LLVM 22.0.0git
SVEIntrinsicOpts.cpp
Go to the documentation of this file.
1//===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Performs general IR level optimizations on SVE intrinsics.
10//
11// This pass performs the following optimizations:
12//
13// - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14// %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15// %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16// ; (%1 can be replaced with a reinterpret of %2)
17//
18// - optimizes ptest intrinsics where the operands are being needlessly
19// converted to and from svbool_t.
20//
21//===----------------------------------------------------------------------===//
22
23#include "AArch64.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/IR/Constants.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/IRBuilder.h"
32#include "llvm/IR/IntrinsicsAArch64.h"
33#include "llvm/IR/LLVMContext.h"
34#include "llvm/IR/Module.h"
37#include <optional>
38
39using namespace llvm;
40using namespace llvm::PatternMatch;
41
42#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
43
44namespace {
45struct SVEIntrinsicOpts : public ModulePass {
46 static char ID; // Pass identification, replacement for typeid
47 SVEIntrinsicOpts() : ModulePass(ID) {}
48
49 bool runOnModule(Module &M) override;
50 void getAnalysisUsage(AnalysisUsage &AU) const override;
51
52private:
53 bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
55 bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
56 bool optimizePredicateStore(Instruction *I);
57 bool optimizePredicateLoad(Instruction *I);
58
59 bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
60
61 /// Operates at the function-scope. I.e., optimizations are applied local to
62 /// the functions themselves.
63 bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
64};
65} // end anonymous namespace
66
67void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
69 AU.setPreservesCFG();
70}
71
72char SVEIntrinsicOpts::ID = 0;
73static const char *name = "SVE intrinsics optimizations";
74INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
76INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
77
79 return new SVEIntrinsicOpts();
80}
81
82/// Checks if a ptrue intrinsic call is promoted. The act of promoting a
83/// ptrue will introduce zeroing. For example:
84///
85/// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
86/// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
87/// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
88///
89/// %1 is promoted, because it is converted:
90///
91/// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
92///
93/// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
94static bool isPTruePromoted(IntrinsicInst *PTrue) {
95 // Find all users of this intrinsic that are calls to convert-to-svbool
96 // reinterpret intrinsics.
98 for (User *User : PTrue->users()) {
99 if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
100 ConvertToUses.push_back(cast<IntrinsicInst>(User));
101 }
102 }
103
104 // If no such calls were found, this is ptrue is not promoted.
105 if (ConvertToUses.empty())
106 return false;
107
108 // Otherwise, try to find users of the convert-to-svbool intrinsics that are
109 // calls to the convert-from-svbool intrinsic, and would result in some lanes
110 // being zeroed.
111 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
112 for (IntrinsicInst *ConvertToUse : ConvertToUses) {
113 for (User *User : ConvertToUse->users()) {
114 auto *IntrUser = dyn_cast<IntrinsicInst>(User);
115 if (IntrUser && IntrUser->getIntrinsicID() ==
116 Intrinsic::aarch64_sve_convert_from_svbool) {
117 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
118
119 // Would some lanes become zeroed by the conversion?
120 if (IntrUserVTy->getElementCount().getKnownMinValue() >
121 PTrueVTy->getElementCount().getKnownMinValue())
122 // This is a promoted ptrue.
123 return true;
124 }
125 }
126 }
127
128 // If no matching calls were found, this is not a promoted ptrue.
129 return false;
130}
131
132/// Attempts to coalesce ptrues in a basic block.
133bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
135 if (PTrues.size() <= 1)
136 return false;
137
138 // Find the ptrue with the most lanes.
139 auto *MostEncompassingPTrue =
140 *llvm::max_element(PTrues, [](auto *PTrue1, auto *PTrue2) {
141 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
142 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
143 return PTrue1VTy->getElementCount().getKnownMinValue() <
144 PTrue2VTy->getElementCount().getKnownMinValue();
145 });
146
147 // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
148 // behind only the ptrues to be coalesced.
149 PTrues.remove(MostEncompassingPTrue);
151
152 // Hoist MostEncompassingPTrue to the start of the basic block. It is always
153 // safe to do this, since ptrue intrinsic calls are guaranteed to have no
154 // predecessors.
155 MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
156
157 LLVMContext &Ctx = BB.getContext();
158 IRBuilder<> Builder(Ctx);
159 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
160
161 auto *MostEncompassingPTrueVTy =
162 cast<VectorType>(MostEncompassingPTrue->getType());
163 auto *ConvertToSVBool = Builder.CreateIntrinsic(
164 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
165 {MostEncompassingPTrue});
166
167 bool ConvertFromCreated = false;
168 for (auto *PTrue : PTrues) {
169 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
170
171 // Only create the converts if the types are not already the same, otherwise
172 // just use the most encompassing ptrue.
173 if (MostEncompassingPTrueVTy != PTrueVTy) {
174 ConvertFromCreated = true;
175
176 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
177 auto *ConvertFromSVBool =
178 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
179 {PTrueVTy}, {ConvertToSVBool});
180 PTrue->replaceAllUsesWith(ConvertFromSVBool);
181 } else
182 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
183
184 PTrue->eraseFromParent();
185 }
186
187 // We never used the ConvertTo so remove it
188 if (!ConvertFromCreated)
189 ConvertToSVBool->eraseFromParent();
190
191 return true;
192}
193
194/// The goal of this function is to remove redundant calls to the SVE ptrue
195/// intrinsic in each basic block within the given functions.
196///
197/// SVE ptrues have two representations in LLVM IR:
198/// - a logical representation -- an arbitrary-width scalable vector of i1s,
199/// i.e. <vscale x N x i1>.
200/// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
201/// scalable vector of i1s, i.e. <vscale x 16 x i1>.
202///
203/// The SVE ptrue intrinsic is used to create a logical representation of an SVE
204/// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
205/// P1 creates a logical SVE predicate that is at least as wide as the logical
206/// SVE predicate created by P2, then all of the bits that are true in the
207/// physical representation of P2 are necessarily also true in the physical
208/// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
209/// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
210/// convert.{to,from}.svbool.
211///
212/// Currently, this pass only coalesces calls to SVE ptrue intrinsics
213/// if they match the following conditions:
214///
215/// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
216/// SV_ALL indicates that all bits of the predicate vector are to be set to
217/// true. SV_POW2 indicates that all bits of the predicate vector up to the
218/// largest power-of-two are to be set to true.
219/// - the result of the call to the intrinsic is not promoted to a wider
220/// predicate. In this case, keeping the extra ptrue leads to better codegen
221/// -- coalescing here would create an irreducible chain of SVE reinterprets
222/// via convert.{to,from}.svbool.
223///
224/// EXAMPLE:
225///
226/// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
227/// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1>
228/// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
229/// ...
230///
231/// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
232/// ; Logical: <1, 1, 1, 1>
233/// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
234/// ...
235///
236/// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
237///
238/// %1 = <vscale x 8 x i1> ptrue(i32 i31)
239/// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
240/// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
241///
242bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
244 bool Changed = false;
245
246 for (auto *F : Functions) {
247 for (auto &BB : *F) {
250
251 // For each basic block, collect the used ptrues and try to coalesce them.
252 for (Instruction &I : BB) {
253 if (I.use_empty())
254 continue;
255
256 auto *IntrI = dyn_cast<IntrinsicInst>(&I);
257 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
258 continue;
259
260 const auto PTruePattern =
261 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
262
263 if (PTruePattern == AArch64SVEPredPattern::all)
264 SVAllPTrues.insert(IntrI);
265 if (PTruePattern == AArch64SVEPredPattern::pow2)
266 SVPow2PTrues.insert(IntrI);
267 }
268
269 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
270 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
271 }
272 }
273
274 return Changed;
275}
276
277// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
278// scalable stores as late as possible
279bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
280 auto *F = I->getFunction();
281 auto Attr = F->getFnAttribute(Attribute::VScaleRange);
282 if (!Attr.isValid())
283 return false;
284
285 unsigned MinVScale = Attr.getVScaleRangeMin();
286 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
287 // The transform needs to know the exact runtime length of scalable vectors
288 if (!MaxVScale || MinVScale != MaxVScale)
289 return false;
290
291 auto *PredType =
292 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
293 auto *FixedPredType =
294 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
295
296 // If we have a store..
297 auto *Store = dyn_cast<StoreInst>(I);
298 if (!Store || !Store->isSimple())
299 return false;
300
301 // ..that is storing a predicate vector sized worth of bits..
302 if (Store->getOperand(0)->getType() != FixedPredType)
303 return false;
304
305 // ..where the value stored comes from a vector extract..
306 auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
307 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
308 return false;
309
310 // ..that is extracting from index 0..
311 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
312 return false;
313
314 // ..where the value being extract from comes from a bitcast
315 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
316 if (!BitCast)
317 return false;
318
319 // ..and the bitcast is casting from predicate type
320 if (BitCast->getOperand(0)->getType() != PredType)
321 return false;
322
323 IRBuilder<> Builder(I->getContext());
324 Builder.SetInsertPoint(I);
325
326 Builder.CreateStore(BitCast->getOperand(0), Store->getPointerOperand());
327
328 Store->eraseFromParent();
329 if (IntrI->use_empty())
330 IntrI->eraseFromParent();
331 if (BitCast->use_empty())
332 BitCast->eraseFromParent();
333
334 return true;
335}
336
337// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
338// scalable loads as late as possible
339bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
340 auto *F = I->getFunction();
341 auto Attr = F->getFnAttribute(Attribute::VScaleRange);
342 if (!Attr.isValid())
343 return false;
344
345 unsigned MinVScale = Attr.getVScaleRangeMin();
346 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
347 // The transform needs to know the exact runtime length of scalable vectors
348 if (!MaxVScale || MinVScale != MaxVScale)
349 return false;
350
351 auto *PredType =
352 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
353 auto *FixedPredType =
354 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
355
356 // If we have a bitcast..
357 auto *BitCast = dyn_cast<BitCastInst>(I);
358 if (!BitCast || BitCast->getType() != PredType)
359 return false;
360
361 // ..whose operand is a vector_insert..
362 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
363 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
364 return false;
365
366 // ..that is inserting into index zero of an undef vector..
367 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
368 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
369 return false;
370
371 // ..where the value inserted comes from a load..
372 auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
373 if (!Load || !Load->isSimple())
374 return false;
375
376 // ..that is loading a predicate vector sized worth of bits..
377 if (Load->getType() != FixedPredType)
378 return false;
379
380 IRBuilder<> Builder(I->getContext());
381 Builder.SetInsertPoint(Load);
382
383 auto *LoadPred = Builder.CreateLoad(PredType, Load->getPointerOperand());
384
385 BitCast->replaceAllUsesWith(LoadPred);
386 BitCast->eraseFromParent();
387 if (IntrI->use_empty())
388 IntrI->eraseFromParent();
389 if (Load->use_empty())
390 Load->eraseFromParent();
391
392 return true;
393}
394
395bool SVEIntrinsicOpts::optimizeInstructions(
397 bool Changed = false;
398
399 for (auto *F : Functions) {
400 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
401
402 // Traverse the DT with an rpo walk so we see defs before uses, allowing
403 // simplification to be done incrementally.
404 BasicBlock *Root = DT->getRoot();
406 for (auto *BB : RPOT) {
407 for (Instruction &I : make_early_inc_range(*BB)) {
408 switch (I.getOpcode()) {
409 case Instruction::Store:
410 Changed |= optimizePredicateStore(&I);
411 break;
412 case Instruction::BitCast:
413 Changed |= optimizePredicateLoad(&I);
414 break;
415 }
416 }
417 }
418 }
419
420 return Changed;
421}
422
423bool SVEIntrinsicOpts::optimizeFunctions(
425 bool Changed = false;
426
427 Changed |= optimizePTrueIntrinsicCalls(Functions);
428 Changed |= optimizeInstructions(Functions);
429
430 return Changed;
431}
432
433bool SVEIntrinsicOpts::runOnModule(Module &M) {
434 bool Changed = false;
436
437 // Check for SVE intrinsic declarations first so that we only iterate over
438 // relevant functions. Where an appropriate declaration is found, store the
439 // function(s) where it is used so we can target these only.
440 for (auto &F : M.getFunctionList()) {
441 if (!F.isDeclaration())
442 continue;
443
444 switch (F.getIntrinsicID()) {
445 case Intrinsic::vector_extract:
446 case Intrinsic::vector_insert:
447 case Intrinsic::aarch64_sve_ptrue:
448 for (User *U : F.users())
449 Functions.insert(cast<Instruction>(U)->getFunction());
450 break;
451 default:
452 break;
453 }
454 }
455
456 if (!Functions.empty())
457 Changed |= optimizeFunctions(Functions);
458
459 return Changed;
460}
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static const char * name
static bool isPTruePromoted(IntrinsicInst *PTrue)
Checks if a ptrue intrinsic call is promoted.
#define DEBUG_TYPE
This file implements a set that has insertion order iteration characteristics.
static Function * getFunction(FunctionType *Ty, const Twine &Name, Module *M)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:270
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:393
LLVM_ABI LLVMContext & getContext() const
Get the context in which this basic block lives.
Definition: BasicBlock.cpp:131
NodeT * getRoot() const
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:165
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:803
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:49
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:255
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:112
static LLVM_ABI ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
Definition: Type.cpp:825
bool remove(const value_type &X)
Remove an item from the set vector.
Definition: SetVector.h:198
bool remove_if(UnaryPredicate P)
Remove items from the set vector based on a predicate function.
Definition: SetVector.h:247
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:104
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:99
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:168
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:356
bool empty() const
Definition: SmallVector.h:82
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
iterator_range< user_iterator > users()
Definition: Value.h:426
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
ModulePass * createSVEIntrinsicOptsPass()
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2049