LLVM 22.0.0git
JumpTableToSwitch.cpp
Go to the documentation of this file.
1//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/DenseSet.h"
11#include "llvm/ADT/STLExtras.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/LLVMContext.h"
23#include "llvm/Support/Error.h"
25#include <limits>
26
27using namespace llvm;
28
30 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
31 cl::desc("Only split jump tables with size less or "
32 "equal than JumpTableSizeThreshold."),
33 cl::init(10));
34
35// TODO: Consider adding a cost model for profitability analysis of this
36// transformation. Currently we replace a jump table with a switch if all the
37// functions in the jump table are smaller than the provided threshold.
39 "jump-table-to-switch-function-size-threshold", cl::Hidden,
40 cl::desc("Only split jump tables containing functions whose sizes are less "
41 "or equal than this threshold."),
42 cl::init(50));
43
45
46#define DEBUG_TYPE "jump-table-to-switch"
47
48namespace {
49struct JumpTableTy {
50 Value *Index;
52};
53} // anonymous namespace
54
55static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
56 PointerType *PtrTy) {
57 Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
58 if (!Ptr)
59 return std::nullopt;
60
61 GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
62 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
63 return std::nullopt;
64
65 Function &F = *GEP->getParent()->getParent();
66 const DataLayout &DL = F.getDataLayout();
67 const unsigned BitWidth =
68 DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
70 APInt ConstantOffset(BitWidth, 0);
71 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
72 return std::nullopt;
73 if (VariableOffsets.size() != 1)
74 return std::nullopt;
75 // TODO: consider supporting more general patterns
76 if (!ConstantOffset.isZero())
77 return std::nullopt;
78 APInt StrideBytes = VariableOffsets.front().second;
79 const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
80 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
81 return std::nullopt;
82 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
84 return std::nullopt;
85
86 JumpTableTy JumpTable;
87 JumpTable.Index = VariableOffsets.front().first;
88 JumpTable.Funcs.reserve(N);
89 for (uint64_t Index = 0; Index < N; ++Index) {
90 // ConstantOffset is zero.
91 APInt Offset = Index * StrideBytes;
92 Constant *C =
94 auto *Func = dyn_cast_or_null<Function>(C);
95 if (!Func || Func->isDeclaration() ||
96 Func->getInstructionCount() > FunctionSizeThreshold)
97 return std::nullopt;
98 JumpTable.Funcs.push_back(Func);
99 }
100 return JumpTable;
101}
102
103static BasicBlock *
104expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
107 GetGuidForFunction) {
108 const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
109
111 BasicBlock *BB = CB->getParent();
112 BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
113 BB->getName() + Twine(".tail"));
114 DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
116
117 Function &F = *BB->getParent();
118 BasicBlock *BBUnreachable = BasicBlock::Create(
119 F.getContext(), "default.switch.case.unreachable", &F, Tail);
120 IRBuilder<> BuilderUnreachable(BBUnreachable);
121 BuilderUnreachable.CreateUnreachable();
122
123 IRBuilder<> Builder(BB);
124 SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
125 DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
126
127 IRBuilder<> BuilderTail(CB);
128 PHINode *PHI =
129 IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
130 const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
131
132 SmallVector<uint64_t> BranchWeights;
134 const bool HadProfile = isValueProfileMD(ProfMD);
135 if (HadProfile) {
136 // The assumptions, coming in, are that the functions in JT.Funcs are
137 // defined in this module (from parseJumpTable).
139 JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
140 BranchWeights.reserve(JT.Funcs.size() + 1);
141 // The first is the default target, which is the unreachable block created
142 // above.
143 BranchWeights.push_back(0U);
144 uint64_t TotalCount = 0;
145 auto Targets = getValueProfDataFromInst(
146 *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
147 std::numeric_limits<uint32_t>::max(), TotalCount);
148
149 for (const auto &[G, C] : Targets) {
150 [[maybe_unused]] auto It = GuidToCounter.insert({G, C});
151 assert(It.second);
152 }
153 }
154 for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
155 BasicBlock *B = BasicBlock::Create(Func->getContext(),
156 "call." + Twine(Index), &F, Tail);
157 DTUpdates.push_back({DominatorTree::Insert, BB, B});
158 DTUpdates.push_back({DominatorTree::Insert, B, Tail});
159
160 CallBase *Call = cast<CallBase>(CB->clone());
161 // The MD_prof metadata (VP kind), if it existed, can be dropped, it doesn't
162 // make sense on a direct call. Note that the values are used for the branch
163 // weights of the switch.
164 Call->setMetadata(LLVMContext::MD_prof, nullptr);
165 Call->setCalledFunction(Func);
166 Call->insertInto(B, B->end());
167 Switch->addCase(
168 cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
169 GlobalValue::GUID FctID = GetGuidForFunction(*Func);
170 // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
171 // just some of the jump targets are taken (for the given profile).
172 BranchWeights.push_back(FctID == 0U ? 0U
173 : GuidToCounter.lookup_or(FctID, 0U));
175 if (PHI)
176 PHI->addIncoming(Call, B);
177 }
178 DTU.applyUpdates(DTUpdates);
179 ORE.emit([&]() {
180 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
181 << "expanded indirect call into switch";
182 });
183 if (HadProfile && !ProfcheckDisableMetadataFixes) {
184 // At least one of the targets must've been taken.
185 assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
186 setBranchWeights(*Switch, downscaleWeights(BranchWeights),
187 /*IsExpected=*/false);
188 } else
190 if (PHI)
192 CB->eraseFromParent();
193 return Tail;
194}
195
202 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
203 bool Changed = false;
204 InstrProfSymtab Symtab;
205 if (auto E = Symtab.create(*F.getParent()))
206 F.getContext().emitError(
207 "Could not create indirect call table, likely corrupted IR" +
208 toString(std::move(E)));
210 for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
211 FToGuid.insert({FPtr, G});
212
213 for (BasicBlock &BB : make_early_inc_range(F)) {
214 BasicBlock *CurrentBB = &BB;
215 while (CurrentBB) {
216 BasicBlock *SplittedOutTail = nullptr;
217 for (Instruction &I : make_early_inc_range(*CurrentBB)) {
218 auto *Call = dyn_cast<CallInst>(&I);
219 if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
220 continue;
221 auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
222 // Skip atomic or volatile loads.
223 if (!L || !L->isSimple())
224 continue;
225 auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
226 if (!GEP)
227 continue;
228 auto *PtrTy = dyn_cast<PointerType>(L->getType());
229 assert(PtrTy && "call operand must be a pointer");
230 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
231 if (!JumpTable)
232 continue;
233 SplittedOutTail = expandToSwitch(
234 Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
236 return AssignGUIDPass::getGUID(Fct);
237 return FToGuid.lookup_or(&Fct, 0U);
238 });
239 Changed = true;
240 break;
241 }
242 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
243 }
244 }
245
246 if (!Changed)
247 return PreservedAnalyses::all();
248
250 if (DT)
252 if (PDT)
254 return PA;
255}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseSet and SmallDenseSet classes.
Hexagon Common GEP
cl::opt< bool > ProfcheckDisableMetadataFixes
static cl::opt< unsigned > FunctionSizeThreshold("jump-table-to-switch-function-size-threshold", cl::Hidden, cl::desc("Only split jump tables containing functions whose sizes are less " "or equal than this threshold."), cl::init(50))
static BasicBlock * expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU, OptimizationRemarkEmitter &ORE, llvm::function_ref< GlobalValue::GUID(const Function &)> GetGuidForFunction)
static cl::opt< unsigned > JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, cl::desc("Only split jump tables with size less or " "equal than JumpTableSizeThreshold."), cl::init(10))
static std::optional< JumpTableTy > parseJumpTable(GetElementPtrInst *GEP, PointerType *PtrTy)
#define DEBUG_TYPE
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
This file contains the declarations for profiling metadata utility functions.
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
Class for arbitrary precision integers.
Definition: APInt.h:78
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1540
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
Definition: PassManager.h:431
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:412
static LLVM_ABI uint64_t getGUID(const Function &F)
static LLVM_ABI const char * GUIDMetadataName
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:206
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:213
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:233
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1116
This is an important base class in LLVM.
Definition: Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
ValueT lookup_or(const_arg_type_t< KeyT > Val, U &&Default) const
Definition: DenseMap.h:213
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:230
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:284
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:165
void applyUpdates(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Definition: Instructions.h:949
MDNode * getMetadata(unsigned KindID) const
Get the current metadata attachments for the given kind, if any.
Definition: Value.h:576
Type * getValueType() const
Definition: GlobalValue.h:298
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
bool isConstant() const
If the value is a global constant, its value is immutable throughout the runtime execution of the pro...
bool hasDefinitiveInitializer() const
hasDefinitiveInitializer - Whether the global variable has an initializer, and any other instances of...
UnreachableInst * CreateUnreachable()
Definition: IRBuilder.h:1339
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2494
SwitchInst * CreateSwitch(Value *V, BasicBlock *Dest, unsigned NumCases=10, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a switch instruction with the specified value, default dest, and with a hint for the number of...
Definition: IRBuilder.h:1220
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
A symbol table used for function [IR]PGO name look-up with keys (such as pointers,...
Definition: InstrProf.h:506
const std::vector< std::pair< uint64_t, Function * > > & getIDToNameMap() const
Definition: InstrProf.h:668
LLVM_ABI Error create(object::SectionRef &Section)
Create InstrProfSymtab from an object file section which contains function PGO names.
LLVM_ABI Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
Definition: Instruction.h:428
size_type size() const
Definition: MapVector.h:56
std::pair< KeyT, ValueT > & front()
Definition: MapVector.h:79
The optimization diagnostic interface.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
Analysis pass which computes a PostDominatorTree.
PostDominatorTree Class - Concrete subclass of DominatorTree that is used to compute the post-dominat...
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition: Analysis.h:132
void reserve(size_type N)
Definition: SmallVector.h:664
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
Multiway switch.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:82
static LLVM_ABI Type * getVoidTy(LLVMContext &C)
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:546
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1098
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:322
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:34
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
Definition: CallingConv.h:76
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:444
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:477
LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I)
Specify that the branch weights for this terminator cannot be known at compile time.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1744
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2491
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1751
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef< uint32_t > Weights, bool IsExpected)
Create a new branch_weights metadata node and add or overwrite a prof metadata reference to instructi...
LLVM_ABI Constant * ConstantFoldLoadFromConst(Constant *C, Type *Ty, const APInt &Offset, const DataLayout &DL)
Extract value of C at the given Offset reinterpreted as Ty.
LLVM_ABI SmallVector< InstrProfValueData, 4 > getValueProfDataFromInst(const Instruction &Inst, InstrProfValueKind ValueKind, uint32_t MaxNumValueData, uint64_t &TotalC, bool GetNoICPValue=false)
Extract the value profile data from Inst and returns them if Inst is annotated with value profile dat...
Definition: InstrProf.cpp:1402
LLVM_ABI bool isValueProfileMD(const MDNode *ProfileData)
Checks if an MDNode contains value profiling Metadata.
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:223
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
const char * toString(DWARFSectionKind Kind)
LLVM_ABI SmallVector< uint32_t > downscaleWeights(ArrayRef< uint64_t > Weights, std::optional< uint64_t > KnownMaxCount=std::nullopt)
downscale the given weights preserving the ratio.
#define N
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Run the pass over the function.
A MapVector that performs no allocations if smaller than a certain size.
Definition: MapVector.h:249