LLVM 22.0.0git
SPIRVMergeRegionExitTargets.cpp
Go to the documentation of this file.
1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVUtils.h"
19#include "llvm/ADT/DenseMap.h"
23#include "llvm/IR/Dominators.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Intrinsics.h"
30
31using namespace llvm;
32
33namespace {
34
35class SPIRVMergeRegionExitTargets : public FunctionPass {
36public:
37 static char ID;
38
39 SPIRVMergeRegionExitTargets() : FunctionPass(ID) {}
40
41 // Gather all the successors of |BB|.
42 // This function asserts if the terminator neither a branch, switch or return.
43 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
44 std::unordered_set<BasicBlock *> output;
45 auto *T = BB->getTerminator();
46
47 if (auto *BI = dyn_cast<BranchInst>(T)) {
48 output.insert(BI->getSuccessor(0));
49 if (BI->isConditional())
50 output.insert(BI->getSuccessor(1));
51 return output;
52 }
53
54 if (auto *SI = dyn_cast<SwitchInst>(T)) {
55 output.insert(SI->getDefaultDest());
56 for (auto &Case : SI->cases())
57 output.insert(Case.getCaseSuccessor());
58 return output;
59 }
60
61 assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
62 return output;
63 }
64
65 /// Create a value in BB set to the value associated with the branch the block
66 /// terminator will take.
67 llvm::Value *createExitVariable(
68 BasicBlock *BB,
69 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
70 auto *T = BB->getTerminator();
71 if (isa<ReturnInst>(T))
72 return nullptr;
73
74 IRBuilder<> Builder(BB);
75 Builder.SetInsertPoint(T);
76
77 if (auto *BI = dyn_cast<BranchInst>(T)) {
78
79 BasicBlock *LHSTarget = BI->getSuccessor(0);
80 BasicBlock *RHSTarget =
81 BI->isConditional() ? BI->getSuccessor(1) : nullptr;
82
83 Value *LHS = TargetToValue.lookup(LHSTarget);
84 Value *RHS = TargetToValue.lookup(RHSTarget);
85
86 if (LHS == nullptr || RHS == nullptr)
87 return LHS == nullptr ? RHS : LHS;
88 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
89 }
90
91 // TODO: add support for switch cases.
92 llvm_unreachable("Unhandled terminator type.");
93 }
94
95 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
97 const SmallPtrSet<BasicBlock *, 4> &ToReplace,
98 BasicBlock *NewTarget) {
99 auto *T = BB->getTerminator();
100 if (isa<ReturnInst>(T))
101 return;
102
103 if (auto *BI = dyn_cast<BranchInst>(T)) {
104 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
105 if (ToReplace.count(BI->getSuccessor(i)) != 0)
106 BI->setSuccessor(i, NewTarget);
107 }
108 return;
109 }
110
111 if (auto *SI = dyn_cast<SwitchInst>(T)) {
112 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
113 if (ToReplace.count(SI->getSuccessor(i)) != 0)
114 SI->setSuccessor(i, NewTarget);
115 }
116 return;
117 }
118
119 assert(false && "Unhandled terminator type.");
120 }
121
122 AllocaInst *CreateVariable(Function &F, Type *Type,
123 BasicBlock::iterator Position) {
124 const DataLayout &DL = F.getDataLayout();
125 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
126 Position);
127 }
128
129 // Run the pass on the given convergence region, ignoring the sub-regions.
130 // Returns true if the CFG changed, false otherwise.
131 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
133 // Gather all the exit targets for this region.
135 for (BasicBlock *Exit : CR->Exits) {
136 for (BasicBlock *Target : gatherSuccessors(Exit)) {
137 if (CR->Blocks.count(Target) == 0)
138 ExitTargets.insert(Target);
139 }
140 }
141
142 // If we have zero or one exit target, nothing do to.
143 if (ExitTargets.size() <= 1)
144 return false;
145
146 // Create the new single exit target.
147 auto F = CR->Entry->getParent();
148 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
149 IRBuilder<> Builder(NewExitTarget);
150
151 AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
152 F->begin()->getFirstInsertionPt());
153
154 // CodeGen output needs to be stable. Using the set as-is would order
155 // the targets differently depending on the allocation pattern.
156 // Sorting per basic-block ordering in the function.
157 std::vector<BasicBlock *> SortedExitTargets;
158 std::vector<BasicBlock *> SortedExits;
159 for (BasicBlock &BB : *F) {
160 if (ExitTargets.count(&BB) != 0)
161 SortedExitTargets.push_back(&BB);
162 if (CR->Exits.count(&BB) != 0)
163 SortedExits.push_back(&BB);
164 }
165
166 // Creating one constant per distinct exit target. This will be route to the
167 // correct target.
169 for (BasicBlock *Target : SortedExitTargets)
170 TargetToValue.insert(
171 std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
172
173 // Creating one variable per exit node, set to the constant matching the
174 // targeted external block.
175 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
176 for (auto Exit : SortedExits) {
177 llvm::Value *Value = createExitVariable(Exit, TargetToValue);
178 IRBuilder<> B2(Exit);
179 B2.SetInsertPoint(Exit->getFirstInsertionPt());
180 B2.CreateStore(Value, Variable);
181 ExitToVariable.emplace_back(std::make_pair(Exit, Value));
182 }
183
184 llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
185
186 // Creating the switch to jump to the correct exit target.
187 llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
188 SortedExitTargets.size() - 1);
189 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
190 BasicBlock *BB = SortedExitTargets[i];
191 Sw->addCase(TargetToValue[BB], BB);
192 }
193
194 // Fix exit branches to redirect to the new exit.
195 for (auto Exit : CR->Exits)
196 replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
197
198 CR = CR->Parent;
199 while (CR) {
200 CR->Blocks.insert(NewExitTarget);
201 CR = CR->Parent;
202 }
203
204 return true;
205 }
206
207 /// Run the pass on the given convergence region and sub-regions (DFS).
208 /// Returns true if a region/sub-region was modified, false otherwise.
209 /// This returns as soon as one region/sub-region has been modified.
210 bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
211 for (auto *Child : CR->Children)
212 if (runOnConvergenceRegion(LI, Child))
213 return true;
214
215 return runOnConvergenceRegionNoRecurse(LI, CR);
216 }
217
218#if !NDEBUG
219 /// Validates each edge exiting the region has the same destination basic
220 /// block.
221 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
222 for (auto *Child : CR->Children)
223 validateRegionExits(Child);
224
225 std::unordered_set<BasicBlock *> ExitTargets;
226 for (auto *Exit : CR->Exits) {
227 auto Set = gatherSuccessors(Exit);
228 for (auto *BB : Set) {
229 if (CR->Blocks.count(BB) == 0)
230 ExitTargets.insert(BB);
231 }
232 }
233
234 assert(ExitTargets.size() <= 1);
235 }
236#endif
237
238 virtual bool runOnFunction(Function &F) override {
239 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
240 auto *TopLevelRegion =
241 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
242 .getRegionInfo()
243 .getWritableTopLevelRegion();
244
245 // FIXME: very inefficient method: each time a region is modified, we bubble
246 // back up, and recompute the whole convergence region tree. Once the
247 // algorithm is completed and test coverage good enough, rewrite this pass
248 // to be efficient instead of simple.
249 bool modified = false;
250 while (runOnConvergenceRegion(LI, TopLevelRegion)) {
251 modified = true;
252 }
253
254#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
255 validateRegionExits(TopLevelRegion);
256#endif
257 return modified;
258 }
259
260 void getAnalysisUsage(AnalysisUsage &AU) const override {
264
267 }
268};
269} // namespace
270
271char SPIRVMergeRegionExitTargets::ID = 0;
272
273INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
274 "SPIRV split region exit blocks", false, false)
275INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
279
280INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
281 "SPIRV split region exit blocks", false, false)
282
284 return new SPIRVMergeRegionExitTargets();
285}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
#define F(x, y, z)
Definition: MD5.cpp:55
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
convergence region
split region exit blocks
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
Value * RHS
Value * LHS
an instruction to allocate memory on the stack
Definition: Instructions.h:64
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:206
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:213
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:170
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:233
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:203
unsigned size() const
Definition: DenseMap.h:120
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:230
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:322
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:597
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:112
SmallVector< ConvergenceRegion * > Children
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type size() const
Definition: SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:470
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:401
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:541
Multiway switch.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LLVM Value Representation.
Definition: Value.h:75
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ Exit
Definition: COFF.h:863
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
FunctionPass * createSPIRVMergeRegionExitTargetsPass()