LLVM 22.0.0git
DXILCBufferAccess.cpp
Go to the documentation of this file.
1//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DXILCBufferAccess.h"
10#include "DirectX.h"
13#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/IntrinsicsDirectX.h"
17#include "llvm/Pass.h"
20
21#define DEBUG_TYPE "dxil-cbuffer-access"
22using namespace llvm;
23
24namespace {
25/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
26struct CBufferRowIntrin {
27 Intrinsic::ID IID;
28 Type *RetTy;
29 unsigned int EltSize;
30 unsigned int NumElts;
31
32 CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
33 assert(Ty == Ty->getScalarType() && "Expected scalar type");
34
35 switch (DL.getTypeSizeInBits(Ty)) {
36 case 16:
37 IID = Intrinsic::dx_resource_load_cbufferrow_8;
38 RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
39 EltSize = 2;
40 NumElts = 8;
41 break;
42 case 32:
43 IID = Intrinsic::dx_resource_load_cbufferrow_4;
44 RetTy = StructType::get(Ty, Ty, Ty, Ty);
45 EltSize = 4;
46 NumElts = 4;
47 break;
48 case 64:
49 IID = Intrinsic::dx_resource_load_cbufferrow_2;
50 RetTy = StructType::get(Ty, Ty);
51 EltSize = 8;
52 NumElts = 2;
53 break;
54 default:
55 llvm_unreachable("Only 16, 32, and 64 bit types supported");
56 }
57 }
58};
59
60// Helper for creating CBuffer handles and loading data from them
61struct CBufferResource {
62 GlobalVariable *GVHandle;
63 GlobalVariable *Member;
64 size_t MemberOffset;
65
66 LoadInst *Handle;
67
68 CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,
69 size_t MemberOffset)
70 : GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}
71
72 const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }
73 Type *getValueType() { return Member->getValueType(); }
75 return Member->users();
76 }
77
78 /// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.
79 /// `Val` can either be Member itself, or a GEP of a constant offset from
80 /// Member
81 size_t getOffsetForCBufferGEP(Value *Val) {
82 assert(isa<PointerType>(Val->getType()) &&
83 "Expected a pointer-typed value");
84
85 if (Val == Member)
86 return 0;
87
88 if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
89 // Since we should always have a constant offset, we should only ever have
90 // a single GEP of indirection from the Global.
91 assert(GEP->getPointerOperand() == Member &&
92 "Indirect access to resource handle");
93
94 const DataLayout &DL = getDataLayout();
95 APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
96 bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
97 (void)Success;
98 assert(Success && "Offsets into cbuffer globals must be constant");
99
100 if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
101 ConstantOffset =
102 hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
103
104 return ConstantOffset.getZExtValue();
105 }
106
107 llvm_unreachable("Expected Val to be a GlobalVariable or GEP");
108 }
109
110 /// Create a handle for this cbuffer resource using the IRBuilder `Builder`
111 /// and sets the handle as the current one to use for subsequent calls to
112 /// `loadValue`
113 void createAndSetCurrentHandle(IRBuilder<> &Builder) {
114 Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,
115 GVHandle->getName());
116 }
117
118 /// Load a value of type `Ty` at offset `Offset` using the handle from the
119 /// last call to `createAndSetCurrentHandle`
120 Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,
121 const Twine &Name = "") {
122 assert(Handle &&
123 "Expected a handle for this cbuffer global resource to be created "
124 "before loading a value from it");
125 const DataLayout &DL = getDataLayout();
126
127 size_t TargetOffset = MemberOffset + Offset;
128 CBufferRowIntrin Intrin(DL, Ty->getScalarType());
129 // The cbuffer consists of some number of 16-byte rows.
130 unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;
131 unsigned int CurrentIndex =
132 (TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
133
134 auto *CBufLoad = Builder.CreateIntrinsic(
135 Intrin.RetTy, Intrin.IID,
136 {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
137 Name + ".load");
138 auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
139 Name + ".extract");
140
141 Value *Result = nullptr;
142 unsigned int Remaining =
143 ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
144
145 if (Remaining == 0) {
146 // We only have a single element, so we're done.
147 Result = Elt;
148
149 // However, if we loaded a <1 x T>, then we need to adjust the type here.
150 if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
151 assert(VT->getNumElements() == 1 &&
152 "Can't have multiple elements here");
153 Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
154 Builder.getInt32(0), Name);
155 }
156 return Result;
157 }
158
159 // Walk each element and extract it, wrapping to new rows as needed.
160 SmallVector<Value *> Extracts{Elt};
161 while (Remaining--) {
162 CurrentIndex %= Intrin.NumElts;
163
164 if (CurrentIndex == 0)
165 CBufLoad = Builder.CreateIntrinsic(
166 Intrin.RetTy, Intrin.IID,
167 {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
168 nullptr, Name + ".load");
169
170 Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
171 Name + ".extract"));
172 }
173
174 // Finally, we build up the original loaded value.
175 Result = PoisonValue::get(Ty);
176 for (int I = 0, E = Extracts.size(); I < E; ++I)
177 Result =
178 Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I),
179 Name + formatv(".upto{}", I));
180 return Result;
181 }
182};
183
184} // namespace
185
186/// Replace load via cbuffer global with a load from the cbuffer handle itself.
187static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
189 size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());
190 IRBuilder<> Builder(LI);
191 CBR.createAndSetCurrentHandle(Builder);
192 Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());
193 LI->replaceAllUsesWith(Result);
194 DeadInsts.push_back(LI);
195}
196
197/// This function recursively copies N array elements from the cbuffer resource
198/// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional
199/// arrays into a sequence of scalar/vector extracts and stores.
201 CBufferResource &CBR, ArrayType *ArrTy,
202 size_t ArrOffset, size_t N,
203 const Twine &Name = "") {
204 const DataLayout &DL = MCI->getDataLayout();
205 Type *ElemTy = ArrTy->getElementType();
206 size_t ElemTySize = DL.getTypeAllocSize(ElemTy);
207 for (unsigned I = 0; I < N; ++I) {
208 size_t Offset = ArrOffset + I * ElemTySize;
209
210 // Recursively copy nested arrays
211 if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {
212 copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset,
213 ElemArrTy->getNumElements(), Name);
214 continue;
215 }
216
217 // Load CBuffer value and store it in Dest
218 APInt CBufArrayOffset(
219 DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);
220 CBufArrayOffset =
221 hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);
222 Value *CBufferVal =
223 CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);
224 Value *GEP =
225 Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),
226 {Builder.getInt32(Offset)}, Name + ".dest");
227 Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());
228 }
229}
230
231/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
232/// itself. Assumes the cbuffer global is an array, and the length of bytes to
233/// copy is divisible by array element allocation size.
234/// The memcpy source must also be a direct cbuffer global reference, not a GEP.
235static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {
236
237 ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
238 assert(ArrTy && "MemCpy lowering is only supported for array types");
239
240 // This assumption vastly simplifies the implementation
241 if (MCI->getSource() != CBR.Member)
243 "Expected MemCpy source to be a cbuffer global variable");
244
245 ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
246 uint64_t ByteLength = Length->getZExtValue();
247
248 // If length to copy is zero, no memcpy is needed
249 if (ByteLength == 0) {
250 MCI->eraseFromParent();
251 return;
252 }
253
254 const DataLayout &DL = CBR.getDataLayout();
255
256 Type *ElemTy = ArrTy->getElementType();
257 size_t ElemSize = DL.getTypeAllocSize(ElemTy);
258 assert(ByteLength % ElemSize == 0 &&
259 "Length of bytes to MemCpy must be divisible by allocation size of "
260 "source/destination array elements");
261 size_t ElemsToCpy = ByteLength / ElemSize;
262
263 IRBuilder<> Builder(MCI);
264 CBR.createAndSetCurrentHandle(Builder);
265
266 copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy,
267 "memcpy." + MCI->getDest()->getName() + "." +
268 MCI->getSource()->getName());
269
270 MCI->eraseFromParent();
271}
272
273static void replaceAccessesWithHandle(CBufferResource &CBR) {
275
276 SmallVector<User *> ToProcess{CBR.users()};
277 while (!ToProcess.empty()) {
278 User *Cur = ToProcess.pop_back_val();
279
280 // If we have a load instruction, replace the access.
281 if (auto *LI = dyn_cast<LoadInst>(Cur)) {
282 replaceLoad(LI, CBR, DeadInsts);
283 continue;
284 }
285
286 // If we have a memcpy instruction, replace it with multiple accesses and
287 // subsequent stores to the destination
288 if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {
289 replaceMemCpy(MCI, CBR);
290 continue;
291 }
292
293 // Otherwise, walk users looking for a load...
294 if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) {
295 ToProcess.append(Cur->user_begin(), Cur->user_end());
296 continue;
297 }
298
299 llvm_unreachable("Unexpected user of Global");
300 }
302}
303
305 std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
306 if (!CBufMD)
307 return false;
308
309 for (const hlsl::CBufferMapping &Mapping : *CBufMD)
310 for (const hlsl::CBufferMember &Member : Mapping.Members) {
311 CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset);
313 Member.GV->removeFromParent();
314 }
315
316 CBufMD->eraseFromModule();
317 return true;
318}
319
322 bool Changed = replaceCBufferAccesses(M);
323
324 if (!Changed)
325 return PreservedAnalyses::all();
326 return PA;
327}
328
329namespace {
330class DXILCBufferAccessLegacy : public ModulePass {
331public:
332 bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
333 StringRef getPassName() const override { return "DXIL CBuffer Access"; }
334 DXILCBufferAccessLegacy() : ModulePass(ID) {}
335
336 static char ID; // Pass identification.
337};
338char DXILCBufferAccessLegacy::ID = 0;
339} // end anonymous namespace
340
341INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
342 false, false)
343
345 return new DXILCBufferAccessLegacy();
346}
#define Success
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool replaceCBufferAccesses(Module &M)
static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI, CBufferResource &CBR, ArrayType *ArrTy, size_t ArrOffset, size_t N, const Twine &Name="")
This function recursively copies N array elements from the cbuffer resource CBR to the MemCpy Destina...
static void replaceLoad(LoadInst *LI, CBufferResource &CBR, SmallVectorImpl< WeakTrackingVH > &DeadInsts)
Replace load via cbuffer global with a load from the cbuffer handle itself.
static void replaceAccessesWithHandle(CBufferResource &CBR)
static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR)
Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle itself.
return RetTy
std::string Name
#define DEBUG_TYPE
Hexagon Common GEP
iv users
Definition: IVUsers.cpp:48
#define I(x, y, z)
Definition: MD5.cpp:58
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
static Type * getValueType(Value *V)
Returns the type of the given value/instruction V.
Class for arbitrary precision integers.
Definition: APInt.h:78
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1540
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
This is the shared class of boolean and integer constants.
Definition: Constants.h:87
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
LLVM_ABI const DataLayout & getDataLayout() const
Get the data layout of the module this global belongs to.
Definition: Globals.cpp:132
Type * getValueType() const
Definition: GlobalValue.h:298
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2571
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2618
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:562
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition: IRBuilder.h:1931
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:834
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition: IRBuilder.h:522
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1847
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1860
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Definition: IRBuilder.h:552
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const DataLayout & getDataLayout() const
Get the data layout of the module this instruction belongs to.
Definition: Instruction.cpp:86
An instruction for reading from memory.
Definition: Instructions.h:180
Value * getPointerOperand()
Definition: Instructions.h:259
This class wraps the llvm.memcpy intrinsic.
Value * getLength() const
Value * getDest() const
This is just like getRawDest, but it strips off any cast instructions (including addrspacecast) that ...
bool isVolatile() const
Value * getSource() const
This is just like getRawSource, but it strips off any cast instructions that feed it,...
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1885
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
static LLVM_ABI StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
Definition: Type.cpp:414
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:352
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
user_iterator user_begin()
Definition: Value.h:402
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:546
user_iterator user_end()
Definition: Value.h:410
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:322
static std::optional< CBufferMetadata > get(Module &M)
Definition: CBuffer.cpp:34
A range adaptor for a pair of iterators.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
const unsigned CBufferRowSizeInBytes
Definition: HLSLResource.h:24
APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset, ArrayType *Ty)
Definition: CBuffer.cpp:66
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:477
@ Length
Definition: DWP.cpp:477
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:533
ModulePass * createDXILCBufferAccessLegacyPass()
Pass to translate loads in the cbuffer address space to intrinsics.
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
LLVM_ABI void reportFatalUsageError(Error Err)
Report a fatal error that does not indicate a bug in LLVM.
Definition: Error.cpp:180
#define N