LLVM 22.0.0git
SplitModuleByCategory.cpp
Go to the documentation of this file.
1//===-------- SplitModuleByCategory.cpp - split a module by categories ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// See comments in the header.
9//===----------------------------------------------------------------------===//
10
12#include "llvm/ADT/SetVector.h"
15#include "llvm/IR/Function.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Support/Debug.h"
21
22#include <map>
23#include <utility>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "split-module-by-category"
28
29namespace {
30
31// A vector that contains a group of function with the same category.
32using EntryPointSet = SetVector<const Function *>;
33
34/// Represents a group of functions with one category.
35struct EntryPointGroup {
36 int ID;
37 EntryPointSet Functions;
38
39 EntryPointGroup() = default;
40
41 EntryPointGroup(int ID, EntryPointSet &&Functions = EntryPointSet())
42 : ID(ID), Functions(std::move(Functions)) {}
43
44 void clear() { Functions.clear(); }
45
46#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
47 LLVM_DUMP_METHOD void dump() const {
48 constexpr size_t INDENT = 4;
49 dbgs().indent(INDENT) << "ENTRY POINTS"
50 << " " << ID << " {\n";
51 for (const Function *F : Functions)
52 dbgs().indent(INDENT) << " " << F->getName() << "\n";
53
54 dbgs().indent(INDENT) << "}\n";
55 }
56#endif
57};
58
59/// Annotates an llvm::Module with information necessary to perform and track
60/// the result of code (llvm::Module instances) splitting:
61/// - entry points group from the module.
62class ModuleDesc {
63 std::unique_ptr<Module> M;
64 EntryPointGroup EntryPoints;
65
66public:
67 ModuleDesc(std::unique_ptr<Module> M,
68 EntryPointGroup &&EntryPoints = EntryPointGroup())
69 : M(std::move(M)), EntryPoints(std::move(EntryPoints)) {
70 assert(this->M && "Module should be non-null");
71 }
72
73 Module &getModule() { return *M; }
74 const Module &getModule() const { return *M; }
75
76 std::unique_ptr<Module> releaseModule() {
77 EntryPoints.clear();
78 return std::move(M);
79 }
80
81#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
82 LLVM_DUMP_METHOD void dump() const {
83 dbgs() << "ModuleDesc[" << M->getName() << "] {\n";
84 EntryPoints.dump();
85 dbgs() << "}\n";
86 }
87#endif
88};
89
90bool isKernel(const Function &F) {
91 return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
92 F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
93 F.getCallingConv() == CallingConv::PTX_Kernel;
94}
95
96// Represents "dependency" or "use" graph of global objects (functions and
97// global variables) in a module. It is used during code split to
98// understand which global variables and functions (other than entry points)
99// should be included into a split module.
100//
101// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent
102// the fact that if "A" is included into a module, then "B" should be included
103// as well.
104//
105// Examples of dependencies which are represented in this graph:
106// - Function FA calls function FB
107// - Function FA uses global variable GA
108// - Global variable GA references (initialized with) function FB
109// - Function FA stores address of a function FB somewhere
110//
111// The following cases are treated as dependencies between global objects:
112// 1. Global object A is used by a global object B in any way (store,
113// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the
114// graph;
115// 2. function A performs an indirect call of a function with signature S and
116// there is a function B with signature S. "A" -> "B" edge will be added to
117// the graph;
118class DependencyGraph {
119public:
120 using GlobalSet = SmallPtrSet<const GlobalValue *, 16>;
121
122 DependencyGraph(const Module &M) {
123 // Group functions by their signature to handle case (2) described above
125 FuncTypeToFuncsMap;
126 for (const Function &F : M.functions()) {
127 // Kernels can't be called (either directly or indirectly).
128 if (isKernel(F))
129 continue;
130
131 FuncTypeToFuncsMap[F.getFunctionType()].insert(&F);
132 }
133
134 for (const Function &F : M.functions()) {
135 // case (1), see comment above the class definition
136 for (const Value *U : F.users())
137 addUserToGraphRecursively(cast<const User>(U), &F);
138
139 // case (2), see comment above the class definition
140 for (const Instruction &I : instructions(F)) {
141 const CallBase *CB = dyn_cast<CallBase>(&I);
142 if (!CB || !CB->isIndirectCall()) // Direct calls were handled above
143 continue;
144
145 const FunctionType *Signature = CB->getFunctionType();
146 GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature];
147 Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end());
148 }
149 }
150
151 // And every global variable (but their handling is a bit simpler)
152 for (const GlobalVariable &GV : M.globals())
153 for (const Value *U : GV.users())
154 addUserToGraphRecursively(cast<const User>(U), &GV);
155 }
156
158 dependencies(const GlobalValue *Val) const {
159 auto It = Graph.find(Val);
160 return (It == Graph.end())
161 ? make_range(EmptySet.begin(), EmptySet.end())
162 : make_range(It->second.begin(), It->second.end());
163 }
164
165private:
166 void addUserToGraphRecursively(const User *Root, const GlobalValue *V) {
168 WorkList.push_back(Root);
169
170 while (!WorkList.empty()) {
171 const User *U = WorkList.pop_back_val();
172 if (const auto *I = dyn_cast<const Instruction>(U)) {
173 const Function *UFunc = I->getFunction();
174 Graph[UFunc].insert(V);
175 } else if (isa<const Constant>(U)) {
176 if (const auto *GV = dyn_cast<const GlobalVariable>(U))
177 Graph[GV].insert(V);
178 // This could be a global variable or some constant expression (like
179 // bitcast or gep). We trace users of this constant further to reach
180 // global objects they are used by and add them to the graph.
181 for (const User *UU : U->users())
182 WorkList.push_back(UU);
183 } else {
184 llvm_unreachable("Unhandled type of function user");
185 }
186 }
187 }
188
191};
192
193void collectFunctionsAndGlobalVariablesToExtract(
195 const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) {
196 // We start with module entry points
197 for (const Function *F : ModuleEntryPoints.Functions)
198 GVs.insert(F);
199
200 // Non-discardable global variables are also include into the initial set
201 for (const GlobalVariable &GV : M.globals())
202 if (!GV.isDiscardableIfUnused())
203 GVs.insert(&GV);
204
205 // GVs has SetVector type. This type inserts a value only if it is not yet
206 // present there. So, recursion is not expected here.
207 size_t Idx = 0;
208 while (Idx < GVs.size()) {
209 const GlobalValue *Obj = GVs[Idx++];
210
211 for (const GlobalValue *Dep : DG.dependencies(Obj)) {
212 if (const auto *Func = dyn_cast<const Function>(Dep)) {
213 if (!Func->isDeclaration())
214 GVs.insert(Func);
215 } else {
216 GVs.insert(Dep); // Global variables are added unconditionally
217 }
218 }
219 }
220}
221
222ModuleDesc extractSubModule(const Module &M,
224 EntryPointGroup &&ModuleEntryPoints) {
226 // Clone definitions only for needed globals. Others will be added as
227 // declarations and removed later.
228 std::unique_ptr<Module> SubM = CloneModule(
229 M, VMap, [&](const GlobalValue *GV) { return GVs.contains(GV); });
230 // Replace entry points with cloned ones.
231 EntryPointSet NewEPs;
232 const EntryPointSet &EPs = ModuleEntryPoints.Functions;
234 EPs, [&](const Function *F) { NewEPs.insert(cast<Function>(VMap[F])); });
235 ModuleEntryPoints.Functions = std::move(NewEPs);
236 return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)};
237}
238
239// The function produces a copy of input LLVM IR module M with only those
240// functions and globals that can be called from entry points that are specified
241// in ModuleEntryPoints vector, in addition to the entry point functions.
242ModuleDesc extractCallGraph(const Module &M,
243 EntryPointGroup &&ModuleEntryPoints,
244 const DependencyGraph &DG) {
246 collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG);
247
248 ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints));
249 LLVM_DEBUG(SplitM.dump());
250 return SplitM;
251}
252
253using EntryPointGroupVec = SmallVector<EntryPointGroup>;
254
255/// Module Splitter.
256/// It gets a module and a collection of entry points groups.
257/// Each group specifies subset entry points from input module that should be
258/// included in a split module.
259class ModuleSplitter {
260private:
261 std::unique_ptr<Module> M;
262 EntryPointGroupVec Groups;
263 DependencyGraph DG;
264
265private:
266 EntryPointGroup drawEntryPointGroup() {
267 assert(Groups.size() > 0 && "Reached end of entry point groups list.");
268 EntryPointGroup Group = std::move(Groups.back());
269 Groups.pop_back();
270 return Group;
271 }
272
273public:
274 ModuleSplitter(std::unique_ptr<Module> Module, EntryPointGroupVec &&GroupVec)
275 : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) {
276 assert(!Groups.empty() && "Entry points groups collection is empty!");
277 }
278
279 /// Gets next subsequence of entry points in an input module and provides
280 /// split submodule containing these entry points and their dependencies.
281 ModuleDesc getNextSplit() {
282 return extractCallGraph(*M, drawEntryPointGroup(), DG);
283 }
284
285 /// Check that there are still submodules to split.
286 bool hasMoreSplits() const { return Groups.size() > 0; }
287};
288
289EntryPointGroupVec selectEntryPointGroups(
290 const Module &M, function_ref<std::optional<int>(const Function &F)> EPC) {
291 // std::map is used here to ensure stable ordering of entry point groups,
292 // which is based on their contents, this greatly helps LIT tests
293 // Note: EPC is allowed to return big identifiers. Therefore, we use
294 // std::map + SmallVector approach here.
295 std::map<int, EntryPointSet> EntryPointsMap;
296
297 for (const auto &F : M.functions())
298 if (std::optional<int> Category = EPC(F); Category)
299 EntryPointsMap[*Category].insert(&F);
300
301 EntryPointGroupVec Groups;
302 Groups.reserve(EntryPointsMap.size());
303 for (auto &[Key, EntryPoints] : EntryPointsMap)
304 Groups.emplace_back(Key, std::move(EntryPoints));
305
306 return Groups;
307}
308
309} // namespace
310
312 std::unique_ptr<Module> M,
313 function_ref<std::optional<int>(const Function &F)> EntryPointCategorizer,
314 function_ref<void(std::unique_ptr<Module> Part)> Callback) {
315 EntryPointGroupVec Groups = selectEntryPointGroups(*M, EntryPointCategorizer);
316 ModuleSplitter Splitter(std::move(M), std::move(Groups));
317 while (Splitter.hasMoreSplits()) {
318 ModuleDesc MD = Splitter.getNextSplit();
319 Callback(MD.releaseModule());
320 }
321}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:638
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static SmallVector< const DIVariable *, 2 > dependencies(DbgVariable *Var)
Return all DIVariables that appear in count: expressions.
static ThreadSafeModule extractSubModule(ThreadSafeModule &TSM, StringRef Suffix, GVPredicate ShouldExtract)
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition: Debug.h:119
static const X86InstrFMA3Group Groups[]
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1116
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1205
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:214
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
A vector that has set insertion semantics.
Definition: SetVector.h:59
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:104
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:168
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition: SetVector.h:269
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:541
bool empty() const
Definition: SmallVector.h:82
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
LLVM Value Representation.
Definition: Value.h:75
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ AMDGPU_KERNEL
Used for AMDGPU code object kernels.
Definition: CallingConv.h:200
@ SPIR_KERNEL
Used for SPIR kernel functions.
Definition: CallingConv.h:144
@ PTX_Kernel
Call to a PTX kernel. Passes all arguments in parameter space.
Definition: CallingConv.h:125
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1737
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ABI void splitModuleTransitiveFromEntryPoints(std::unique_ptr< Module > M, function_ref< std::optional< int >(const Function &F)> EntryPointCategorizer, function_ref< void(std::unique_ptr< Module > Part)> Callback)
Splits the given module M into parts.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
LLVM_ABI std::unique_ptr< Module > CloneModule(const Module &M)
Return an exact copy of the specified module.
Definition: CloneModule.cpp:40