30#define DEBUG_TYPE "dxil-flatten-arrays"
35class DXILFlattenArraysLegacy :
public ModulePass {
46 Value *RootPointerOperand;
51class DXILFlattenArraysVisitor
52 :
public InstVisitor<DXILFlattenArraysVisitor, bool> {
54 DXILFlattenArraysVisitor(
56 : GlobalMap(GlobalMap) {}
78 static bool isMultiDimensionalArray(
Type *
T);
79 static std::pair<unsigned, Type *> getElementCountAndType(
Type *ArrayTy);
95bool DXILFlattenArraysVisitor::finish() {
96 GEPChainInfoMap.clear();
101bool DXILFlattenArraysVisitor::isMultiDimensionalArray(
Type *
T) {
102 if (
ArrayType *ArrType = dyn_cast<ArrayType>(
T))
103 return isa<ArrayType>(ArrType->getElementType());
107std::pair<unsigned, Type *>
108DXILFlattenArraysVisitor::getElementCountAndType(
Type *ArrayTy) {
109 unsigned TotalElements = 1;
110 Type *CurrArrayTy = ArrayTy;
111 while (
auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
112 TotalElements *= InnerArrayTy->getNumElements();
113 CurrArrayTy = InnerArrayTy->getElementType();
115 return std::make_pair(TotalElements, CurrArrayTy);
118ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
121 "Indicies and dimmensions should be the same");
122 unsigned FlatIndex = 0;
123 unsigned Multiplier = 1;
125 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
126 unsigned DimSize = Dims[
I];
127 ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[
I]);
128 assert(CIndex &&
"This function expects all indicies to be ConstantInt");
130 Multiplier *= DimSize;
135Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
137 if (Indices.
size() == 1)
141 unsigned Multiplier = 1;
143 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
144 unsigned DimSize = Dims[
I];
147 FlatIndex = Builder.
CreateAdd(FlatIndex, ScaledIndex);
148 Multiplier *= DimSize;
153bool DXILFlattenArraysVisitor::visitLoadInst(
LoadInst &LI) {
155 for (
unsigned I = 0;
I < NumOperands; ++
I) {
158 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
160 cast<GetElementPtrInst>(
CE->getAsInstruction());
169 visitGetElementPtrInst(*OldGEP);
176bool DXILFlattenArraysVisitor::visitStoreInst(
StoreInst &SI) {
177 unsigned NumOperands =
SI.getNumOperands();
178 for (
unsigned I = 0;
I < NumOperands; ++
I) {
179 Value *CurrOpperand =
SI.getOperand(
I);
181 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
183 cast<GetElementPtrInst>(
CE->getAsInstruction());
189 SI.replaceAllUsesWith(NewStore);
190 SI.eraseFromParent();
191 visitGetElementPtrInst(*OldGEP);
198bool DXILFlattenArraysVisitor::visitAllocaInst(
AllocaInst &AI) {
204 auto [TotalElements,
BaseType] = getElementCountAndType(ArrType);
217 if (GEPChainInfoMap.contains(cast<GEPOperator>(&
GEP)))
220 Value *PtrOperand =
GEP.getPointerOperand();
224 assert(!isa<PHINode>(PtrOperand) &&
225 "Pointer operand of GEP should not be a PHI Node");
229 if (
auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230 PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
232 cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
238 Builder.
CreateGEP(
GEP.getSourceElementType(), OldGEPI, Indices,
239 GEP.getName(),
GEP.getNoWrapFlags());
240 assert(isa<GetElementPtrInst>(NewGEP) &&
241 "Expected newly-created GEP to be an instruction");
244 GEP.replaceAllUsesWith(NewGEPI);
245 GEP.eraseFromParent();
246 visitGetElementPtrInst(*OldGEPI);
247 visitGetElementPtrInst(*NewGEPI);
256 unsigned BitWidth =
DL.getIndexTypeSizeInBits(
GEP.getType());
258 [[maybe_unused]]
bool Success =
GEP.collectOffset(
265 if (
auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
270 if (!GEPChainInfoMap.contains(PtrOpGEP))
273 GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
274 Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
275 Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
276 for (
auto &VariableOffset : PGEPInfo.VariableOffsets)
277 Info.VariableOffsets.insert(VariableOffset);
278 Info.ConstantOffset += PGEPInfo.ConstantOffset;
280 Info.RootPointerOperand = PtrOperand;
285 Type *RootTy =
GEP.getSourceElementType();
286 if (
auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
287 if (GlobalMap.contains(GlobalVar))
291 }
else if (
auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
292 RootTy = Alloca->getAllocatedType();
293 assert(!isMultiDimensionalArray(RootTy) &&
294 "Expected root array type to be flattened");
297 if (!isa<ArrayType>(RootTy))
300 Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
306 bool ReplaceThisGEP =
GEP.users().empty();
308 if (!isa<GetElementPtrInst>(
User))
309 ReplaceThisGEP =
true;
311 if (ReplaceThisGEP) {
312 unsigned BytesPerElem =
313 DL.getTypeAllocSize(
Info.RootFlattenedArrayType->getArrayElementType());
315 "Bytes per element should be a power of 2");
322 Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
323 assert(ConstantOffset < UINT32_MAX &&
324 "Constant byte offset for flat GEP index must fit within 32 bits");
326 for (
auto [VarIndex, Multiplier] :
Info.VariableOffsets) {
327 assert(Multiplier.getActiveBits() <= 32 &&
328 "The multiplier for a flat GEP index must fit within 32 bits");
329 assert(VarIndex->getType()->isIntegerTy(32) &&
330 "Expected i32-typed GEP indices");
332 if (Multiplier.getZExtValue() % BytesPerElem != 0) {
337 Builder.
getInt32(Multiplier.getZExtValue()));
342 Builder.
getInt32(Multiplier.getZExtValue() / BytesPerElem));
343 FlattenedIndex = Builder.
CreateAdd(FlattenedIndex, VI);
348 Info.RootFlattenedArrayType,
Info.RootPointerOperand,
349 {ZeroIndex, FlattenedIndex},
GEP.getName(),
GEP.getNoWrapFlags());
355 if (!isa<GEPOperator>(NewGEP))
357 Info.RootFlattenedArrayType,
Info.RootPointerOperand,
358 {ZeroIndex, FlattenedIndex},
GEP.getNoWrapFlags(),
GEP.getName(),
363 GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
364 GEP.replaceAllUsesWith(NewGEP);
365 GEP.eraseFromParent();
372 GEPChainInfoMap.insert({cast<GEPOperator>(&
GEP), std::move(Info)});
373 PotentiallyDeadInstrs.emplace_back(&
GEP);
377bool DXILFlattenArraysVisitor::visit(
Function &
F) {
378 bool MadeChange =
false;
391 auto *ArrayTy = dyn_cast<ArrayType>(
Init->getType());
393 Elements.push_back(
Init);
396 unsigned ArrSize = ArrayTy->getNumElements();
397 if (isa<ConstantAggregateZero>(
Init)) {
398 for (
unsigned I = 0;
I < ArrSize; ++
I)
404 if (
auto *ArrayConstant = dyn_cast<ConstantArray>(
Init)) {
405 for (
unsigned I = 0;
I < ArrayConstant->getNumOperands(); ++
I) {
408 }
else if (
auto *DataArrayConstant = dyn_cast<ConstantDataArray>(
Init)) {
409 for (
unsigned I = 0;
I < DataArrayConstant->getNumElements(); ++
I) {
414 "Expected a ConstantArray or ConstantDataArray for array initializer!");
422 if (isa<ConstantAggregateZero>(
Init))
426 if (isa<UndefValue>(
Init))
429 if (!isa<ArrayType>(OrigType))
434 assert(FlattenedType->getNumElements() == FlattenedElements.
size() &&
435 "The number of collected elements should match the FlattenedType");
443 Type *OrigType =
G.getValueType();
444 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
447 ArrayType *ArrType = cast<ArrayType>(OrigType);
449 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
456 nullptr,
G.getName() +
".1dim", &
G,
457 G.getThreadLocalMode(),
G.getAddressSpace(),
458 G.isExternallyInitialized());
462 if (
G.getAlignment() > 0) {
466 if (
G.hasInitializer()) {
472 GlobalMap[&
G] = NewGlobal;
477 bool MadeChange =
false;
480 DXILFlattenArraysVisitor Impl(GlobalMap);
482 if (
F.isDeclaration())
484 MadeChange |= Impl.visit(
F);
486 for (
auto &[Old, New] : GlobalMap) {
487 Old->replaceAllUsesWith(New);
488 Old->eraseFromParent();
502bool DXILFlattenArraysLegacy::runOnModule(
Module &M) {
506char DXILFlattenArraysLegacy::ID = 0;
509 "DXIL Array Flattener",
false,
false)
514 return new DXILFlattenArraysLegacy();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Analysis containing CSE Info
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
static bool flattenArrays(Module &M)
static Constant * transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx)
static void flattenGlobalArrays(Module &M, SmallDenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
Class for arbitrary precision integers.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
A container for analyses that lazily runs them and caches their results.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM Basic Block Representation.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
static LLVM_ABI Constant * get(ArrayType *T, ArrayRef< Constant * > V)
A constant value that is initialized with an expression using other constant values.
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
A parsed version of the target data layout string in and methods for querying it.
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
void setUnnamedAddr(UnnamedAddr Val)
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalVariable.
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
BasicBlock::iterator GetInsertPoint() const
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
RetTy visitFreezeInst(FreezeInst &I)
RetTy visitFCmpInst(FCmpInst &I)
RetTy visitExtractElementInst(ExtractElementInst &I)
RetTy visitShuffleVectorInst(ShuffleVectorInst &I)
RetTy visitBitCastInst(BitCastInst &I)
void visit(Iterator Start, Iterator End)
RetTy visitPHINode(PHINode &I)
RetTy visitUnaryOperator(UnaryOperator &I)
RetTy visitStoreInst(StoreInst &I)
RetTy visitInsertElementInst(InsertElementInst &I)
RetTy visitAllocaInst(AllocaInst &I)
RetTy visitBinaryOperator(BinaryOperator &I)
RetTy visitICmpInst(ICmpInst &I)
RetTy visitCallInst(CallInst &I)
RetTy visitCastInst(CastInst &I)
RetTy visitSelectInst(SelectInst &I)
RetTy visitGetElementPtrInst(GetElementPtrInst &I)
void visitInstruction(Instruction &I)
RetTy visitLoadInst(LoadInst &I)
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...
constexpr unsigned BitWidth
A MapVector that performs no allocations if smaller than a certain size.