LLVM 22.0.0git
SPIRVRegularizer.cpp
Go to the documentation of this file.
1//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10// the pass was taken from SPIRV-LLVM translator.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRV.h"
17#include "llvm/IR/InstVisitor.h"
18#include "llvm/IR/PassManager.h"
20
21#include <list>
22
23#define DEBUG_TYPE "spirv-regularizer"
24
25using namespace llvm;
26
27namespace {
28struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
30
31public:
32 static char ID;
33 SPIRVRegularizer() : FunctionPass(ID) {}
34 bool runOnFunction(Function &F) override;
35 StringRef getPassName() const override { return "SPIR-V Regularizer"; }
36
37 void getAnalysisUsage(AnalysisUsage &AU) const override {
39 }
40 void visitCallInst(CallInst &CI);
41
42private:
43 void visitCallScalToVec(CallInst *CI, StringRef MangledName,
44 StringRef DemangledName);
45 void runLowerConstExpr(Function &F);
46};
47} // namespace
48
49char SPIRVRegularizer::ID = 0;
50
51INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
52 false)
53
54// Since SPIR-V cannot represent constant expression, constant expressions
55// in LLVM IR need to be lowered to instructions. For each function,
56// the constant expressions used by instructions of the function are replaced
57// by instructions placed in the entry block since it dominates all other BBs.
58// Each constant expression only needs to be lowered once in each function
59// and all uses of it by instructions in that function are replaced by
60// one instruction.
61// TODO: remove redundant instructions for common subexpression.
62void SPIRVRegularizer::runLowerConstExpr(Function &F) {
63 LLVMContext &Ctx = F.getContext();
64 std::list<Instruction *> WorkList;
65 for (auto &II : instructions(F))
66 WorkList.push_back(&II);
67
68 auto FBegin = F.begin();
69 while (!WorkList.empty()) {
70 Instruction *II = WorkList.front();
71
72 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
73 if (isa<Function>(V))
74 return V;
75 auto *CE = cast<ConstantExpr>(V);
76 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
77 auto ReplInst = CE->getAsInstruction();
78 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
79 ReplInst->insertBefore(InsPoint->getIterator());
80 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
81 std::vector<Instruction *> Users;
82 // Do not replace use during iteration of use. Do it in another loop.
83 for (auto U : CE->users()) {
84 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
85 auto InstUser = dyn_cast<Instruction>(U);
86 // Only replace users in scope of current function.
87 if (InstUser && InstUser->getParent()->getParent() == &F)
88 Users.push_back(InstUser);
89 }
90 for (auto &User : Users) {
91 if (ReplInst->getParent() == User->getParent() &&
92 User->comesBefore(ReplInst))
93 ReplInst->moveBefore(User->getIterator());
94 User->replaceUsesOfWith(CE, ReplInst);
95 }
96 return ReplInst;
97 };
98
99 WorkList.pop_front();
100 auto LowerConstantVec = [&II, &LowerOp, &WorkList,
101 &Ctx](ConstantVector *Vec,
102 unsigned NumOfOp) -> Value * {
103 if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
104 return isa<ConstantExpr>(V) || isa<Function>(V);
105 })) {
106 // Expand a vector of constexprs and construct it back with
107 // series of insertelement instructions.
108 std::list<Value *> OpList;
109 std::transform(Vec->op_begin(), Vec->op_end(),
110 std::back_inserter(OpList),
111 [LowerOp](Value *V) { return LowerOp(V); });
112 Value *Repl = nullptr;
113 unsigned Idx = 0;
114 auto *PhiII = dyn_cast<PHINode>(II);
115 Instruction *InsPoint =
116 PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
117 std::list<Instruction *> ReplList;
118 for (auto V : OpList) {
119 if (auto *Inst = dyn_cast<Instruction>(V))
120 ReplList.push_back(Inst);
122 (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
123 ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
124 InsPoint->getIterator());
125 }
126 WorkList.splice(WorkList.begin(), ReplList);
127 return Repl;
128 }
129 return nullptr;
130 };
131 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
132 auto *Op = II->getOperand(OI);
133 if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
134 Value *ReplInst = LowerConstantVec(Vec, OI);
135 if (ReplInst)
136 II->replaceUsesOfWith(Op, ReplInst);
137 } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
138 WorkList.push_front(cast<Instruction>(LowerOp(CE)));
139 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
140 auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
141 if (!ConstMD)
142 continue;
143 Constant *C = ConstMD->getValue();
144 Value *ReplInst = nullptr;
145 if (auto *Vec = dyn_cast<ConstantVector>(C))
146 ReplInst = LowerConstantVec(Vec, OI);
147 if (auto *CE = dyn_cast<ConstantExpr>(C))
148 ReplInst = LowerOp(CE);
149 if (!ReplInst)
150 continue;
151 Metadata *RepMD = ValueAsMetadata::get(ReplInst);
152 Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
153 II->setOperand(OI, RepMDVal);
154 WorkList.push_front(cast<Instruction>(ReplInst));
155 }
156 }
157 }
158}
159
160// It fixes calls to OCL builtins that accept vector arguments and one of them
161// is actually a scalar splat.
162void SPIRVRegularizer::visitCallInst(CallInst &CI) {
163 auto F = CI.getCalledFunction();
164 if (!F)
165 return;
166
167 auto MangledName = F->getName();
168 char *NameStr = itaniumDemangle(F->getName().data());
169 if (!NameStr)
170 return;
171 StringRef DemangledName(NameStr);
172
173 // TODO: add support for other builtins.
174 if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
175 DemangledName.starts_with("min") || DemangledName.starts_with("max"))
176 visitCallScalToVec(&CI, MangledName, DemangledName);
177 free(NameStr);
178}
179
180void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
181 StringRef DemangledName) {
182 // Check if all arguments have the same type - it's simple case.
183 auto Uniform = true;
184 Type *Arg0Ty = CI->getOperand(0)->getType();
185 auto IsArg0Vector = isa<VectorType>(Arg0Ty);
186 for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
187 Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
188 if (Uniform)
189 return;
190
191 auto *OldF = CI->getCalledFunction();
192 Function *NewF = nullptr;
193 auto [It, Inserted] = Old2NewFuncs.try_emplace(OldF);
194 if (Inserted) {
196 SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
197 auto *NewFTy =
198 FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
199 NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
200 *OldF->getParent());
202 auto NewFArgIt = NewF->arg_begin();
203 for (auto &Arg : OldF->args()) {
204 auto ArgName = Arg.getName();
205 NewFArgIt->setName(ArgName);
206 VMap[&Arg] = &(*NewFArgIt++);
207 }
209 CloneFunctionInto(NewF, OldF, VMap,
210 CloneFunctionChangeType::LocalChangesOnly, Returns);
211 NewF->setAttributes(Attrs);
212 It->second = NewF;
213 } else {
214 NewF = It->second;
215 }
216 assert(NewF);
217
218 // This produces an instruction sequence that implements a splat of
219 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
220 // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
221 // For instance (transcoding/OpMin.ll), this call
222 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
223 // is translated to
224 // %8 = OpUndef %v2uint
225 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
226 // ...
227 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
228 // %11 = OpVectorShuffle %v2uint %10 %8 0 0
229 // %call = OpExtInst %v2uint %1 s_min %14 %11
230 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
231 PoisonValue *PVal = PoisonValue::get(Arg0Ty);
233 PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
234 ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
235 Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
236 Value *NewVec =
237 new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
238 CI->setOperand(1, NewVec);
239 CI->replaceUsesOfWith(OldF, NewF);
241}
242
243bool SPIRVRegularizer::runOnFunction(Function &F) {
244 runLowerConstExpr(F);
245 visit(F);
246 for (auto &OldNew : Old2NewFuncs) {
247 Function *OldF = OldNew.first;
248 Function *NewF = OldNew.second;
249 NewF->takeName(OldF);
250 OldF->eraseFromParent();
251 }
252 return true;
253}
254
256 return new SPIRVRegularizer();
257}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This header defines various interfaces for pass management in LLVM.
iv Induction Variable Users
Definition: IVUsers.cpp:48
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
#define DEBUG_TYPE
#define LLVM_DEBUG(...)
Definition: Debug.h:119
Represent the analysis usage information of a pass.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1348
void mutateFunctionType(FunctionType *FTy)
Definition: InstrTypes.h:1207
unsigned arg_size() const
Definition: InstrTypes.h:1290
This class represents a function call, abstracting a target machine's calling convention.
Constant Vector Declarations.
Definition: Constants.h:517
static LLVM_ABI Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
Definition: Constants.cpp:1474
This is an important base class in LLVM.
Definition: Constant.h:43
This class represents an Operation in the Expression.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:166
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:209
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:352
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Function.cpp:448
arg_iterator arg_begin()
Definition: Function.h:866
void setAttributes(AttributeList Attrs)
Set the attribute list for this Function.
Definition: Function.h:355
static InsertElementInst * Create(Value *Vec, Value *NewElt, Value *Idx, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Base class for instruction visitors.
Definition: InstVisitor.h:78
RetTy visitCallInst(CallInst &I)
Definition: InstVisitor.h:215
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:319
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
static LLVM_ABI MetadataAsValue * get(LLVMContext &Context, Metadata *MD)
Definition: Metadata.cpp:103
Root of the metadata hierarchy.
Definition: Metadata.h:63
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:112
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:85
In order to facilitate speculative execution, many instructions do not invoke immediate undefined beh...
Definition: Constants.h:1468
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1885
This instruction constructs a fixed permutation of two input vectors.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition: User.h:237
Value * getOperand(unsigned i) const
Definition: User.h:232
static LLVM_ABI ValueAsMetadata * get(Value *V)
Definition: Metadata.cpp:502
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:322
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition: Value.cpp:396
self_iterator getIterator()
Definition: ilist_node.h:134
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
DEMANGLE_ABI char * itaniumDemangle(std::string_view mangled_name, bool ParseParams=true)
Returns a non-NULL pointer to a NUL-terminated C style string that should be explicitly freed,...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
FunctionPass * createSPIRVRegularizerPass()
LLVM_ABI void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.