LLVM 22.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr const int InitKernelEnvironmentArgNo = 0;
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 const Triple T(OMPBuilder.M.getTargetTriple());
291 switch (T.getArch()) {
295 assert(OMPBuilder.Config.IsTargetDevice &&
296 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
297 OMPBuilder.Config.IsGPU = true;
298 break;
299 default:
300 OMPBuilder.Config.IsGPU = false;
301 break;
302 }
303 OMPBuilder.initialize();
304 initializeRuntimeFunctions(M);
305 initializeInternalControlVars();
306 }
307
308 /// Generic information that describes an internal control variable.
309 struct InternalControlVarInfo {
310 /// The kind, as described by InternalControlVar enum.
312
313 /// The name of the ICV.
314 StringRef Name;
315
316 /// Environment variable associated with this ICV.
317 StringRef EnvVarName;
318
319 /// Initial value kind.
320 ICVInitValue InitKind;
321
322 /// Initial value.
323 ConstantInt *InitValue;
324
325 /// Setter RTL function associated with this ICV.
326 RuntimeFunction Setter;
327
328 /// Getter RTL function associated with this ICV.
329 RuntimeFunction Getter;
330
331 /// RTL Function corresponding to the override clause of this ICV
332 RuntimeFunction Clause;
333 };
334
335 /// Generic information that describes a runtime function
336 struct RuntimeFunctionInfo {
337
338 /// The kind, as described by the RuntimeFunction enum.
339 RuntimeFunction Kind;
340
341 /// The name of the function.
342 StringRef Name;
343
344 /// Flag to indicate a variadic function.
345 bool IsVarArg;
346
347 /// The return type of the function.
348 Type *ReturnType;
349
350 /// The argument types of the function.
351 SmallVector<Type *, 8> ArgumentTypes;
352
353 /// The declaration if available.
354 Function *Declaration = nullptr;
355
356 /// Uses of this runtime function per function containing the use.
357 using UseVector = SmallVector<Use *, 16>;
358
359 /// Clear UsesMap for runtime function.
360 void clearUsesMap() { UsesMap.clear(); }
361
362 /// Boolean conversion that is true if the runtime function was found.
363 operator bool() const { return Declaration; }
364
365 /// Return the vector of uses in function \p F.
366 UseVector &getOrCreateUseVector(Function *F) {
367 std::shared_ptr<UseVector> &UV = UsesMap[F];
368 if (!UV)
369 UV = std::make_shared<UseVector>();
370 return *UV;
371 }
372
373 /// Return the vector of uses in function \p F or `nullptr` if there are
374 /// none.
375 const UseVector *getUseVector(Function &F) const {
376 auto I = UsesMap.find(&F);
377 if (I != UsesMap.end())
378 return I->second.get();
379 return nullptr;
380 }
381
382 /// Return how many functions contain uses of this runtime function.
383 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
384
385 /// Return the number of arguments (or the minimal number for variadic
386 /// functions).
387 size_t getNumArgs() const { return ArgumentTypes.size(); }
388
389 /// Run the callback \p CB on each use and forget the use if the result is
390 /// true. The callback will be fed the function in which the use was
391 /// encountered as second argument.
392 void foreachUse(SmallVectorImpl<Function *> &SCC,
393 function_ref<bool(Use &, Function &)> CB) {
394 for (Function *F : SCC)
395 foreachUse(CB, F);
396 }
397
398 /// Run the callback \p CB on each use within the function \p F and forget
399 /// the use if the result is true.
400 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
401 SmallVector<unsigned, 8> ToBeDeleted;
402 ToBeDeleted.clear();
403
404 unsigned Idx = 0;
405 UseVector &UV = getOrCreateUseVector(F);
406
407 for (Use *U : UV) {
408 if (CB(*U, *F))
409 ToBeDeleted.push_back(Idx);
410 ++Idx;
411 }
412
413 // Remove the to-be-deleted indices in reverse order as prior
414 // modifications will not modify the smaller indices.
415 while (!ToBeDeleted.empty()) {
416 unsigned Idx = ToBeDeleted.pop_back_val();
417 UV[Idx] = UV.back();
418 UV.pop_back();
419 }
420 }
421
422 private:
423 /// Map from functions to all uses of this runtime function contained in
424 /// them.
425 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
426
427 public:
428 /// Iterators for the uses of this runtime function.
429 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
430 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
431 };
432
433 /// An OpenMP-IR-Builder instance
434 OpenMPIRBuilder OMPBuilder;
435
436 /// Map from runtime function kind to the runtime function description.
437 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
438 RuntimeFunction::OMPRTL___last>
439 RFIs;
440
441 /// Map from function declarations/definitions to their runtime enum type.
442 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
443
444 /// Map from ICV kind to the ICV description.
445 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
446 InternalControlVar::ICV___last>
447 ICVs;
448
449 /// Helper to initialize all internal control variable information for those
450 /// defined in OMPKinds.def.
451 void initializeInternalControlVars() {
452#define ICV_RT_SET(_Name, RTL) \
453 { \
454 auto &ICV = ICVs[_Name]; \
455 ICV.Setter = RTL; \
456 }
457#define ICV_RT_GET(Name, RTL) \
458 { \
459 auto &ICV = ICVs[Name]; \
460 ICV.Getter = RTL; \
461 }
462#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
463 { \
464 auto &ICV = ICVs[Enum]; \
465 ICV.Name = _Name; \
466 ICV.Kind = Enum; \
467 ICV.InitKind = Init; \
468 ICV.EnvVarName = _EnvVarName; \
469 switch (ICV.InitKind) { \
470 case ICV_IMPLEMENTATION_DEFINED: \
471 ICV.InitValue = nullptr; \
472 break; \
473 case ICV_ZERO: \
474 ICV.InitValue = ConstantInt::get( \
475 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
476 break; \
477 case ICV_FALSE: \
478 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
479 break; \
480 case ICV_LAST: \
481 break; \
482 } \
483 }
484#include "llvm/Frontend/OpenMP/OMPKinds.def"
485 }
486
487 /// Returns true if the function declaration \p F matches the runtime
488 /// function types, that is, return type \p RTFRetType, and argument types
489 /// \p RTFArgTypes.
490 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
491 SmallVector<Type *, 8> &RTFArgTypes) {
492 // TODO: We should output information to the user (under debug output
493 // and via remarks).
494
495 if (!F)
496 return false;
497 if (F->getReturnType() != RTFRetType)
498 return false;
499 if (F->arg_size() != RTFArgTypes.size())
500 return false;
501
502 auto *RTFTyIt = RTFArgTypes.begin();
503 for (Argument &Arg : F->args()) {
504 if (Arg.getType() != *RTFTyIt)
505 return false;
506
507 ++RTFTyIt;
508 }
509
510 return true;
511 }
512
513 // Helper to collect all uses of the declaration in the UsesMap.
514 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
515 unsigned NumUses = 0;
516 if (!RFI.Declaration)
517 return NumUses;
518 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
519
520 if (CollectStats) {
521 NumOpenMPRuntimeFunctionsIdentified += 1;
522 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
523 }
524
525 // TODO: We directly convert uses into proper calls and unknown uses.
526 for (Use &U : RFI.Declaration->uses()) {
527 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
528 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
529 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
530 ++NumUses;
531 }
532 } else {
533 RFI.getOrCreateUseVector(nullptr).push_back(&U);
534 ++NumUses;
535 }
536 }
537 return NumUses;
538 }
539
540 // Helper function to recollect uses of a runtime function.
541 void recollectUsesForFunction(RuntimeFunction RTF) {
542 auto &RFI = RFIs[RTF];
543 RFI.clearUsesMap();
544 collectUses(RFI, /*CollectStats*/ false);
545 }
546
547 // Helper function to recollect uses of all runtime functions.
548 void recollectUses() {
549 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
550 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
551 }
552
553 // Helper function to inherit the calling convention of the function callee.
554 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
555 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
556 CI->setCallingConv(Fn->getCallingConv());
557 }
558
559 // Helper function to determine if it's legal to create a call to the runtime
560 // functions.
561 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
562 // We can always emit calls if we haven't yet linked in the runtime.
563 if (!OpenMPPostLink)
564 return true;
565
566 // Once the runtime has been already been linked in we cannot emit calls to
567 // any undefined functions.
568 for (RuntimeFunction Fn : Fns) {
569 RuntimeFunctionInfo &RFI = RFIs[Fn];
570
571 if (!RFI.Declaration || RFI.Declaration->isDeclaration())
572 return false;
573 }
574 return true;
575 }
576
577 /// Helper to initialize all runtime function information for those defined
578 /// in OpenMPKinds.def.
579 void initializeRuntimeFunctions(Module &M) {
580
581 // Helper macros for handling __VA_ARGS__ in OMP_RTL
582#define OMP_TYPE(VarName, ...) \
583 Type *VarName = OMPBuilder.VarName; \
584 (void)VarName;
585
586#define OMP_ARRAY_TYPE(VarName, ...) \
587 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
588 (void)VarName##Ty; \
589 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
590 (void)VarName##PtrTy;
591
592#define OMP_FUNCTION_TYPE(VarName, ...) \
593 FunctionType *VarName = OMPBuilder.VarName; \
594 (void)VarName; \
595 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
596 (void)VarName##Ptr;
597
598#define OMP_STRUCT_TYPE(VarName, ...) \
599 StructType *VarName = OMPBuilder.VarName; \
600 (void)VarName; \
601 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
602 (void)VarName##Ptr;
603
604#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
605 { \
606 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
607 Function *F = M.getFunction(_Name); \
608 RTLFunctions.insert(F); \
609 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
610 RuntimeFunctionIDMap[F] = _Enum; \
611 auto &RFI = RFIs[_Enum]; \
612 RFI.Kind = _Enum; \
613 RFI.Name = _Name; \
614 RFI.IsVarArg = _IsVarArg; \
615 RFI.ReturnType = OMPBuilder._ReturnType; \
616 RFI.ArgumentTypes = std::move(ArgsTypes); \
617 RFI.Declaration = F; \
618 unsigned NumUses = collectUses(RFI); \
619 (void)NumUses; \
620 LLVM_DEBUG({ \
621 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
622 << " found\n"; \
623 if (RFI.Declaration) \
624 dbgs() << TAG << "-> got " << NumUses << " uses in " \
625 << RFI.getNumFunctionsWithUses() \
626 << " different functions.\n"; \
627 }); \
628 } \
629 }
630#include "llvm/Frontend/OpenMP/OMPKinds.def"
631
632 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
633 // functions, except if `optnone` is present.
634 if (isOpenMPDevice(M)) {
635 for (Function &F : M) {
636 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
637 if (F.hasFnAttribute(Attribute::NoInline) &&
638 F.getName().starts_with(Prefix) &&
639 !F.hasFnAttribute(Attribute::OptimizeNone))
640 F.removeFnAttr(Attribute::NoInline);
641 }
642 }
643
644 // TODO: We should attach the attributes defined in OMPKinds.def.
645 }
646
647 /// Collection of known OpenMP runtime functions..
648 DenseSet<const Function *> RTLFunctions;
649
650 /// Indicates if we have already linked in the OpenMP device library.
651 bool OpenMPPostLink = false;
652};
653
654template <typename Ty, bool InsertInvalidates = true>
655struct BooleanStateWithSetVector : public BooleanState {
656 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
657 bool insert(const Ty &Elem) {
658 if (InsertInvalidates)
659 BooleanState::indicatePessimisticFixpoint();
660 return Set.insert(Elem);
661 }
662
663 const Ty &operator[](int Idx) const { return Set[Idx]; }
664 bool operator==(const BooleanStateWithSetVector &RHS) const {
665 return BooleanState::operator==(RHS) && Set == RHS.Set;
666 }
667 bool operator!=(const BooleanStateWithSetVector &RHS) const {
668 return !(*this == RHS);
669 }
670
671 bool empty() const { return Set.empty(); }
672 size_t size() const { return Set.size(); }
673
674 /// "Clamp" this state with \p RHS.
675 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
676 BooleanState::operator^=(RHS);
677 Set.insert_range(RHS.Set);
678 return *this;
679 }
680
681private:
682 /// A set to keep track of elements.
683 SetVector<Ty> Set;
684
685public:
686 typename decltype(Set)::iterator begin() { return Set.begin(); }
687 typename decltype(Set)::iterator end() { return Set.end(); }
688 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
689 typename decltype(Set)::const_iterator end() const { return Set.end(); }
690};
691
692template <typename Ty, bool InsertInvalidates = true>
693using BooleanStateWithPtrSetVector =
694 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
695
696struct KernelInfoState : AbstractState {
697 /// Flag to track if we reached a fixpoint.
698 bool IsAtFixpoint = false;
699
700 /// The parallel regions (identified by the outlined parallel functions) that
701 /// can be reached from the associated function.
702 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
703 ReachedKnownParallelRegions;
704
705 /// State to track what parallel region we might reach.
706 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
707
708 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
709 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
710 /// false.
711 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
712
713 /// The __kmpc_target_init call in this kernel, if any. If we find more than
714 /// one we abort as the kernel is malformed.
715 CallBase *KernelInitCB = nullptr;
716
717 /// The constant kernel environement as taken from and passed to
718 /// __kmpc_target_init.
719 ConstantStruct *KernelEnvC = nullptr;
720
721 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
722 /// one we abort as the kernel is malformed.
723 CallBase *KernelDeinitCB = nullptr;
724
725 /// Flag to indicate if the associated function is a kernel entry.
726 bool IsKernelEntry = false;
727
728 /// State to track what kernel entries can reach the associated function.
729 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
730
731 /// State to indicate if we can track parallel level of the associated
732 /// function. We will give up tracking if we encounter unknown caller or the
733 /// caller is __kmpc_parallel_51.
734 BooleanStateWithSetVector<uint8_t> ParallelLevels;
735
736 /// Flag that indicates if the kernel has nested Parallelism
737 bool NestedParallelism = false;
738
739 /// Abstract State interface
740 ///{
741
742 KernelInfoState() = default;
743 KernelInfoState(bool BestState) {
744 if (!BestState)
745 indicatePessimisticFixpoint();
746 }
747
748 /// See AbstractState::isValidState(...)
749 bool isValidState() const override { return true; }
750
751 /// See AbstractState::isAtFixpoint(...)
752 bool isAtFixpoint() const override { return IsAtFixpoint; }
753
754 /// See AbstractState::indicatePessimisticFixpoint(...)
755 ChangeStatus indicatePessimisticFixpoint() override {
756 IsAtFixpoint = true;
757 ParallelLevels.indicatePessimisticFixpoint();
758 ReachingKernelEntries.indicatePessimisticFixpoint();
759 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
760 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
761 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
762 NestedParallelism = true;
763 return ChangeStatus::CHANGED;
764 }
765
766 /// See AbstractState::indicateOptimisticFixpoint(...)
767 ChangeStatus indicateOptimisticFixpoint() override {
768 IsAtFixpoint = true;
769 ParallelLevels.indicateOptimisticFixpoint();
770 ReachingKernelEntries.indicateOptimisticFixpoint();
771 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
772 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
773 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 return ChangeStatus::UNCHANGED;
775 }
776
777 /// Return the assumed state
778 KernelInfoState &getAssumed() { return *this; }
779 const KernelInfoState &getAssumed() const { return *this; }
780
781 bool operator==(const KernelInfoState &RHS) const {
782 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
783 return false;
784 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
785 return false;
786 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
787 return false;
788 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
789 return false;
790 if (ParallelLevels != RHS.ParallelLevels)
791 return false;
792 if (NestedParallelism != RHS.NestedParallelism)
793 return false;
794 return true;
795 }
796
797 /// Returns true if this kernel contains any OpenMP parallel regions.
798 bool mayContainParallelRegion() {
799 return !ReachedKnownParallelRegions.empty() ||
800 !ReachedUnknownParallelRegions.empty();
801 }
802
803 /// Return empty set as the best state of potential values.
804 static KernelInfoState getBestState() { return KernelInfoState(true); }
805
806 static KernelInfoState getBestState(KernelInfoState &KIS) {
807 return getBestState();
808 }
809
810 /// Return full set as the worst state of potential values.
811 static KernelInfoState getWorstState() { return KernelInfoState(false); }
812
813 /// "Clamp" this state with \p KIS.
814 KernelInfoState operator^=(const KernelInfoState &KIS) {
815 // Do not merge two different _init and _deinit call sites.
816 if (KIS.KernelInitCB) {
817 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
818 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
819 "assumptions.");
820 KernelInitCB = KIS.KernelInitCB;
821 }
822 if (KIS.KernelDeinitCB) {
823 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
824 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
825 "assumptions.");
826 KernelDeinitCB = KIS.KernelDeinitCB;
827 }
828 if (KIS.KernelEnvC) {
829 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
830 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
831 "assumptions.");
832 KernelEnvC = KIS.KernelEnvC;
833 }
834 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
835 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
836 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
837 NestedParallelism |= KIS.NestedParallelism;
838 return *this;
839 }
840
841 KernelInfoState operator&=(const KernelInfoState &KIS) {
842 return (*this ^= KIS);
843 }
844
845 ///}
846};
847
848/// Used to map the values physically (in the IR) stored in an offload
849/// array, to a vector in memory.
850struct OffloadArray {
851 /// Physical array (in the IR).
852 AllocaInst *Array = nullptr;
853 /// Mapped values.
854 SmallVector<Value *, 8> StoredValues;
855 /// Last stores made in the offload array.
856 SmallVector<StoreInst *, 8> LastAccesses;
857
858 OffloadArray() = default;
859
860 /// Initializes the OffloadArray with the values stored in \p Array before
861 /// instruction \p Before is reached. Returns false if the initialization
862 /// fails.
863 /// This MUST be used immediately after the construction of the object.
864 bool initialize(AllocaInst &Array, Instruction &Before) {
865 if (!Array.getAllocatedType()->isArrayTy())
866 return false;
867
868 if (!getValues(Array, Before))
869 return false;
870
871 this->Array = &Array;
872 return true;
873 }
874
875 static const unsigned DeviceIDArgNum = 1;
876 static const unsigned BasePtrsArgNum = 3;
877 static const unsigned PtrsArgNum = 4;
878 static const unsigned SizesArgNum = 5;
879
880private:
881 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
882 /// \p Array, leaving StoredValues with the values stored before the
883 /// instruction \p Before is reached.
884 bool getValues(AllocaInst &Array, Instruction &Before) {
885 // Initialize container.
886 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
887 StoredValues.assign(NumValues, nullptr);
888 LastAccesses.assign(NumValues, nullptr);
889
890 // TODO: This assumes the instruction \p Before is in the same
891 // BasicBlock as Array. Make it general, for any control flow graph.
892 BasicBlock *BB = Array.getParent();
893 if (BB != Before.getParent())
894 return false;
895
896 const DataLayout &DL = Array.getDataLayout();
897 const unsigned int PointerSize = DL.getPointerSize();
898
899 for (Instruction &I : *BB) {
900 if (&I == &Before)
901 break;
902
903 if (!isa<StoreInst>(&I))
904 continue;
905
906 auto *S = cast<StoreInst>(&I);
907 int64_t Offset = -1;
908 auto *Dst =
909 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
910 if (Dst == &Array) {
911 int64_t Idx = Offset / PointerSize;
912 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
913 LastAccesses[Idx] = S;
914 }
915 }
916
917 return isFilled();
918 }
919
920 /// Returns true if all values in StoredValues and
921 /// LastAccesses are not nullptrs.
922 bool isFilled() {
923 const unsigned NumValues = StoredValues.size();
924 for (unsigned I = 0; I < NumValues; ++I) {
925 if (!StoredValues[I] || !LastAccesses[I])
926 return false;
927 }
928
929 return true;
930 }
931};
932
933struct OpenMPOpt {
934
935 using OptimizationRemarkGetter =
936 function_ref<OptimizationRemarkEmitter &(Function *)>;
937
938 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
939 OptimizationRemarkGetter OREGetter,
940 OMPInformationCache &OMPInfoCache, Attributor &A)
941 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
942 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
943
944 /// Check if any remarks are enabled for openmp-opt
945 bool remarksEnabled() {
946 auto &Ctx = M.getContext();
947 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
948 }
949
950 /// Run all OpenMP optimizations on the underlying SCC.
951 bool run(bool IsModulePass) {
952 if (SCC.empty())
953 return false;
954
955 bool Changed = false;
956
957 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
958 << " functions\n");
959
960 if (IsModulePass) {
961 Changed |= runAttributor(IsModulePass);
962
963 // Recollect uses, in case Attributor deleted any.
964 OMPInfoCache.recollectUses();
965
966 // TODO: This should be folded into buildCustomStateMachine.
967 Changed |= rewriteDeviceCodeStateMachine();
968
969 if (remarksEnabled())
970 analysisGlobalization();
971 } else {
972 if (PrintICVValues)
973 printICVs();
975 printKernels();
976
977 Changed |= runAttributor(IsModulePass);
978
979 // Recollect uses, in case Attributor deleted any.
980 OMPInfoCache.recollectUses();
981
982 Changed |= deleteParallelRegions();
983
985 Changed |= hideMemTransfersLatency();
986 Changed |= deduplicateRuntimeCalls();
988 if (mergeParallelRegions()) {
989 deduplicateRuntimeCalls();
990 Changed = true;
991 }
992 }
993 }
994
995 if (OMPInfoCache.OpenMPPostLink)
996 Changed |= removeRuntimeSymbols();
997
998 return Changed;
999 }
1000
1001 /// Print initial ICV values for testing.
1002 /// FIXME: This should be done from the Attributor once it is added.
1003 void printICVs() const {
1004 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1005 ICV_proc_bind};
1006
1007 for (Function *F : SCC) {
1008 for (auto ICV : ICVs) {
1009 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1010 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1011 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1012 << " Value: "
1013 << (ICVInfo.InitValue
1014 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1015 : "IMPLEMENTATION_DEFINED");
1016 };
1017
1018 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1019 }
1020 }
1021 }
1022
1023 /// Print OpenMP GPU kernels for testing.
1024 void printKernels() const {
1025 for (Function *F : SCC) {
1026 if (!omp::isOpenMPKernel(*F))
1027 continue;
1028
1029 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1030 return ORA << "OpenMP GPU kernel "
1031 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1032 };
1033
1035 }
1036 }
1037
1038 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1039 /// given it has to be the callee or a nullptr is returned.
1040 static CallInst *getCallIfRegularCall(
1041 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1042 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1043 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1044 (!RFI ||
1045 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1046 return CI;
1047 return nullptr;
1048 }
1049
1050 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1051 /// the callee or a nullptr is returned.
1052 static CallInst *getCallIfRegularCall(
1053 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1054 CallInst *CI = dyn_cast<CallInst>(&V);
1055 if (CI && !CI->hasOperandBundles() &&
1056 (!RFI ||
1057 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1058 return CI;
1059 return nullptr;
1060 }
1061
1062private:
1063 /// Merge parallel regions when it is safe.
1064 bool mergeParallelRegions() {
1065 const unsigned CallbackCalleeOperand = 2;
1066 const unsigned CallbackFirstArgOperand = 3;
1067 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1068
1069 // Check if there are any __kmpc_fork_call calls to merge.
1070 OMPInformationCache::RuntimeFunctionInfo &RFI =
1071 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1072
1073 if (!RFI.Declaration)
1074 return false;
1075
1076 // Unmergable calls that prevent merging a parallel region.
1077 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1078 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1079 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1080 };
1081
1082 bool Changed = false;
1083 LoopInfo *LI = nullptr;
1084 DominatorTree *DT = nullptr;
1085
1086 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1087
1088 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1089 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1090 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1091 BasicBlock *CGEndBB =
1092 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1093 assert(StartBB != nullptr && "StartBB should not be null");
1094 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1095 assert(EndBB != nullptr && "EndBB should not be null");
1096 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1097 return Error::success();
1098 };
1099
1100 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1101 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1102 ReplacementValue = &Inner;
1103 return CodeGenIP;
1104 };
1105
1106 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1107
1108 /// Create a sequential execution region within a merged parallel region,
1109 /// encapsulated in a master construct with a barrier for synchronization.
1110 auto CreateSequentialRegion = [&](Function *OuterFn,
1111 BasicBlock *OuterPredBB,
1112 Instruction *SeqStartI,
1113 Instruction *SeqEndI) {
1114 // Isolate the instructions of the sequential region to a separate
1115 // block.
1116 BasicBlock *ParentBB = SeqStartI->getParent();
1117 BasicBlock *SeqEndBB =
1118 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1119 BasicBlock *SeqAfterBB =
1120 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1121 BasicBlock *SeqStartBB =
1122 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1123
1124 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1125 "Expected a different CFG");
1126 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1127 ParentBB->getTerminator()->eraseFromParent();
1128
1129 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1130 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1131 BasicBlock *CGEndBB =
1132 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1133 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1134 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1135 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1136 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1137 return Error::success();
1138 };
1139 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1140
1141 // Find outputs from the sequential region to outside users and
1142 // broadcast their values to them.
1143 for (Instruction &I : *SeqStartBB) {
1144 SmallPtrSet<Instruction *, 4> OutsideUsers;
1145 for (User *Usr : I.users()) {
1146 Instruction &UsrI = *cast<Instruction>(Usr);
1147 // Ignore outputs to LT intrinsics, code extraction for the merged
1148 // parallel region will fix them.
1149 if (UsrI.isLifetimeStartOrEnd())
1150 continue;
1151
1152 if (UsrI.getParent() != SeqStartBB)
1153 OutsideUsers.insert(&UsrI);
1154 }
1155
1156 if (OutsideUsers.empty())
1157 continue;
1158
1159 // Emit an alloca in the outer region to store the broadcasted
1160 // value.
1161 const DataLayout &DL = M.getDataLayout();
1162 AllocaInst *AllocaI = new AllocaInst(
1163 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1164 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1165
1166 // Emit a store instruction in the sequential BB to update the
1167 // value.
1168 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1169
1170 // Emit a load instruction and replace the use of the output value
1171 // with it.
1172 for (Instruction *UsrI : OutsideUsers) {
1173 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1174 I.getName() + ".seq.output.load",
1175 UsrI->getIterator());
1176 UsrI->replaceUsesOfWith(&I, LoadI);
1177 }
1178 }
1179
1180 OpenMPIRBuilder::LocationDescription Loc(
1181 InsertPointTy(ParentBB, ParentBB->end()), DL);
1182 OpenMPIRBuilder::InsertPointTy SeqAfterIP = cantFail(
1183 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1184 cantFail(
1185 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1186
1187 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1188
1189 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1190 << "\n");
1191 };
1192
1193 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1194 // contained in BB and only separated by instructions that can be
1195 // redundantly executed in parallel. The block BB is split before the first
1196 // call (in MergableCIs) and after the last so the entire region we merge
1197 // into a single parallel region is contained in a single basic block
1198 // without any other instructions. We use the OpenMPIRBuilder to outline
1199 // that block and call the resulting function via __kmpc_fork_call.
1200 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1201 BasicBlock *BB) {
1202 // TODO: Change the interface to allow single CIs expanded, e.g, to
1203 // include an outer loop.
1204 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1205
1206 auto Remark = [&](OptimizationRemark OR) {
1207 OR << "Parallel region merged with parallel region"
1208 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1209 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1210 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1211 if (CI != MergableCIs.back())
1212 OR << ", ";
1213 }
1214 return OR << ".";
1215 };
1216
1217 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1218
1219 Function *OriginalFn = BB->getParent();
1220 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1221 << " parallel regions in " << OriginalFn->getName()
1222 << "\n");
1223
1224 // Isolate the calls to merge in a separate block.
1225 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1226 BasicBlock *AfterBB =
1227 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1228 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1229 "omp.par.merged");
1230
1231 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1232 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1233 BB->getTerminator()->eraseFromParent();
1234
1235 // Create sequential regions for sequential instructions that are
1236 // in-between mergable parallel regions.
1237 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1238 It != End; ++It) {
1239 Instruction *ForkCI = *It;
1240 Instruction *NextForkCI = *(It + 1);
1241
1242 // Continue if there are not in-between instructions.
1243 if (ForkCI->getNextNode() == NextForkCI)
1244 continue;
1245
1246 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1247 NextForkCI->getPrevNode());
1248 }
1249
1250 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1251 DL);
1252 IRBuilder<>::InsertPoint AllocaIP(
1253 &OriginalFn->getEntryBlock(),
1254 OriginalFn->getEntryBlock().getFirstInsertionPt());
1255 // Create the merged parallel region with default proc binding, to
1256 // avoid overriding binding settings, and without explicit cancellation.
1257 OpenMPIRBuilder::InsertPointTy AfterIP =
1258 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1259 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1260 OMP_PROC_BIND_default, /* IsCancellable */ false));
1261 BranchInst::Create(AfterBB, AfterIP.getBlock());
1262
1263 // Perform the actual outlining.
1264 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1265
1266 Function *OutlinedFn = MergableCIs.front()->getCaller();
1267
1268 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1269 // callbacks.
1270 SmallVector<Value *, 8> Args;
1271 for (auto *CI : MergableCIs) {
1272 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1273 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1274 Args.clear();
1275 Args.push_back(OutlinedFn->getArg(0));
1276 Args.push_back(OutlinedFn->getArg(1));
1277 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1278 ++U)
1279 Args.push_back(CI->getArgOperand(U));
1280
1281 CallInst *NewCI =
1282 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1283 if (CI->getDebugLoc())
1284 NewCI->setDebugLoc(CI->getDebugLoc());
1285
1286 // Forward parameter attributes from the callback to the callee.
1287 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1288 ++U)
1289 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1290 NewCI->addParamAttr(
1291 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1292
1293 // Emit an explicit barrier to replace the implicit fork-join barrier.
1294 if (CI != MergableCIs.back()) {
1295 // TODO: Remove barrier if the merged parallel region includes the
1296 // 'nowait' clause.
1297 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1298 InsertPointTy(NewCI->getParent(),
1299 NewCI->getNextNode()->getIterator()),
1300 OMPD_parallel));
1301 }
1302
1303 CI->eraseFromParent();
1304 }
1305
1306 assert(OutlinedFn != OriginalFn && "Outlining failed");
1307 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1308 CGUpdater.reanalyzeFunction(*OriginalFn);
1309
1310 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1311
1312 return true;
1313 };
1314
1315 // Helper function that identifes sequences of
1316 // __kmpc_fork_call uses in a basic block.
1317 auto DetectPRsCB = [&](Use &U, Function &F) {
1318 CallInst *CI = getCallIfRegularCall(U, &RFI);
1319 BB2PRMap[CI->getParent()].insert(CI);
1320
1321 return false;
1322 };
1323
1324 BB2PRMap.clear();
1325 RFI.foreachUse(SCC, DetectPRsCB);
1326 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1327 // Find mergable parallel regions within a basic block that are
1328 // safe to merge, that is any in-between instructions can safely
1329 // execute in parallel after merging.
1330 // TODO: support merging across basic-blocks.
1331 for (auto &It : BB2PRMap) {
1332 auto &CIs = It.getSecond();
1333 if (CIs.size() < 2)
1334 continue;
1335
1336 BasicBlock *BB = It.getFirst();
1337 SmallVector<CallInst *, 4> MergableCIs;
1338
1339 /// Returns true if the instruction is mergable, false otherwise.
1340 /// A terminator instruction is unmergable by definition since merging
1341 /// works within a BB. Instructions before the mergable region are
1342 /// mergable if they are not calls to OpenMP runtime functions that may
1343 /// set different execution parameters for subsequent parallel regions.
1344 /// Instructions in-between parallel regions are mergable if they are not
1345 /// calls to any non-intrinsic function since that may call a non-mergable
1346 /// OpenMP runtime function.
1347 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1348 // We do not merge across BBs, hence return false (unmergable) if the
1349 // instruction is a terminator.
1350 if (I.isTerminator())
1351 return false;
1352
1353 if (!isa<CallInst>(&I))
1354 return true;
1355
1356 CallInst *CI = cast<CallInst>(&I);
1357 if (IsBeforeMergableRegion) {
1358 Function *CalledFunction = CI->getCalledFunction();
1359 if (!CalledFunction)
1360 return false;
1361 // Return false (unmergable) if the call before the parallel
1362 // region calls an explicit affinity (proc_bind) or number of
1363 // threads (num_threads) compiler-generated function. Those settings
1364 // may be incompatible with following parallel regions.
1365 // TODO: ICV tracking to detect compatibility.
1366 for (const auto &RFI : UnmergableCallsInfo) {
1367 if (CalledFunction == RFI.Declaration)
1368 return false;
1369 }
1370 } else {
1371 // Return false (unmergable) if there is a call instruction
1372 // in-between parallel regions when it is not an intrinsic. It
1373 // may call an unmergable OpenMP runtime function in its callpath.
1374 // TODO: Keep track of possible OpenMP calls in the callpath.
1375 if (!isa<IntrinsicInst>(CI))
1376 return false;
1377 }
1378
1379 return true;
1380 };
1381 // Find maximal number of parallel region CIs that are safe to merge.
1382 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1383 Instruction &I = *It;
1384 ++It;
1385
1386 if (CIs.count(&I)) {
1387 MergableCIs.push_back(cast<CallInst>(&I));
1388 continue;
1389 }
1390
1391 // Continue expanding if the instruction is mergable.
1392 if (IsMergable(I, MergableCIs.empty()))
1393 continue;
1394
1395 // Forward the instruction iterator to skip the next parallel region
1396 // since there is an unmergable instruction which can affect it.
1397 for (; It != End; ++It) {
1398 Instruction &SkipI = *It;
1399 if (CIs.count(&SkipI)) {
1400 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1401 << " due to " << I << "\n");
1402 ++It;
1403 break;
1404 }
1405 }
1406
1407 // Store mergable regions found.
1408 if (MergableCIs.size() > 1) {
1409 MergableCIsVector.push_back(MergableCIs);
1410 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1411 << " parallel regions in block " << BB->getName()
1412 << " of function " << BB->getParent()->getName()
1413 << "\n";);
1414 }
1415
1416 MergableCIs.clear();
1417 }
1418
1419 if (!MergableCIsVector.empty()) {
1420 Changed = true;
1421
1422 for (auto &MergableCIs : MergableCIsVector)
1423 Merge(MergableCIs, BB);
1424 MergableCIsVector.clear();
1425 }
1426 }
1427
1428 if (Changed) {
1429 /// Re-collect use for fork calls, emitted barrier calls, and
1430 /// any emitted master/end_master calls.
1431 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1432 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1433 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1434 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1435 }
1436
1437 return Changed;
1438 }
1439
1440 /// Try to delete parallel regions if possible.
1441 bool deleteParallelRegions() {
1442 const unsigned CallbackCalleeOperand = 2;
1443
1444 OMPInformationCache::RuntimeFunctionInfo &RFI =
1445 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1446
1447 if (!RFI.Declaration)
1448 return false;
1449
1450 bool Changed = false;
1451 auto DeleteCallCB = [&](Use &U, Function &) {
1452 CallInst *CI = getCallIfRegularCall(U);
1453 if (!CI)
1454 return false;
1455 auto *Fn = dyn_cast<Function>(
1456 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1457 if (!Fn)
1458 return false;
1459 if (!Fn->onlyReadsMemory())
1460 return false;
1461 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1462 return false;
1463
1464 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1465 << CI->getCaller()->getName() << "\n");
1466
1467 auto Remark = [&](OptimizationRemark OR) {
1468 return OR << "Removing parallel region with no side-effects.";
1469 };
1471
1472 CI->eraseFromParent();
1473 Changed = true;
1474 ++NumOpenMPParallelRegionsDeleted;
1475 return true;
1476 };
1477
1478 RFI.foreachUse(SCC, DeleteCallCB);
1479
1480 return Changed;
1481 }
1482
1483 /// Try to eliminate runtime calls by reusing existing ones.
1484 bool deduplicateRuntimeCalls() {
1485 bool Changed = false;
1486
1487 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1488 OMPRTL_omp_get_num_threads,
1489 OMPRTL_omp_in_parallel,
1490 OMPRTL_omp_get_cancellation,
1491 OMPRTL_omp_get_supported_active_levels,
1492 OMPRTL_omp_get_level,
1493 OMPRTL_omp_get_ancestor_thread_num,
1494 OMPRTL_omp_get_team_size,
1495 OMPRTL_omp_get_active_level,
1496 OMPRTL_omp_in_final,
1497 OMPRTL_omp_get_proc_bind,
1498 OMPRTL_omp_get_num_places,
1499 OMPRTL_omp_get_num_procs,
1500 OMPRTL_omp_get_place_num,
1501 OMPRTL_omp_get_partition_num_places,
1502 OMPRTL_omp_get_partition_place_nums};
1503
1504 // Global-tid is handled separately.
1505 SmallSetVector<Value *, 16> GTIdArgs;
1506 collectGlobalThreadIdArguments(GTIdArgs);
1507 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1508 << " global thread ID arguments\n");
1509
1510 for (Function *F : SCC) {
1511 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1512 Changed |= deduplicateRuntimeCalls(
1513 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1514
1515 // __kmpc_global_thread_num is special as we can replace it with an
1516 // argument in enough cases to make it worth trying.
1517 Value *GTIdArg = nullptr;
1518 for (Argument &Arg : F->args())
1519 if (GTIdArgs.count(&Arg)) {
1520 GTIdArg = &Arg;
1521 break;
1522 }
1523 Changed |= deduplicateRuntimeCalls(
1524 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1525 }
1526
1527 return Changed;
1528 }
1529
1530 /// Tries to remove known runtime symbols that are optional from the module.
1531 bool removeRuntimeSymbols() {
1532 // The RPC client symbol is defined in `libc` and indicates that something
1533 // required an RPC server. If its users were all optimized out then we can
1534 // safely remove it.
1535 // TODO: This should be somewhere more common in the future.
1536 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1537 if (GV->hasNUsesOrMore(1))
1538 return false;
1539
1540 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1541 GV->eraseFromParent();
1542 return true;
1543 }
1544 return false;
1545 }
1546
1547 /// Tries to hide the latency of runtime calls that involve host to
1548 /// device memory transfers by splitting them into their "issue" and "wait"
1549 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1550 /// moved downards as much as possible. The "issue" issues the memory transfer
1551 /// asynchronously, returning a handle. The "wait" waits in the returned
1552 /// handle for the memory transfer to finish.
1553 bool hideMemTransfersLatency() {
1554 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1555 bool Changed = false;
1556 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1557 auto *RTCall = getCallIfRegularCall(U, &RFI);
1558 if (!RTCall)
1559 return false;
1560
1561 OffloadArray OffloadArrays[3];
1562 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1563 return false;
1564
1565 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1566
1567 // TODO: Check if can be moved upwards.
1568 bool WasSplit = false;
1569 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1570 if (WaitMovementPoint)
1571 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1572
1573 Changed |= WasSplit;
1574 return WasSplit;
1575 };
1576 if (OMPInfoCache.runtimeFnsAvailable(
1577 {OMPRTL___tgt_target_data_begin_mapper_issue,
1578 OMPRTL___tgt_target_data_begin_mapper_wait}))
1579 RFI.foreachUse(SCC, SplitMemTransfers);
1580
1581 return Changed;
1582 }
1583
1584 void analysisGlobalization() {
1585 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1586
1587 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1588 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1589 auto Remark = [&](OptimizationRemarkMissed ORM) {
1590 return ORM
1591 << "Found thread data sharing on the GPU. "
1592 << "Expect degraded performance due to data globalization.";
1593 };
1595 }
1596
1597 return false;
1598 };
1599
1600 RFI.foreachUse(SCC, CheckGlobalization);
1601 }
1602
1603 /// Maps the values stored in the offload arrays passed as arguments to
1604 /// \p RuntimeCall into the offload arrays in \p OAs.
1605 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1607 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1608
1609 // A runtime call that involves memory offloading looks something like:
1610 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1611 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1612 // ...)
1613 // So, the idea is to access the allocas that allocate space for these
1614 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1615 // Therefore:
1616 // i8** %offload_baseptrs.
1617 Value *BasePtrsArg =
1618 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1619 // i8** %offload_ptrs.
1620 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1621 // i8** %offload_sizes.
1622 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1623
1624 // Get values stored in **offload_baseptrs.
1625 auto *V = getUnderlyingObject(BasePtrsArg);
1626 if (!isa<AllocaInst>(V))
1627 return false;
1628 auto *BasePtrsArray = cast<AllocaInst>(V);
1629 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1630 return false;
1631
1632 // Get values stored in **offload_baseptrs.
1633 V = getUnderlyingObject(PtrsArg);
1634 if (!isa<AllocaInst>(V))
1635 return false;
1636 auto *PtrsArray = cast<AllocaInst>(V);
1637 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1638 return false;
1639
1640 // Get values stored in **offload_sizes.
1641 V = getUnderlyingObject(SizesArg);
1642 // If it's a [constant] global array don't analyze it.
1643 if (isa<GlobalValue>(V))
1644 return isa<Constant>(V);
1645 if (!isa<AllocaInst>(V))
1646 return false;
1647
1648 auto *SizesArray = cast<AllocaInst>(V);
1649 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1650 return false;
1651
1652 return true;
1653 }
1654
1655 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1656 /// For now this is a way to test that the function getValuesInOffloadArrays
1657 /// is working properly.
1658 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1659 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1660 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1661
1662 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1663 std::string ValuesStr;
1664 raw_string_ostream Printer(ValuesStr);
1665 std::string Separator = " --- ";
1666
1667 for (auto *BP : OAs[0].StoredValues) {
1668 BP->print(Printer);
1669 Printer << Separator;
1670 }
1671 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1672 ValuesStr.clear();
1673
1674 for (auto *P : OAs[1].StoredValues) {
1675 P->print(Printer);
1676 Printer << Separator;
1677 }
1678 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1679 ValuesStr.clear();
1680
1681 for (auto *S : OAs[2].StoredValues) {
1682 S->print(Printer);
1683 Printer << Separator;
1684 }
1685 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1686 }
1687
1688 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1689 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1690 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1691 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1692 // Make it traverse the CFG.
1693
1694 Instruction *CurrentI = &RuntimeCall;
1695 bool IsWorthIt = false;
1696 while ((CurrentI = CurrentI->getNextNode())) {
1697
1698 // TODO: Once we detect the regions to be offloaded we should use the
1699 // alias analysis manager to check if CurrentI may modify one of
1700 // the offloaded regions.
1701 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1702 if (IsWorthIt)
1703 return CurrentI;
1704
1705 return nullptr;
1706 }
1707
1708 // FIXME: For now if we move it over anything without side effect
1709 // is worth it.
1710 IsWorthIt = true;
1711 }
1712
1713 // Return end of BasicBlock.
1714 return RuntimeCall.getParent()->getTerminator();
1715 }
1716
1717 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1718 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1719 Instruction &WaitMovementPoint) {
1720 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1721 // function. Used for storing information of the async transfer, allowing to
1722 // wait on it later.
1723 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1724 Function *F = RuntimeCall.getCaller();
1725 BasicBlock &Entry = F->getEntryBlock();
1726 IRBuilder.Builder.SetInsertPoint(&Entry,
1727 Entry.getFirstNonPHIOrDbgOrAlloca());
1728 Value *Handle = IRBuilder.Builder.CreateAlloca(
1729 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1730 Handle =
1731 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1732
1733 // Add "issue" runtime call declaration:
1734 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1735 // i8**, i8**, i64*, i64*)
1736 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1737 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1738
1739 // Change RuntimeCall call site for its asynchronous version.
1740 SmallVector<Value *, 16> Args;
1741 for (auto &Arg : RuntimeCall.args())
1742 Args.push_back(Arg.get());
1743 Args.push_back(Handle);
1744
1745 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1746 RuntimeCall.getIterator());
1747 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1748 RuntimeCall.eraseFromParent();
1749
1750 // Add "wait" runtime call declaration:
1751 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1752 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1753 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1754
1755 Value *WaitParams[2] = {
1756 IssueCallsite->getArgOperand(
1757 OffloadArray::DeviceIDArgNum), // device_id.
1758 Handle // handle to wait on.
1759 };
1760 CallInst *WaitCallsite = CallInst::Create(
1761 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1762 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1763
1764 return true;
1765 }
1766
1767 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1768 bool GlobalOnly, bool &SingleChoice) {
1769 if (CurrentIdent == NextIdent)
1770 return CurrentIdent;
1771
1772 // TODO: Figure out how to actually combine multiple debug locations. For
1773 // now we just keep an existing one if there is a single choice.
1774 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1775 SingleChoice = !CurrentIdent;
1776 return NextIdent;
1777 }
1778 return nullptr;
1779 }
1780
1781 /// Return an `struct ident_t*` value that represents the ones used in the
1782 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1783 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1784 /// return value we create one from scratch. We also do not yet combine
1785 /// information, e.g., the source locations, see combinedIdentStruct.
1786 Value *
1787 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1788 Function &F, bool GlobalOnly) {
1789 bool SingleChoice = true;
1790 Value *Ident = nullptr;
1791 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1792 CallInst *CI = getCallIfRegularCall(U, &RFI);
1793 if (!CI || &F != &Caller)
1794 return false;
1795 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1796 /* GlobalOnly */ true, SingleChoice);
1797 return false;
1798 };
1799 RFI.foreachUse(SCC, CombineIdentStruct);
1800
1801 if (!Ident || !SingleChoice) {
1802 // The IRBuilder uses the insertion block to get to the module, this is
1803 // unfortunate but we work around it for now.
1804 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1805 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1806 &F.getEntryBlock(), F.getEntryBlock().begin()));
1807 // Create a fallback location if non was found.
1808 // TODO: Use the debug locations of the calls instead.
1809 uint32_t SrcLocStrSize;
1810 Constant *Loc =
1811 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1812 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1813 }
1814 return Ident;
1815 }
1816
1817 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1818 /// \p ReplVal if given.
1819 bool deduplicateRuntimeCalls(Function &F,
1820 OMPInformationCache::RuntimeFunctionInfo &RFI,
1821 Value *ReplVal = nullptr) {
1822 auto *UV = RFI.getUseVector(F);
1823 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1824 return false;
1825
1826 LLVM_DEBUG(
1827 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1828 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1829
1830 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1831 cast<Argument>(ReplVal)->getParent() == &F)) &&
1832 "Unexpected replacement value!");
1833
1834 // TODO: Use dominance to find a good position instead.
1835 auto CanBeMoved = [this](CallBase &CB) {
1836 unsigned NumArgs = CB.arg_size();
1837 if (NumArgs == 0)
1838 return true;
1839 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1840 return false;
1841 for (unsigned U = 1; U < NumArgs; ++U)
1842 if (isa<Instruction>(CB.getArgOperand(U)))
1843 return false;
1844 return true;
1845 };
1846
1847 if (!ReplVal) {
1848 auto *DT =
1849 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1850 if (!DT)
1851 return false;
1852 Instruction *IP = nullptr;
1853 for (Use *U : *UV) {
1854 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1855 if (IP)
1856 IP = DT->findNearestCommonDominator(IP, CI);
1857 else
1858 IP = CI;
1859 if (!CanBeMoved(*CI))
1860 continue;
1861 if (!ReplVal)
1862 ReplVal = CI;
1863 }
1864 }
1865 if (!ReplVal)
1866 return false;
1867 assert(IP && "Expected insertion point!");
1868 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1869 }
1870
1871 // If we use a call as a replacement value we need to make sure the ident is
1872 // valid at the new location. For now we just pick a global one, either
1873 // existing and used by one of the calls, or created from scratch.
1874 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1875 if (!CI->arg_empty() &&
1876 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1877 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1878 /* GlobalOnly */ true);
1879 CI->setArgOperand(0, Ident);
1880 }
1881 }
1882
1883 bool Changed = false;
1884 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1885 CallInst *CI = getCallIfRegularCall(U, &RFI);
1886 if (!CI || CI == ReplVal || &F != &Caller)
1887 return false;
1888 assert(CI->getCaller() == &F && "Unexpected call!");
1889
1890 auto Remark = [&](OptimizationRemark OR) {
1891 return OR << "OpenMP runtime call "
1892 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1893 };
1894 if (CI->getDebugLoc())
1896 else
1898
1899 CI->replaceAllUsesWith(ReplVal);
1900 CI->eraseFromParent();
1901 ++NumOpenMPRuntimeCallsDeduplicated;
1902 Changed = true;
1903 return true;
1904 };
1905 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1906
1907 return Changed;
1908 }
1909
1910 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1911 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1912 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1913 // initialization. We could define an AbstractAttribute instead and
1914 // run the Attributor here once it can be run as an SCC pass.
1915
1916 // Helper to check the argument \p ArgNo at all call sites of \p F for
1917 // a GTId.
1918 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1919 if (!F.hasLocalLinkage())
1920 return false;
1921 for (Use &U : F.uses()) {
1922 if (CallInst *CI = getCallIfRegularCall(U)) {
1923 Value *ArgOp = CI->getArgOperand(ArgNo);
1924 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1925 getCallIfRegularCall(
1926 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1927 continue;
1928 }
1929 return false;
1930 }
1931 return true;
1932 };
1933
1934 // Helper to identify uses of a GTId as GTId arguments.
1935 auto AddUserArgs = [&](Value &GTId) {
1936 for (Use &U : GTId.uses())
1937 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1938 if (CI->isArgOperand(&U))
1939 if (Function *Callee = CI->getCalledFunction())
1940 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1941 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1942 };
1943
1944 // The argument users of __kmpc_global_thread_num calls are GTIds.
1945 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1946 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1947
1948 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1949 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1950 AddUserArgs(*CI);
1951 return false;
1952 });
1953
1954 // Transitively search for more arguments by looking at the users of the
1955 // ones we know already. During the search the GTIdArgs vector is extended
1956 // so we cannot cache the size nor can we use a range based for.
1957 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1958 AddUserArgs(*GTIdArgs[U]);
1959 }
1960
1961 /// Kernel (=GPU) optimizations and utility functions
1962 ///
1963 ///{{
1964
1965 /// Cache to remember the unique kernel for a function.
1966 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1967
1968 /// Find the unique kernel that will execute \p F, if any.
1969 Kernel getUniqueKernelFor(Function &F);
1970
1971 /// Find the unique kernel that will execute \p I, if any.
1972 Kernel getUniqueKernelFor(Instruction &I) {
1973 return getUniqueKernelFor(*I.getFunction());
1974 }
1975
1976 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1977 /// the cases we can avoid taking the address of a function.
1978 bool rewriteDeviceCodeStateMachine();
1979
1980 ///
1981 ///}}
1982
1983 /// Emit a remark generically
1984 ///
1985 /// This template function can be used to generically emit a remark. The
1986 /// RemarkKind should be one of the following:
1987 /// - OptimizationRemark to indicate a successful optimization attempt
1988 /// - OptimizationRemarkMissed to report a failed optimization attempt
1989 /// - OptimizationRemarkAnalysis to provide additional information about an
1990 /// optimization attempt
1991 ///
1992 /// The remark is built using a callback function provided by the caller that
1993 /// takes a RemarkKind as input and returns a RemarkKind.
1994 template <typename RemarkKind, typename RemarkCallBack>
1995 void emitRemark(Instruction *I, StringRef RemarkName,
1996 RemarkCallBack &&RemarkCB) const {
1997 Function *F = I->getParent()->getParent();
1998 auto &ORE = OREGetter(F);
1999
2000 if (RemarkName.starts_with("OMP"))
2001 ORE.emit([&]() {
2002 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2003 << " [" << RemarkName << "]";
2004 });
2005 else
2006 ORE.emit(
2007 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2008 }
2009
2010 /// Emit a remark on a function.
2011 template <typename RemarkKind, typename RemarkCallBack>
2012 void emitRemark(Function *F, StringRef RemarkName,
2013 RemarkCallBack &&RemarkCB) const {
2014 auto &ORE = OREGetter(F);
2015
2016 if (RemarkName.starts_with("OMP"))
2017 ORE.emit([&]() {
2018 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2019 << " [" << RemarkName << "]";
2020 });
2021 else
2022 ORE.emit(
2023 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2024 }
2025
2026 /// The underlying module.
2027 Module &M;
2028
2029 /// The SCC we are operating on.
2030 SmallVectorImpl<Function *> &SCC;
2031
2032 /// Callback to update the call graph, the first argument is a removed call,
2033 /// the second an optional replacement call.
2034 CallGraphUpdater &CGUpdater;
2035
2036 /// Callback to get an OptimizationRemarkEmitter from a Function *
2037 OptimizationRemarkGetter OREGetter;
2038
2039 /// OpenMP-specific information cache. Also Used for Attributor runs.
2040 OMPInformationCache &OMPInfoCache;
2041
2042 /// Attributor instance.
2043 Attributor &A;
2044
2045 /// Helper function to run Attributor on SCC.
2046 bool runAttributor(bool IsModulePass) {
2047 if (SCC.empty())
2048 return false;
2049
2050 registerAAs(IsModulePass);
2051
2052 ChangeStatus Changed = A.run();
2053
2054 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2055 << " functions, result: " << Changed << ".\n");
2056
2057 if (Changed == ChangeStatus::CHANGED)
2058 OMPInfoCache.invalidateAnalyses();
2059
2060 return Changed == ChangeStatus::CHANGED;
2061 }
2062
2063 void registerFoldRuntimeCall(RuntimeFunction RF);
2064
2065 /// Populate the Attributor with abstract attribute opportunities in the
2066 /// functions.
2067 void registerAAs(bool IsModulePass);
2068
2069public:
2070 /// Callback to register AAs for live functions, including internal functions
2071 /// marked live during the traversal.
2072 static void registerAAsForFunction(Attributor &A, const Function &F);
2073};
2074
2075Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2076 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2077 !OMPInfoCache.CGSCC->contains(&F))
2078 return nullptr;
2079
2080 // Use a scope to keep the lifetime of the CachedKernel short.
2081 {
2082 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2083 if (CachedKernel)
2084 return *CachedKernel;
2085
2086 // TODO: We should use an AA to create an (optimistic and callback
2087 // call-aware) call graph. For now we stick to simple patterns that
2088 // are less powerful, basically the worst fixpoint.
2089 if (isOpenMPKernel(F)) {
2090 CachedKernel = Kernel(&F);
2091 return *CachedKernel;
2092 }
2093
2094 CachedKernel = nullptr;
2095 if (!F.hasLocalLinkage()) {
2096
2097 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2098 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2099 return ORA << "Potentially unknown OpenMP target region caller.";
2100 };
2102
2103 return nullptr;
2104 }
2105 }
2106
2107 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2108 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2109 // Allow use in equality comparisons.
2110 if (Cmp->isEquality())
2111 return getUniqueKernelFor(*Cmp);
2112 return nullptr;
2113 }
2114 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2115 // Allow direct calls.
2116 if (CB->isCallee(&U))
2117 return getUniqueKernelFor(*CB);
2118
2119 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2120 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2121 // Allow the use in __kmpc_parallel_51 calls.
2122 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2123 return getUniqueKernelFor(*CB);
2124 return nullptr;
2125 }
2126 // Disallow every other use.
2127 return nullptr;
2128 };
2129
2130 // TODO: In the future we want to track more than just a unique kernel.
2131 SmallPtrSet<Kernel, 2> PotentialKernels;
2132 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2133 PotentialKernels.insert(GetUniqueKernelForUse(U));
2134 });
2135
2136 Kernel K = nullptr;
2137 if (PotentialKernels.size() == 1)
2138 K = *PotentialKernels.begin();
2139
2140 // Cache the result.
2141 UniqueKernelMap[&F] = K;
2142
2143 return K;
2144}
2145
2146bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2147 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2148 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2149
2150 bool Changed = false;
2151 if (!KernelParallelRFI)
2152 return Changed;
2153
2154 // If we have disabled state machine changes, exit
2156 return Changed;
2157
2158 for (Function *F : SCC) {
2159
2160 // Check if the function is a use in a __kmpc_parallel_51 call at
2161 // all.
2162 bool UnknownUse = false;
2163 bool KernelParallelUse = false;
2164 unsigned NumDirectCalls = 0;
2165
2166 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2167 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2168 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2169 if (CB->isCallee(&U)) {
2170 ++NumDirectCalls;
2171 return;
2172 }
2173
2174 if (isa<ICmpInst>(U.getUser())) {
2175 ToBeReplacedStateMachineUses.push_back(&U);
2176 return;
2177 }
2178
2179 // Find wrapper functions that represent parallel kernels.
2180 CallInst *CI =
2181 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2182 const unsigned int WrapperFunctionArgNo = 6;
2183 if (!KernelParallelUse && CI &&
2184 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2185 KernelParallelUse = true;
2186 ToBeReplacedStateMachineUses.push_back(&U);
2187 return;
2188 }
2189 UnknownUse = true;
2190 });
2191
2192 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2193 // use.
2194 if (!KernelParallelUse)
2195 continue;
2196
2197 // If this ever hits, we should investigate.
2198 // TODO: Checking the number of uses is not a necessary restriction and
2199 // should be lifted.
2200 if (UnknownUse || NumDirectCalls != 1 ||
2201 ToBeReplacedStateMachineUses.size() > 2) {
2202 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2203 return ORA << "Parallel region is used in "
2204 << (UnknownUse ? "unknown" : "unexpected")
2205 << " ways. Will not attempt to rewrite the state machine.";
2206 };
2208 continue;
2209 }
2210
2211 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2212 // up if the function is not called from a unique kernel.
2213 Kernel K = getUniqueKernelFor(*F);
2214 if (!K) {
2215 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2216 return ORA << "Parallel region is not called from a unique kernel. "
2217 "Will not attempt to rewrite the state machine.";
2218 };
2220 continue;
2221 }
2222
2223 // We now know F is a parallel body function called only from the kernel K.
2224 // We also identified the state machine uses in which we replace the
2225 // function pointer by a new global symbol for identification purposes. This
2226 // ensures only direct calls to the function are left.
2227
2228 Module &M = *F->getParent();
2229 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2230
2231 auto *ID = new GlobalVariable(
2232 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2233 UndefValue::get(Int8Ty), F->getName() + ".ID");
2234
2235 for (Use *U : ToBeReplacedStateMachineUses)
2237 ID, U->get()->getType()));
2238
2239 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2240
2241 Changed = true;
2242 }
2243
2244 return Changed;
2245}
2246
2247/// Abstract Attribute for tracking ICV values.
2248struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2249 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2250 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2251
2252 /// Returns true if value is assumed to be tracked.
2253 bool isAssumedTracked() const { return getAssumed(); }
2254
2255 /// Returns true if value is known to be tracked.
2256 bool isKnownTracked() const { return getAssumed(); }
2257
2258 /// Create an abstract attribute biew for the position \p IRP.
2259 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2260
2261 /// Return the value with which \p I can be replaced for specific \p ICV.
2262 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2263 const Instruction *I,
2264 Attributor &A) const {
2265 return std::nullopt;
2266 }
2267
2268 /// Return an assumed unique ICV value if a single candidate is found. If
2269 /// there cannot be one, return a nullptr. If it is not clear yet, return
2270 /// std::nullopt.
2271 virtual std::optional<Value *>
2272 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2273
2274 // Currently only nthreads is being tracked.
2275 // this array will only grow with time.
2276 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2277
2278 /// See AbstractAttribute::getName()
2279 StringRef getName() const override { return "AAICVTracker"; }
2280
2281 /// See AbstractAttribute::getIdAddr()
2282 const char *getIdAddr() const override { return &ID; }
2283
2284 /// This function should return true if the type of the \p AA is AAICVTracker
2285 static bool classof(const AbstractAttribute *AA) {
2286 return (AA->getIdAddr() == &ID);
2287 }
2288
2289 static const char ID;
2290};
2291
2292struct AAICVTrackerFunction : public AAICVTracker {
2293 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2294 : AAICVTracker(IRP, A) {}
2295
2296 // FIXME: come up with better string.
2297 const std::string getAsStr(Attributor *) const override {
2298 return "ICVTrackerFunction";
2299 }
2300
2301 // FIXME: come up with some stats.
2302 void trackStatistics() const override {}
2303
2304 /// We don't manifest anything for this AA.
2305 ChangeStatus manifest(Attributor &A) override {
2306 return ChangeStatus::UNCHANGED;
2307 }
2308
2309 // Map of ICV to their values at specific program point.
2310 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2311 InternalControlVar::ICV___last>
2312 ICVReplacementValuesMap;
2313
2314 ChangeStatus updateImpl(Attributor &A) override {
2315 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2316
2317 Function *F = getAnchorScope();
2318
2319 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2320
2321 for (InternalControlVar ICV : TrackableICVs) {
2322 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2323
2324 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2325 auto TrackValues = [&](Use &U, Function &) {
2326 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2327 if (!CI)
2328 return false;
2329
2330 // FIXME: handle setters with more that 1 arguments.
2331 /// Track new value.
2332 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2333 HasChanged = ChangeStatus::CHANGED;
2334
2335 return false;
2336 };
2337
2338 auto CallCheck = [&](Instruction &I) {
2339 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2340 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2341 HasChanged = ChangeStatus::CHANGED;
2342
2343 return true;
2344 };
2345
2346 // Track all changes of an ICV.
2347 SetterRFI.foreachUse(TrackValues, F);
2348
2349 bool UsedAssumedInformation = false;
2350 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2351 UsedAssumedInformation,
2352 /* CheckBBLivenessOnly */ true);
2353
2354 /// TODO: Figure out a way to avoid adding entry in
2355 /// ICVReplacementValuesMap
2356 Instruction *Entry = &F->getEntryBlock().front();
2357 if (HasChanged == ChangeStatus::CHANGED)
2358 ValuesMap.try_emplace(Entry);
2359 }
2360
2361 return HasChanged;
2362 }
2363
2364 /// Helper to check if \p I is a call and get the value for it if it is
2365 /// unique.
2366 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2367 InternalControlVar &ICV) const {
2368
2369 const auto *CB = dyn_cast<CallBase>(&I);
2370 if (!CB || CB->hasFnAttr("no_openmp") ||
2371 CB->hasFnAttr("no_openmp_routines") ||
2372 CB->hasFnAttr("no_openmp_constructs"))
2373 return std::nullopt;
2374
2375 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2376 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2377 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2378 Function *CalledFunction = CB->getCalledFunction();
2379
2380 // Indirect call, assume ICV changes.
2381 if (CalledFunction == nullptr)
2382 return nullptr;
2383 if (CalledFunction == GetterRFI.Declaration)
2384 return std::nullopt;
2385 if (CalledFunction == SetterRFI.Declaration) {
2386 if (ICVReplacementValuesMap[ICV].count(&I))
2387 return ICVReplacementValuesMap[ICV].lookup(&I);
2388
2389 return nullptr;
2390 }
2391
2392 // Since we don't know, assume it changes the ICV.
2393 if (CalledFunction->isDeclaration())
2394 return nullptr;
2395
2396 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2397 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2398
2399 if (ICVTrackingAA->isAssumedTracked()) {
2400 std::optional<Value *> URV =
2401 ICVTrackingAA->getUniqueReplacementValue(ICV);
2402 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2403 OMPInfoCache)))
2404 return URV;
2405 }
2406
2407 // If we don't know, assume it changes.
2408 return nullptr;
2409 }
2410
2411 // We don't check unique value for a function, so return std::nullopt.
2412 std::optional<Value *>
2413 getUniqueReplacementValue(InternalControlVar ICV) const override {
2414 return std::nullopt;
2415 }
2416
2417 /// Return the value with which \p I can be replaced for specific \p ICV.
2418 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2419 const Instruction *I,
2420 Attributor &A) const override {
2421 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2422 if (ValuesMap.count(I))
2423 return ValuesMap.lookup(I);
2424
2426 SmallPtrSet<const Instruction *, 16> Visited;
2427 Worklist.push_back(I);
2428
2429 std::optional<Value *> ReplVal;
2430
2431 while (!Worklist.empty()) {
2432 const Instruction *CurrInst = Worklist.pop_back_val();
2433 if (!Visited.insert(CurrInst).second)
2434 continue;
2435
2436 const BasicBlock *CurrBB = CurrInst->getParent();
2437
2438 // Go up and look for all potential setters/calls that might change the
2439 // ICV.
2440 while ((CurrInst = CurrInst->getPrevNode())) {
2441 if (ValuesMap.count(CurrInst)) {
2442 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2443 // Unknown value, track new.
2444 if (!ReplVal) {
2445 ReplVal = NewReplVal;
2446 break;
2447 }
2448
2449 // If we found a new value, we can't know the icv value anymore.
2450 if (NewReplVal)
2451 if (ReplVal != NewReplVal)
2452 return nullptr;
2453
2454 break;
2455 }
2456
2457 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2458 if (!NewReplVal)
2459 continue;
2460
2461 // Unknown value, track new.
2462 if (!ReplVal) {
2463 ReplVal = NewReplVal;
2464 break;
2465 }
2466
2467 // if (NewReplVal.hasValue())
2468 // We found a new value, we can't know the icv value anymore.
2469 if (ReplVal != NewReplVal)
2470 return nullptr;
2471 }
2472
2473 // If we are in the same BB and we have a value, we are done.
2474 if (CurrBB == I->getParent() && ReplVal)
2475 return ReplVal;
2476
2477 // Go through all predecessors and add terminators for analysis.
2478 for (const BasicBlock *Pred : predecessors(CurrBB))
2479 if (const Instruction *Terminator = Pred->getTerminator())
2480 Worklist.push_back(Terminator);
2481 }
2482
2483 return ReplVal;
2484 }
2485};
2486
2487struct AAICVTrackerFunctionReturned : AAICVTracker {
2488 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2489 : AAICVTracker(IRP, A) {}
2490
2491 // FIXME: come up with better string.
2492 const std::string getAsStr(Attributor *) const override {
2493 return "ICVTrackerFunctionReturned";
2494 }
2495
2496 // FIXME: come up with some stats.
2497 void trackStatistics() const override {}
2498
2499 /// We don't manifest anything for this AA.
2500 ChangeStatus manifest(Attributor &A) override {
2501 return ChangeStatus::UNCHANGED;
2502 }
2503
2504 // Map of ICV to their values at specific program point.
2505 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2506 InternalControlVar::ICV___last>
2507 ICVReplacementValuesMap;
2508
2509 /// Return the value with which \p I can be replaced for specific \p ICV.
2510 std::optional<Value *>
2511 getUniqueReplacementValue(InternalControlVar ICV) const override {
2512 return ICVReplacementValuesMap[ICV];
2513 }
2514
2515 ChangeStatus updateImpl(Attributor &A) override {
2516 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2517 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2518 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2519
2520 if (!ICVTrackingAA->isAssumedTracked())
2521 return indicatePessimisticFixpoint();
2522
2523 for (InternalControlVar ICV : TrackableICVs) {
2524 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2525 std::optional<Value *> UniqueICVValue;
2526
2527 auto CheckReturnInst = [&](Instruction &I) {
2528 std::optional<Value *> NewReplVal =
2529 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2530
2531 // If we found a second ICV value there is no unique returned value.
2532 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2533 return false;
2534
2535 UniqueICVValue = NewReplVal;
2536
2537 return true;
2538 };
2539
2540 bool UsedAssumedInformation = false;
2541 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2542 UsedAssumedInformation,
2543 /* CheckBBLivenessOnly */ true))
2544 UniqueICVValue = nullptr;
2545
2546 if (UniqueICVValue == ReplVal)
2547 continue;
2548
2549 ReplVal = UniqueICVValue;
2550 Changed = ChangeStatus::CHANGED;
2551 }
2552
2553 return Changed;
2554 }
2555};
2556
2557struct AAICVTrackerCallSite : AAICVTracker {
2558 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2559 : AAICVTracker(IRP, A) {}
2560
2561 void initialize(Attributor &A) override {
2562 assert(getAnchorScope() && "Expected anchor function");
2563
2564 // We only initialize this AA for getters, so we need to know which ICV it
2565 // gets.
2566 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2567 for (InternalControlVar ICV : TrackableICVs) {
2568 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2569 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2570 if (Getter.Declaration == getAssociatedFunction()) {
2571 AssociatedICV = ICVInfo.Kind;
2572 return;
2573 }
2574 }
2575
2576 /// Unknown ICV.
2577 indicatePessimisticFixpoint();
2578 }
2579
2580 ChangeStatus manifest(Attributor &A) override {
2581 if (!ReplVal || !*ReplVal)
2582 return ChangeStatus::UNCHANGED;
2583
2584 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2585 A.deleteAfterManifest(*getCtxI());
2586
2587 return ChangeStatus::CHANGED;
2588 }
2589
2590 // FIXME: come up with better string.
2591 const std::string getAsStr(Attributor *) const override {
2592 return "ICVTrackerCallSite";
2593 }
2594
2595 // FIXME: come up with some stats.
2596 void trackStatistics() const override {}
2597
2598 InternalControlVar AssociatedICV;
2599 std::optional<Value *> ReplVal;
2600
2601 ChangeStatus updateImpl(Attributor &A) override {
2602 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2603 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2604
2605 // We don't have any information, so we assume it changes the ICV.
2606 if (!ICVTrackingAA->isAssumedTracked())
2607 return indicatePessimisticFixpoint();
2608
2609 std::optional<Value *> NewReplVal =
2610 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2611
2612 if (ReplVal == NewReplVal)
2613 return ChangeStatus::UNCHANGED;
2614
2615 ReplVal = NewReplVal;
2616 return ChangeStatus::CHANGED;
2617 }
2618
2619 // Return the value with which associated value can be replaced for specific
2620 // \p ICV.
2621 std::optional<Value *>
2622 getUniqueReplacementValue(InternalControlVar ICV) const override {
2623 return ReplVal;
2624 }
2625};
2626
2627struct AAICVTrackerCallSiteReturned : AAICVTracker {
2628 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2629 : AAICVTracker(IRP, A) {}
2630
2631 // FIXME: come up with better string.
2632 const std::string getAsStr(Attributor *) const override {
2633 return "ICVTrackerCallSiteReturned";
2634 }
2635
2636 // FIXME: come up with some stats.
2637 void trackStatistics() const override {}
2638
2639 /// We don't manifest anything for this AA.
2640 ChangeStatus manifest(Attributor &A) override {
2641 return ChangeStatus::UNCHANGED;
2642 }
2643
2644 // Map of ICV to their values at specific program point.
2645 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2646 InternalControlVar::ICV___last>
2647 ICVReplacementValuesMap;
2648
2649 /// Return the value with which associated value can be replaced for specific
2650 /// \p ICV.
2651 std::optional<Value *>
2652 getUniqueReplacementValue(InternalControlVar ICV) const override {
2653 return ICVReplacementValuesMap[ICV];
2654 }
2655
2656 ChangeStatus updateImpl(Attributor &A) override {
2657 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2658 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2659 *this, IRPosition::returned(*getAssociatedFunction()),
2660 DepClassTy::REQUIRED);
2661
2662 // We don't have any information, so we assume it changes the ICV.
2663 if (!ICVTrackingAA->isAssumedTracked())
2664 return indicatePessimisticFixpoint();
2665
2666 for (InternalControlVar ICV : TrackableICVs) {
2667 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2668 std::optional<Value *> NewReplVal =
2669 ICVTrackingAA->getUniqueReplacementValue(ICV);
2670
2671 if (ReplVal == NewReplVal)
2672 continue;
2673
2674 ReplVal = NewReplVal;
2675 Changed = ChangeStatus::CHANGED;
2676 }
2677 return Changed;
2678 }
2679};
2680
2681/// Determines if \p BB exits the function unconditionally itself or reaches a
2682/// block that does through only unique successors.
2683static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2684 if (succ_empty(BB))
2685 return true;
2686 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2687 if (!Successor)
2688 return false;
2689 return hasFunctionEndAsUniqueSuccessor(Successor);
2690}
2691
2692struct AAExecutionDomainFunction : public AAExecutionDomain {
2693 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2694 : AAExecutionDomain(IRP, A) {}
2695
2696 ~AAExecutionDomainFunction() { delete RPOT; }
2697
2698 void initialize(Attributor &A) override {
2699 Function *F = getAnchorScope();
2700 assert(F && "Expected anchor function");
2701 RPOT = new ReversePostOrderTraversal<Function *>(F);
2702 }
2703
2704 const std::string getAsStr(Attributor *) const override {
2705 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2706 for (auto &It : BEDMap) {
2707 if (!It.getFirst())
2708 continue;
2709 TotalBlocks++;
2710 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2711 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2712 It.getSecond().IsReachingAlignedBarrierOnly;
2713 }
2714 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2715 std::to_string(AlignedBlocks) + " of " +
2716 std::to_string(TotalBlocks) +
2717 " executed by initial thread / aligned";
2718 }
2719
2720 /// See AbstractAttribute::trackStatistics().
2721 void trackStatistics() const override {}
2722
2723 ChangeStatus manifest(Attributor &A) override {
2724 LLVM_DEBUG({
2725 for (const BasicBlock &BB : *getAnchorScope()) {
2726 if (!isExecutedByInitialThreadOnly(BB))
2727 continue;
2728 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2729 << BB.getName() << " is executed by a single thread.\n";
2730 }
2731 });
2732
2733 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2734
2736 return Changed;
2737
2738 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2739 auto HandleAlignedBarrier = [&](CallBase *CB) {
2740 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2741 if (!ED.IsReachedFromAlignedBarrierOnly ||
2742 ED.EncounteredNonLocalSideEffect)
2743 return;
2744 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2745 return;
2746
2747 // We can remove this barrier, if it is one, or aligned barriers reaching
2748 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2749 // end should only be removed if the kernel end is their unique successor;
2750 // otherwise, they may have side-effects that aren't accounted for in the
2751 // kernel end in their other successors. If those barriers have other
2752 // barriers reaching them, those can be transitively removed as well as
2753 // long as the kernel end is also their unique successor.
2754 if (CB) {
2755 DeletedBarriers.insert(CB);
2756 A.deleteAfterManifest(*CB);
2757 ++NumBarriersEliminated;
2758 Changed = ChangeStatus::CHANGED;
2759 } else if (!ED.AlignedBarriers.empty()) {
2760 Changed = ChangeStatus::CHANGED;
2761 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2762 ED.AlignedBarriers.end());
2763 SmallSetVector<CallBase *, 16> Visited;
2764 while (!Worklist.empty()) {
2765 CallBase *LastCB = Worklist.pop_back_val();
2766 if (!Visited.insert(LastCB))
2767 continue;
2768 if (LastCB->getFunction() != getAnchorScope())
2769 continue;
2770 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2771 continue;
2772 if (!DeletedBarriers.count(LastCB)) {
2773 ++NumBarriersEliminated;
2774 A.deleteAfterManifest(*LastCB);
2775 continue;
2776 }
2777 // The final aligned barrier (LastCB) reaching the kernel end was
2778 // removed already. This means we can go one step further and remove
2779 // the barriers encoutered last before (LastCB).
2780 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2781 Worklist.append(LastED.AlignedBarriers.begin(),
2782 LastED.AlignedBarriers.end());
2783 }
2784 }
2785
2786 // If we actually eliminated a barrier we need to eliminate the associated
2787 // llvm.assumes as well to avoid creating UB.
2788 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2789 for (auto *AssumeCB : ED.EncounteredAssumes)
2790 A.deleteAfterManifest(*AssumeCB);
2791 };
2792
2793 for (auto *CB : AlignedBarriers)
2794 HandleAlignedBarrier(CB);
2795
2796 // Handle the "kernel end barrier" for kernels too.
2797 if (omp::isOpenMPKernel(*getAnchorScope()))
2798 HandleAlignedBarrier(nullptr);
2799
2800 return Changed;
2801 }
2802
2803 bool isNoOpFence(const FenceInst &FI) const override {
2804 return getState().isValidState() && !NonNoOpFences.count(&FI);
2805 }
2806
2807 /// Merge barrier and assumption information from \p PredED into the successor
2808 /// \p ED.
2809 void
2810 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2811 const ExecutionDomainTy &PredED);
2812
2813 /// Merge all information from \p PredED into the successor \p ED. If
2814 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2815 /// represented by \p ED from this predecessor.
2816 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2817 const ExecutionDomainTy &PredED,
2818 bool InitialEdgeOnly = false);
2819
2820 /// Accumulate information for the entry block in \p EntryBBED.
2821 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2822
2823 /// See AbstractAttribute::updateImpl.
2824 ChangeStatus updateImpl(Attributor &A) override;
2825
2826 /// Query interface, see AAExecutionDomain
2827 ///{
2828 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2829 if (!isValidState())
2830 return false;
2831 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2832 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2833 }
2834
2835 bool isExecutedInAlignedRegion(Attributor &A,
2836 const Instruction &I) const override {
2837 assert(I.getFunction() == getAnchorScope() &&
2838 "Instruction is out of scope!");
2839 if (!isValidState())
2840 return false;
2841
2842 bool ForwardIsOk = true;
2843 const Instruction *CurI;
2844
2845 // Check forward until a call or the block end is reached.
2846 CurI = &I;
2847 do {
2848 auto *CB = dyn_cast<CallBase>(CurI);
2849 if (!CB)
2850 continue;
2851 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2852 return true;
2853 const auto &It = CEDMap.find({CB, PRE});
2854 if (It == CEDMap.end())
2855 continue;
2856 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2857 ForwardIsOk = false;
2858 break;
2859 } while ((CurI = CurI->getNextNode()));
2860
2861 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2862 ForwardIsOk = false;
2863
2864 // Check backward until a call or the block beginning is reached.
2865 CurI = &I;
2866 do {
2867 auto *CB = dyn_cast<CallBase>(CurI);
2868 if (!CB)
2869 continue;
2870 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2871 return true;
2872 const auto &It = CEDMap.find({CB, POST});
2873 if (It == CEDMap.end())
2874 continue;
2875 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2876 break;
2877 return false;
2878 } while ((CurI = CurI->getPrevNode()));
2879
2880 // Delayed decision on the forward pass to allow aligned barrier detection
2881 // in the backwards traversal.
2882 if (!ForwardIsOk)
2883 return false;
2884
2885 if (!CurI) {
2886 const BasicBlock *BB = I.getParent();
2887 if (BB == &BB->getParent()->getEntryBlock())
2888 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2889 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2890 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2891 })) {
2892 return false;
2893 }
2894 }
2895
2896 // On neither traversal we found a anything but aligned barriers.
2897 return true;
2898 }
2899
2900 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2901 assert(isValidState() &&
2902 "No request should be made against an invalid state!");
2903 return BEDMap.lookup(&BB);
2904 }
2905 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2906 getExecutionDomain(const CallBase &CB) const override {
2907 assert(isValidState() &&
2908 "No request should be made against an invalid state!");
2909 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2910 }
2911 ExecutionDomainTy getFunctionExecutionDomain() const override {
2912 assert(isValidState() &&
2913 "No request should be made against an invalid state!");
2914 return InterProceduralED;
2915 }
2916 ///}
2917
2918 // Check if the edge into the successor block contains a condition that only
2919 // lets the main thread execute it.
2920 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2921 BasicBlock &SuccessorBB) {
2922 if (!Edge || !Edge->isConditional())
2923 return false;
2924 if (Edge->getSuccessor(0) != &SuccessorBB)
2925 return false;
2926
2927 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2928 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2929 return false;
2930
2931 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2932 if (!C)
2933 return false;
2934
2935 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2936 if (C->isAllOnesValue()) {
2937 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2938 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2939 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2940 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2941 if (!CB)
2942 return false;
2943 ConstantStruct *KernelEnvC =
2945 ConstantInt *ExecModeC =
2946 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2947 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2948 }
2949
2950 if (C->isZero()) {
2951 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2952 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2953 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2954 return true;
2955
2956 // Match: 0 == llvm.amdgcn.workitem.id.x()
2957 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2958 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2959 return true;
2960 }
2961
2962 return false;
2963 };
2964
2965 /// Mapping containing information about the function for other AAs.
2966 ExecutionDomainTy InterProceduralED;
2967
2968 enum Direction { PRE = 0, POST = 1 };
2969 /// Mapping containing information per block.
2970 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2971 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2972 CEDMap;
2973 SmallSetVector<CallBase *, 16> AlignedBarriers;
2974
2975 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2976
2977 /// Set \p R to \V and report true if that changed \p R.
2978 static bool setAndRecord(bool &R, bool V) {
2979 bool Eq = (R == V);
2980 R = V;
2981 return !Eq;
2982 }
2983
2984 /// Collection of fences known to be non-no-opt. All fences not in this set
2985 /// can be assumed no-opt.
2986 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2987};
2988
2989void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2990 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2991 for (auto *EA : PredED.EncounteredAssumes)
2992 ED.addAssumeInst(A, *EA);
2993
2994 for (auto *AB : PredED.AlignedBarriers)
2995 ED.addAlignedBarrier(A, *AB);
2996}
2997
2998bool AAExecutionDomainFunction::mergeInPredecessor(
2999 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3000 bool InitialEdgeOnly) {
3001
3002 bool Changed = false;
3003 Changed |=
3004 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3005 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3006 ED.IsExecutedByInitialThreadOnly));
3007
3008 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3009 ED.IsReachedFromAlignedBarrierOnly &&
3010 PredED.IsReachedFromAlignedBarrierOnly);
3011 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3012 ED.EncounteredNonLocalSideEffect |
3013 PredED.EncounteredNonLocalSideEffect);
3014 // Do not track assumptions and barriers as part of Changed.
3015 if (ED.IsReachedFromAlignedBarrierOnly)
3016 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3017 else
3018 ED.clearAssumeInstAndAlignedBarriers();
3019 return Changed;
3020}
3021
3022bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3023 ExecutionDomainTy &EntryBBED) {
3025 auto PredForCallSite = [&](AbstractCallSite ACS) {
3026 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3027 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3028 DepClassTy::OPTIONAL);
3029 if (!EDAA || !EDAA->getState().isValidState())
3030 return false;
3031 CallSiteEDs.emplace_back(
3032 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3033 return true;
3034 };
3035
3036 ExecutionDomainTy ExitED;
3037 bool AllCallSitesKnown;
3038 if (A.checkForAllCallSites(PredForCallSite, *this,
3039 /* RequiresAllCallSites */ true,
3040 AllCallSitesKnown)) {
3041 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3042 mergeInPredecessor(A, EntryBBED, CSInED);
3043 ExitED.IsReachingAlignedBarrierOnly &=
3044 CSOutED.IsReachingAlignedBarrierOnly;
3045 }
3046
3047 } else {
3048 // We could not find all predecessors, so this is either a kernel or a
3049 // function with external linkage (or with some other weird uses).
3050 if (omp::isOpenMPKernel(*getAnchorScope())) {
3051 EntryBBED.IsExecutedByInitialThreadOnly = false;
3052 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3053 EntryBBED.EncounteredNonLocalSideEffect = false;
3054 ExitED.IsReachingAlignedBarrierOnly = false;
3055 } else {
3056 EntryBBED.IsExecutedByInitialThreadOnly = false;
3057 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3058 EntryBBED.EncounteredNonLocalSideEffect = true;
3059 ExitED.IsReachingAlignedBarrierOnly = false;
3060 }
3061 }
3062
3063 bool Changed = false;
3064 auto &FnED = BEDMap[nullptr];
3065 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3066 FnED.IsReachedFromAlignedBarrierOnly &
3067 EntryBBED.IsReachedFromAlignedBarrierOnly);
3068 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3069 FnED.IsReachingAlignedBarrierOnly &
3070 ExitED.IsReachingAlignedBarrierOnly);
3071 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3072 EntryBBED.IsExecutedByInitialThreadOnly);
3073 return Changed;
3074}
3075
3076ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3077
3078 bool Changed = false;
3079
3080 // Helper to deal with an aligned barrier encountered during the forward
3081 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3082 // it was encountered.
3083 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3084 Changed |= AlignedBarriers.insert(&CB);
3085 // First, update the barrier ED kept in the separate CEDMap.
3086 auto &CallInED = CEDMap[{&CB, PRE}];
3087 Changed |= mergeInPredecessor(A, CallInED, ED);
3088 CallInED.IsReachingAlignedBarrierOnly = true;
3089 // Next adjust the ED we use for the traversal.
3090 ED.EncounteredNonLocalSideEffect = false;
3091 ED.IsReachedFromAlignedBarrierOnly = true;
3092 // Aligned barrier collection has to come last.
3093 ED.clearAssumeInstAndAlignedBarriers();
3094 ED.addAlignedBarrier(A, CB);
3095 auto &CallOutED = CEDMap[{&CB, POST}];
3096 Changed |= mergeInPredecessor(A, CallOutED, ED);
3097 };
3098
3099 auto *LivenessAA =
3100 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3101
3102 Function *F = getAnchorScope();
3103 BasicBlock &EntryBB = F->getEntryBlock();
3104 bool IsKernel = omp::isOpenMPKernel(*F);
3105
3106 SmallVector<Instruction *> SyncInstWorklist;
3107 for (auto &RIt : *RPOT) {
3108 BasicBlock &BB = *RIt;
3109
3110 bool IsEntryBB = &BB == &EntryBB;
3111 // TODO: We use local reasoning since we don't have a divergence analysis
3112 // running as well. We could basically allow uniform branches here.
3113 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3114 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3115 ExecutionDomainTy ED;
3116 // Propagate "incoming edges" into information about this block.
3117 if (IsEntryBB) {
3118 Changed |= handleCallees(A, ED);
3119 } else {
3120 // For live non-entry blocks we only propagate
3121 // information via live edges.
3122 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3123 continue;
3124
3125 for (auto *PredBB : predecessors(&BB)) {
3126 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3127 continue;
3128 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3129 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3130 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3131 }
3132 }
3133
3134 // Now we traverse the block, accumulate effects in ED and attach
3135 // information to calls.
3136 for (Instruction &I : BB) {
3137 bool UsedAssumedInformation;
3138 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3139 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3140 /* CheckForDeadStore */ true))
3141 continue;
3142
3143 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3144 // former is collected the latter is ignored.
3145 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3146 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3147 ED.addAssumeInst(A, *AI);
3148 continue;
3149 }
3150 // TODO: Should we also collect and delete lifetime markers?
3151 if (II->isAssumeLikeIntrinsic())
3152 continue;
3153 }
3154
3155 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3156 if (!ED.EncounteredNonLocalSideEffect) {
3157 // An aligned fence without non-local side-effects is a no-op.
3158 if (ED.IsReachedFromAlignedBarrierOnly)
3159 continue;
3160 // A non-aligned fence without non-local side-effects is a no-op
3161 // if the ordering only publishes non-local side-effects (or less).
3162 switch (FI->getOrdering()) {
3163 case AtomicOrdering::NotAtomic:
3164 continue;
3165 case AtomicOrdering::Unordered:
3166 continue;
3167 case AtomicOrdering::Monotonic:
3168 continue;
3169 case AtomicOrdering::Acquire:
3170 break;
3171 case AtomicOrdering::Release:
3172 continue;
3173 case AtomicOrdering::AcquireRelease:
3174 break;
3175 case AtomicOrdering::SequentiallyConsistent:
3176 break;
3177 };
3178 }
3179 NonNoOpFences.insert(FI);
3180 }
3181
3182 auto *CB = dyn_cast<CallBase>(&I);
3183 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3184 bool IsAlignedBarrier =
3185 !IsNoSync && CB &&
3186 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3187
3188 AlignedBarrierLastInBlock &= IsNoSync;
3189 IsExplicitlyAligned &= IsNoSync;
3190
3191 // Next we check for calls. Aligned barriers are handled
3192 // explicitly, everything else is kept for the backward traversal and will
3193 // also affect our state.
3194 if (CB) {
3195 if (IsAlignedBarrier) {
3196 HandleAlignedBarrier(*CB, ED);
3197 AlignedBarrierLastInBlock = true;
3198 IsExplicitlyAligned = true;
3199 continue;
3200 }
3201
3202 // Check the pointer(s) of a memory intrinsic explicitly.
3203 if (isa<MemIntrinsic>(&I)) {
3204 if (!ED.EncounteredNonLocalSideEffect &&
3206 ED.EncounteredNonLocalSideEffect = true;
3207 if (!IsNoSync) {
3208 ED.IsReachedFromAlignedBarrierOnly = false;
3209 SyncInstWorklist.push_back(&I);
3210 }
3211 continue;
3212 }
3213
3214 // Record how we entered the call, then accumulate the effect of the
3215 // call in ED for potential use by the callee.
3216 auto &CallInED = CEDMap[{CB, PRE}];
3217 Changed |= mergeInPredecessor(A, CallInED, ED);
3218
3219 // If we have a sync-definition we can check if it starts/ends in an
3220 // aligned barrier. If we are unsure we assume any sync breaks
3221 // alignment.
3223 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3224 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3225 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3226 if (EDAA && EDAA->getState().isValidState()) {
3227 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3228 ED.IsReachedFromAlignedBarrierOnly =
3229 CalleeED.IsReachedFromAlignedBarrierOnly;
3230 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3231 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3232 ED.EncounteredNonLocalSideEffect |=
3233 CalleeED.EncounteredNonLocalSideEffect;
3234 else
3235 ED.EncounteredNonLocalSideEffect =
3236 CalleeED.EncounteredNonLocalSideEffect;
3237 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3238 Changed |=
3239 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3240 SyncInstWorklist.push_back(&I);
3241 }
3242 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3243 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3244 auto &CallOutED = CEDMap[{CB, POST}];
3245 Changed |= mergeInPredecessor(A, CallOutED, ED);
3246 continue;
3247 }
3248 }
3249 if (!IsNoSync) {
3250 ED.IsReachedFromAlignedBarrierOnly = false;
3251 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3252 SyncInstWorklist.push_back(&I);
3253 }
3254 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3255 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3256 auto &CallOutED = CEDMap[{CB, POST}];
3257 Changed |= mergeInPredecessor(A, CallOutED, ED);
3258 }
3259
3260 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3261 continue;
3262
3263 // If we have a callee we try to use fine-grained information to
3264 // determine local side-effects.
3265 if (CB) {
3266 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3267 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3268
3269 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3272 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3273 };
3274 if (MemAA && MemAA->getState().isValidState() &&
3275 MemAA->checkForAllAccessesToMemoryKind(
3277 continue;
3278 }
3279
3280 auto &InfoCache = A.getInfoCache();
3281 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3282 continue;
3283
3284 if (auto *LI = dyn_cast<LoadInst>(&I))
3285 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3286 continue;
3287
3288 if (!ED.EncounteredNonLocalSideEffect &&
3290 ED.EncounteredNonLocalSideEffect = true;
3291 }
3292
3293 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3294 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3295 !BB.getTerminator()->getNumSuccessors()) {
3296
3297 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3298
3299 auto &FnED = BEDMap[nullptr];
3300 if (IsKernel && !IsExplicitlyAligned)
3301 FnED.IsReachingAlignedBarrierOnly = false;
3302 Changed |= mergeInPredecessor(A, FnED, ED);
3303
3304 if (!FnED.IsReachingAlignedBarrierOnly) {
3305 IsEndAndNotReachingAlignedBarriersOnly = true;
3306 SyncInstWorklist.push_back(BB.getTerminator());
3307 auto &BBED = BEDMap[&BB];
3308 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3309 }
3310 }
3311
3312 ExecutionDomainTy &StoredED = BEDMap[&BB];
3313 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3314 !IsEndAndNotReachingAlignedBarriersOnly;
3315
3316 // Check if we computed anything different as part of the forward
3317 // traversal. We do not take assumptions and aligned barriers into account
3318 // as they do not influence the state we iterate. Backward traversal values
3319 // are handled later on.
3320 if (ED.IsExecutedByInitialThreadOnly !=
3321 StoredED.IsExecutedByInitialThreadOnly ||
3322 ED.IsReachedFromAlignedBarrierOnly !=
3323 StoredED.IsReachedFromAlignedBarrierOnly ||
3324 ED.EncounteredNonLocalSideEffect !=
3325 StoredED.EncounteredNonLocalSideEffect)
3326 Changed = true;
3327
3328 // Update the state with the new value.
3329 StoredED = std::move(ED);
3330 }
3331
3332 // Propagate (non-aligned) sync instruction effects backwards until the
3333 // entry is hit or an aligned barrier.
3334 SmallSetVector<BasicBlock *, 16> Visited;
3335 while (!SyncInstWorklist.empty()) {
3336 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3337 Instruction *CurInst = SyncInst;
3338 bool HitAlignedBarrierOrKnownEnd = false;
3339 while ((CurInst = CurInst->getPrevNode())) {
3340 auto *CB = dyn_cast<CallBase>(CurInst);
3341 if (!CB)
3342 continue;
3343 auto &CallOutED = CEDMap[{CB, POST}];
3344 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3345 auto &CallInED = CEDMap[{CB, PRE}];
3346 HitAlignedBarrierOrKnownEnd =
3347 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3348 if (HitAlignedBarrierOrKnownEnd)
3349 break;
3350 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3351 }
3352 if (HitAlignedBarrierOrKnownEnd)
3353 continue;
3354 BasicBlock *SyncBB = SyncInst->getParent();
3355 for (auto *PredBB : predecessors(SyncBB)) {
3356 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3357 continue;
3358 if (!Visited.insert(PredBB))
3359 continue;
3360 auto &PredED = BEDMap[PredBB];
3361 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3362 Changed = true;
3363 SyncInstWorklist.push_back(PredBB->getTerminator());
3364 }
3365 }
3366 if (SyncBB != &EntryBB)
3367 continue;
3368 Changed |=
3369 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3370 }
3371
3372 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3373}
3374
3375/// Try to replace memory allocation calls called by a single thread with a
3376/// static buffer of shared memory.
3377struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3378 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3379 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3380
3381 /// Create an abstract attribute view for the position \p IRP.
3382 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3383 Attributor &A);
3384
3385 /// Returns true if HeapToShared conversion is assumed to be possible.
3386 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3387
3388 /// Returns true if HeapToShared conversion is assumed and the CB is a
3389 /// callsite to a free operation to be removed.
3390 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3391
3392 /// See AbstractAttribute::getName().
3393 StringRef getName() const override { return "AAHeapToShared"; }
3394
3395 /// See AbstractAttribute::getIdAddr().
3396 const char *getIdAddr() const override { return &ID; }
3397
3398 /// This function should return true if the type of the \p AA is
3399 /// AAHeapToShared.
3400 static bool classof(const AbstractAttribute *AA) {
3401 return (AA->getIdAddr() == &ID);
3402 }
3403
3404 /// Unique ID (due to the unique address)
3405 static const char ID;
3406};
3407
3408struct AAHeapToSharedFunction : public AAHeapToShared {
3409 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3410 : AAHeapToShared(IRP, A) {}
3411
3412 const std::string getAsStr(Attributor *) const override {
3413 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3414 " malloc calls eligible.";
3415 }
3416
3417 /// See AbstractAttribute::trackStatistics().
3418 void trackStatistics() const override {}
3419
3420 /// This functions finds free calls that will be removed by the
3421 /// HeapToShared transformation.
3422 void findPotentialRemovedFreeCalls(Attributor &A) {
3423 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3424 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3425
3426 PotentialRemovedFreeCalls.clear();
3427 // Update free call users of found malloc calls.
3428 for (CallBase *CB : MallocCalls) {
3430 for (auto *U : CB->users()) {
3431 CallBase *C = dyn_cast<CallBase>(U);
3432 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3433 FreeCalls.push_back(C);
3434 }
3435
3436 if (FreeCalls.size() != 1)
3437 continue;
3438
3439 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3440 }
3441 }
3442
3443 void initialize(Attributor &A) override {
3445 indicatePessimisticFixpoint();
3446 return;
3447 }
3448
3449 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3450 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3451 if (!RFI.Declaration)
3452 return;
3453
3455 [](const IRPosition &, const AbstractAttribute *,
3456 bool &) -> std::optional<Value *> { return nullptr; };
3457
3458 Function *F = getAnchorScope();
3459 for (User *U : RFI.Declaration->users())
3460 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3461 if (CB->getFunction() != F)
3462 continue;
3463 MallocCalls.insert(CB);
3464 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3465 SCB);
3466 }
3467
3468 findPotentialRemovedFreeCalls(A);
3469 }
3470
3471 bool isAssumedHeapToShared(CallBase &CB) const override {
3472 return isValidState() && MallocCalls.count(&CB);
3473 }
3474
3475 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3476 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3477 }
3478
3479 ChangeStatus manifest(Attributor &A) override {
3480 if (MallocCalls.empty())
3481 return ChangeStatus::UNCHANGED;
3482
3483 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3484 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3485
3486 Function *F = getAnchorScope();
3487 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3488 DepClassTy::OPTIONAL);
3489
3490 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3491 for (CallBase *CB : MallocCalls) {
3492 // Skip replacing this if HeapToStack has already claimed it.
3493 if (HS && HS->isAssumedHeapToStack(*CB))
3494 continue;
3495
3496 // Find the unique free call to remove it.
3498 for (auto *U : CB->users()) {
3499 CallBase *C = dyn_cast<CallBase>(U);
3500 if (C && C->getCalledFunction() == FreeCall.Declaration)
3501 FreeCalls.push_back(C);
3502 }
3503 if (FreeCalls.size() != 1)
3504 continue;
3505
3506 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3507
3508 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3509 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3510 << " with shared memory."
3511 << " Shared memory usage is limited to "
3512 << SharedMemoryLimit << " bytes\n");
3513 continue;
3514 }
3515
3516 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3517 << " with " << AllocSize->getZExtValue()
3518 << " bytes of shared memory\n");
3519
3520 // Create a new shared memory buffer of the same size as the allocation
3521 // and replace all the uses of the original allocation with it.
3522 Module *M = CB->getModule();
3523 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3524 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3525 auto *SharedMem = new GlobalVariable(
3526 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3527 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3529 static_cast<unsigned>(AddressSpace::Shared));
3530 auto *NewBuffer = ConstantExpr::getPointerCast(
3531 SharedMem, PointerType::getUnqual(M->getContext()));
3532
3533 auto Remark = [&](OptimizationRemark OR) {
3534 return OR << "Replaced globalized variable with "
3535 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3536 << (AllocSize->isOne() ? " byte " : " bytes ")
3537 << "of shared memory.";
3538 };
3539 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3540
3541 MaybeAlign Alignment = CB->getRetAlign();
3542 assert(Alignment &&
3543 "HeapToShared on allocation without alignment attribute");
3544 SharedMem->setAlignment(*Alignment);
3545
3546 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3547 A.deleteAfterManifest(*CB);
3548 A.deleteAfterManifest(*FreeCalls.front());
3549
3550 SharedMemoryUsed += AllocSize->getZExtValue();
3551 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3552 Changed = ChangeStatus::CHANGED;
3553 }
3554
3555 return Changed;
3556 }
3557
3558 ChangeStatus updateImpl(Attributor &A) override {
3559 if (MallocCalls.empty())
3560 return indicatePessimisticFixpoint();
3561 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3562 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3563 if (!RFI.Declaration)
3564 return ChangeStatus::UNCHANGED;
3565
3566 Function *F = getAnchorScope();
3567
3568 auto NumMallocCalls = MallocCalls.size();
3569
3570 // Only consider malloc calls executed by a single thread with a constant.
3571 for (User *U : RFI.Declaration->users()) {
3572 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3573 if (CB->getCaller() != F)
3574 continue;
3575 if (!MallocCalls.count(CB))
3576 continue;
3577 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3578 MallocCalls.remove(CB);
3579 continue;
3580 }
3581 const auto *ED = A.getAAFor<AAExecutionDomain>(
3582 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3583 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3584 MallocCalls.remove(CB);
3585 }
3586 }
3587
3588 findPotentialRemovedFreeCalls(A);
3589
3590 if (NumMallocCalls != MallocCalls.size())
3591 return ChangeStatus::CHANGED;
3592
3593 return ChangeStatus::UNCHANGED;
3594 }
3595
3596 /// Collection of all malloc calls in a function.
3597 SmallSetVector<CallBase *, 4> MallocCalls;
3598 /// Collection of potentially removed free calls in a function.
3599 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3600 /// The total amount of shared memory that has been used for HeapToShared.
3601 unsigned SharedMemoryUsed = 0;
3602};
3603
3604struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3605 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3606 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3607
3608 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3609 /// unknown callees.
3610 static bool requiresCalleeForCallBase() { return false; }
3611
3612 /// Statistics are tracked as part of manifest for now.
3613 void trackStatistics() const override {}
3614
3615 /// See AbstractAttribute::getAsStr()
3616 const std::string getAsStr(Attributor *) const override {
3617 if (!isValidState())
3618 return "<invalid>";
3619 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3620 : "generic") +
3621 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3622 : "") +
3623 std::string(" #PRs: ") +
3624 (ReachedKnownParallelRegions.isValidState()
3625 ? std::to_string(ReachedKnownParallelRegions.size())
3626 : "<invalid>") +
3627 ", #Unknown PRs: " +
3628 (ReachedUnknownParallelRegions.isValidState()
3629 ? std::to_string(ReachedUnknownParallelRegions.size())
3630 : "<invalid>") +
3631 ", #Reaching Kernels: " +
3632 (ReachingKernelEntries.isValidState()
3633 ? std::to_string(ReachingKernelEntries.size())
3634 : "<invalid>") +
3635 ", #ParLevels: " +
3636 (ParallelLevels.isValidState()
3637 ? std::to_string(ParallelLevels.size())
3638 : "<invalid>") +
3639 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3640 }
3641
3642 /// Create an abstract attribute biew for the position \p IRP.
3643 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3644
3645 /// See AbstractAttribute::getName()
3646 StringRef getName() const override { return "AAKernelInfo"; }
3647
3648 /// See AbstractAttribute::getIdAddr()
3649 const char *getIdAddr() const override { return &ID; }
3650
3651 /// This function should return true if the type of the \p AA is AAKernelInfo
3652 static bool classof(const AbstractAttribute *AA) {
3653 return (AA->getIdAddr() == &ID);
3654 }
3655
3656 static const char ID;
3657};
3658
3659/// The function kernel info abstract attribute, basically, what can we say
3660/// about a function with regards to the KernelInfoState.
3661struct AAKernelInfoFunction : AAKernelInfo {
3662 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3663 : AAKernelInfo(IRP, A) {}
3664
3665 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3666
3667 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3668 return GuardedInstructions;
3669 }
3670
3671 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3673 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3674 assert(NewKernelEnvC && "Failed to create new kernel environment");
3675 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3676 }
3677
3678#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3679 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3680 ConstantStruct *ConfigC = \
3681 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3682 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3683 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3684 assert(NewConfigC && "Failed to create new configuration environment"); \
3685 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3686 }
3687
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3689 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3695
3696#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3697
3698 /// See AbstractAttribute::initialize(...).
3699 void initialize(Attributor &A) override {
3700 // This is a high-level transform that might change the constant arguments
3701 // of the init and dinit calls. We need to tell the Attributor about this
3702 // to avoid other parts using the current constant value for simpliication.
3703 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3704
3705 Function *Fn = getAnchorScope();
3706
3707 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3708 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3709 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3710 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3711
3712 // For kernels we perform more initialization work, first we find the init
3713 // and deinit calls.
3714 auto StoreCallBase = [](Use &U,
3715 OMPInformationCache::RuntimeFunctionInfo &RFI,
3716 CallBase *&Storage) {
3717 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3718 assert(CB &&
3719 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3720 assert(!Storage &&
3721 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3722 Storage = CB;
3723 return false;
3724 };
3725 InitRFI.foreachUse(
3726 [&](Use &U, Function &) {
3727 StoreCallBase(U, InitRFI, KernelInitCB);
3728 return false;
3729 },
3730 Fn);
3731 DeinitRFI.foreachUse(
3732 [&](Use &U, Function &) {
3733 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3734 return false;
3735 },
3736 Fn);
3737
3738 // Ignore kernels without initializers such as global constructors.
3739 if (!KernelInitCB || !KernelDeinitCB)
3740 return;
3741
3742 // Add itself to the reaching kernel and set IsKernelEntry.
3743 ReachingKernelEntries.insert(Fn);
3744 IsKernelEntry = true;
3745
3746 KernelEnvC =
3748 GlobalVariable *KernelEnvGV =
3750
3752 KernelConfigurationSimplifyCB =
3753 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3754 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3755 if (!isAtFixpoint()) {
3756 if (!AA)
3757 return nullptr;
3758 UsedAssumedInformation = true;
3759 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3760 }
3761 return KernelEnvC;
3762 };
3763
3764 A.registerGlobalVariableSimplificationCallback(
3765 *KernelEnvGV, KernelConfigurationSimplifyCB);
3766
3767 // We cannot change to SPMD mode if the runtime functions aren't availible.
3768 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3769 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3770 OMPRTL___kmpc_barrier_simple_spmd});
3771
3772 // Check if we know we are in SPMD-mode already.
3773 ConstantInt *ExecModeC =
3774 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3775 ConstantInt *AssumedExecModeC = ConstantInt::get(
3776 ExecModeC->getIntegerType(),
3778 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3779 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3780 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3781 // This is a generic region but SPMDization is disabled so stop
3782 // tracking.
3783 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3784 else
3785 setExecModeOfKernelEnvironment(AssumedExecModeC);
3786
3787 const Triple T(Fn->getParent()->getTargetTriple());
3788 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3789 auto [MinThreads, MaxThreads] =
3790 OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
3791 if (MinThreads)
3792 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3793 if (MaxThreads)
3794 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3795 auto [MinTeams, MaxTeams] =
3796 OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
3797 if (MinTeams)
3798 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3799 if (MaxTeams)
3800 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3801
3802 ConstantInt *MayUseNestedParallelismC =
3803 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3804 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3805 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3806 setMayUseNestedParallelismOfKernelEnvironment(
3807 AssumedMayUseNestedParallelismC);
3808
3810 ConstantInt *UseGenericStateMachineC =
3811 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3812 KernelEnvC);
3813 ConstantInt *AssumedUseGenericStateMachineC =
3814 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3815 setUseGenericStateMachineOfKernelEnvironment(
3816 AssumedUseGenericStateMachineC);
3817 }
3818
3819 // Register virtual uses of functions we might need to preserve.
3820 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3822 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3823 return;
3824 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3825 };
3826
3827 // Add a dependence to ensure updates if the state changes.
3828 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3829 const AbstractAttribute *QueryingAA) {
3830 if (QueryingAA) {
3831 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3832 }
3833 return true;
3834 };
3835
3836 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3837 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3838 // Whenever we create a custom state machine we will insert calls to
3839 // __kmpc_get_hardware_num_threads_in_block,
3840 // __kmpc_get_warp_size,
3841 // __kmpc_barrier_simple_generic,
3842 // __kmpc_kernel_parallel, and
3843 // __kmpc_kernel_end_parallel.
3844 // Not needed if we are on track for SPMDzation.
3845 if (SPMDCompatibilityTracker.isValidState())
3846 return AddDependence(A, this, QueryingAA);
3847 // Not needed if we can't rewrite due to an invalid state.
3848 if (!ReachedKnownParallelRegions.isValidState())
3849 return AddDependence(A, this, QueryingAA);
3850 return false;
3851 };
3852
3853 // Not needed if we are pre-runtime merge.
3854 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3855 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3856 CustomStateMachineUseCB);
3857 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3858 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3859 CustomStateMachineUseCB);
3860 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3861 CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3863 CustomStateMachineUseCB);
3864 }
3865
3866 // If we do not perform SPMDzation we do not need the virtual uses below.
3867 if (SPMDCompatibilityTracker.isAtFixpoint())
3868 return;
3869
3870 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3871 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3872 // Whenever we perform SPMDzation we will insert
3873 // __kmpc_get_hardware_thread_id_in_block calls.
3874 if (!SPMDCompatibilityTracker.isValidState())
3875 return AddDependence(A, this, QueryingAA);
3876 return false;
3877 };
3878 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3879 HWThreadIdUseCB);
3880
3881 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3882 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3883 // Whenever we perform SPMDzation with guarding we will insert
3884 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3885 // nothing to guard, or there are no parallel regions, we don't need
3886 // the calls.
3887 if (!SPMDCompatibilityTracker.isValidState())
3888 return AddDependence(A, this, QueryingAA);
3889 if (SPMDCompatibilityTracker.empty())
3890 return AddDependence(A, this, QueryingAA);
3891 if (!mayContainParallelRegion())
3892 return AddDependence(A, this, QueryingAA);
3893 return false;
3894 };
3895 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3896 }
3897
3898 /// Sanitize the string \p S such that it is a suitable global symbol name.
3899 static std::string sanitizeForGlobalName(std::string S) {
3900 std::replace_if(
3901 S.begin(), S.end(),
3902 [](const char C) {
3903 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3904 (C >= '0' && C <= '9') || C == '_');
3905 },
3906 '.');
3907 return S;
3908 }
3909
3910 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3911 /// finished now.
3912 ChangeStatus manifest(Attributor &A) override {
3913 // If we are not looking at a kernel with __kmpc_target_init and
3914 // __kmpc_target_deinit call we cannot actually manifest the information.
3915 if (!KernelInitCB || !KernelDeinitCB)
3916 return ChangeStatus::UNCHANGED;
3917
3918 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3919
3920 bool HasBuiltStateMachine = true;
3921 if (!changeToSPMDMode(A, Changed)) {
3922 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3923 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3924 else
3925 HasBuiltStateMachine = false;
3926 }
3927
3928 // We need to reset KernelEnvC if specific rewriting is not done.
3929 ConstantStruct *ExistingKernelEnvC =
3931 ConstantInt *OldUseGenericStateMachineVal =
3932 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3933 ExistingKernelEnvC);
3934 if (!HasBuiltStateMachine)
3935 setUseGenericStateMachineOfKernelEnvironment(
3936 OldUseGenericStateMachineVal);
3937
3938 // At last, update the KernelEnvc
3939 GlobalVariable *KernelEnvGV =
3941 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3942 KernelEnvGV->setInitializer(KernelEnvC);
3943 Changed = ChangeStatus::CHANGED;
3944 }
3945
3946 return Changed;
3947 }
3948
3949 void insertInstructionGuardsHelper(Attributor &A) {
3950 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3951
3952 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3953 Instruction *RegionEndI) {
3954 LoopInfo *LI = nullptr;
3955 DominatorTree *DT = nullptr;
3956 MemorySSAUpdater *MSU = nullptr;
3957 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3958
3959 BasicBlock *ParentBB = RegionStartI->getParent();
3960 Function *Fn = ParentBB->getParent();
3961 Module &M = *Fn->getParent();
3962
3963 // Create all the blocks and logic.
3964 // ParentBB:
3965 // goto RegionCheckTidBB
3966 // RegionCheckTidBB:
3967 // Tid = __kmpc_hardware_thread_id()
3968 // if (Tid != 0)
3969 // goto RegionBarrierBB
3970 // RegionStartBB:
3971 // <execute instructions guarded>
3972 // goto RegionEndBB
3973 // RegionEndBB:
3974 // <store escaping values to shared mem>
3975 // goto RegionBarrierBB
3976 // RegionBarrierBB:
3977 // __kmpc_simple_barrier_spmd()
3978 // // second barrier is omitted if lacking escaping values.
3979 // <load escaping values from shared mem>
3980 // __kmpc_simple_barrier_spmd()
3981 // goto RegionExitBB
3982 // RegionExitBB:
3983 // <execute rest of instructions>
3984
3985 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3986 DT, LI, MSU, "region.guarded.end");
3987 BasicBlock *RegionBarrierBB =
3988 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3989 MSU, "region.barrier");
3990 BasicBlock *RegionExitBB =
3991 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3992 DT, LI, MSU, "region.exit");
3993 BasicBlock *RegionStartBB =
3994 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3995
3996 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3997 "Expected a different CFG");
3998
3999 BasicBlock *RegionCheckTidBB = SplitBlock(
4000 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4001
4002 // Register basic blocks with the Attributor.
4003 A.registerManifestAddedBasicBlock(*RegionEndBB);
4004 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4005 A.registerManifestAddedBasicBlock(*RegionExitBB);
4006 A.registerManifestAddedBasicBlock(*RegionStartBB);
4007 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4008
4009 bool HasBroadcastValues = false;
4010 // Find escaping outputs from the guarded region to outside users and
4011 // broadcast their values to them.
4012 for (Instruction &I : *RegionStartBB) {
4013 SmallVector<Use *, 4> OutsideUses;
4014 for (Use &U : I.uses()) {
4015 Instruction &UsrI = *cast<Instruction>(U.getUser());
4016 if (UsrI.getParent() != RegionStartBB)
4017 OutsideUses.push_back(&U);
4018 }
4019
4020 if (OutsideUses.empty())
4021 continue;
4022
4023 HasBroadcastValues = true;
4024
4025 // Emit a global variable in shared memory to store the broadcasted
4026 // value.
4027 auto *SharedMem = new GlobalVariable(
4028 M, I.getType(), /* IsConstant */ false,
4030 sanitizeForGlobalName(
4031 (I.getName() + ".guarded.output.alloc").str()),
4033 static_cast<unsigned>(AddressSpace::Shared));
4034
4035 // Emit a store instruction to update the value.
4036 new StoreInst(&I, SharedMem,
4037 RegionEndBB->getTerminator()->getIterator());
4038
4039 LoadInst *LoadI = new LoadInst(
4040 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4041 RegionBarrierBB->getTerminator()->getIterator());
4042
4043 // Emit a load instruction and replace uses of the output value.
4044 for (Use *U : OutsideUses)
4045 A.changeUseAfterManifest(*U, *LoadI);
4046 }
4047
4048 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4049
4050 // Go to tid check BB in ParentBB.
4051 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4052 ParentBB->getTerminator()->eraseFromParent();
4053 OpenMPIRBuilder::LocationDescription Loc(
4054 InsertPointTy(ParentBB, ParentBB->end()), DL);
4055 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4056 uint32_t SrcLocStrSize;
4057 auto *SrcLocStr =
4058 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4059 Value *Ident =
4060 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4061 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4062
4063 // Add check for Tid in RegionCheckTidBB
4064 RegionCheckTidBB->getTerminator()->eraseFromParent();
4065 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4066 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4067 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4068 FunctionCallee HardwareTidFn =
4069 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4070 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4071 CallInst *Tid =
4072 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4073 Tid->setDebugLoc(DL);
4074 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4075 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4076 OMPInfoCache.OMPBuilder.Builder
4077 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4078 ->setDebugLoc(DL);
4079
4080 // First barrier for synchronization, ensures main thread has updated
4081 // values.
4082 FunctionCallee BarrierFn =
4083 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4084 M, OMPRTL___kmpc_barrier_simple_spmd);
4085 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4086 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4087 CallInst *Barrier =
4088 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4089 Barrier->setDebugLoc(DL);
4090 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4091
4092 // Second barrier ensures workers have read broadcast values.
4093 if (HasBroadcastValues) {
4094 CallInst *Barrier =
4095 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4096 RegionBarrierBB->getTerminator()->getIterator());
4097 Barrier->setDebugLoc(DL);
4098 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4099 }
4100 };
4101
4102 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4103 SmallPtrSet<BasicBlock *, 8> Visited;
4104 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4105 BasicBlock *BB = GuardedI->getParent();
4106 if (!Visited.insert(BB).second)
4107 continue;
4108
4110 Instruction *LastEffect = nullptr;
4111 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4112 while (++IP != IPEnd) {
4113 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4114 continue;
4115 Instruction *I = &*IP;
4116 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4117 continue;
4118 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4119 LastEffect = nullptr;
4120 continue;
4121 }
4122 if (LastEffect)
4123 Reorders.push_back({I, LastEffect});
4124 LastEffect = &*IP;
4125 }
4126 for (auto &Reorder : Reorders)
4127 Reorder.first->moveBefore(Reorder.second->getIterator());
4128 }
4129
4131
4132 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4133 BasicBlock *BB = GuardedI->getParent();
4134 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4135 IRPosition::function(*GuardedI->getFunction()), nullptr,
4136 DepClassTy::NONE);
4137 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4138 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4139 // Continue if instruction is already guarded.
4140 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4141 continue;
4142
4143 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4144 for (Instruction &I : *BB) {
4145 // If instruction I needs to be guarded update the guarded region
4146 // bounds.
4147 if (SPMDCompatibilityTracker.contains(&I)) {
4148 CalleeAAFunction.getGuardedInstructions().insert(&I);
4149 if (GuardedRegionStart)
4150 GuardedRegionEnd = &I;
4151 else
4152 GuardedRegionStart = GuardedRegionEnd = &I;
4153
4154 continue;
4155 }
4156
4157 // Instruction I does not need guarding, store
4158 // any region found and reset bounds.
4159 if (GuardedRegionStart) {
4160 GuardedRegions.push_back(
4161 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4162 GuardedRegionStart = nullptr;
4163 GuardedRegionEnd = nullptr;
4164 }
4165 }
4166 }
4167
4168 for (auto &GR : GuardedRegions)
4169 CreateGuardedRegion(GR.first, GR.second);
4170 }
4171
4172 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4173 // Only allow 1 thread per workgroup to continue executing the user code.
4174 //
4175 // InitCB = __kmpc_target_init(...)
4176 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4177 // if (ThreadIdInBlock != 0) return;
4178 // UserCode:
4179 // // user code
4180 //
4181 auto &Ctx = getAnchorValue().getContext();
4182 Function *Kernel = getAssociatedFunction();
4183 assert(Kernel && "Expected an associated function!");
4184
4185 // Create block for user code to branch to from initial block.
4186 BasicBlock *InitBB = KernelInitCB->getParent();
4187 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4188 KernelInitCB->getNextNode(), "main.thread.user_code");
4189 BasicBlock *ReturnBB =
4190 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4191
4192 // Register blocks with attributor:
4193 A.registerManifestAddedBasicBlock(*InitBB);
4194 A.registerManifestAddedBasicBlock(*UserCodeBB);
4195 A.registerManifestAddedBasicBlock(*ReturnBB);
4196
4197 // Debug location:
4198 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4199 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4200 InitBB->getTerminator()->eraseFromParent();
4201
4202 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4203 Module &M = *Kernel->getParent();
4204 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4205 FunctionCallee ThreadIdInBlockFn =
4206 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4207 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4208
4209 // Get thread ID in block.
4210 CallInst *ThreadIdInBlock =
4211 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4212 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4213 ThreadIdInBlock->setDebugLoc(DLoc);
4214
4215 // Eliminate all threads in the block with ID not equal to 0:
4216 Instruction *IsMainThread =
4217 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4218 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4219 "thread.is_main", InitBB);
4220 IsMainThread->setDebugLoc(DLoc);
4221 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4222 }
4223
4224 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4225 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4226
4227 if (!SPMDCompatibilityTracker.isAssumed()) {
4228 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4229 if (!NonCompatibleI)
4230 continue;
4231
4232 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4233 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4234 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4235 continue;
4236
4237 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4238 ORA << "Value has potential side effects preventing SPMD-mode "
4239 "execution";
4240 if (isa<CallBase>(NonCompatibleI)) {
4241 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4242 "the called function to override";
4243 }
4244 return ORA << ".";
4245 };
4246 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4247 Remark);
4248
4249 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4250 << *NonCompatibleI << "\n");
4251 }
4252
4253 return false;
4254 }
4255
4256 // Get the actual kernel, could be the caller of the anchor scope if we have
4257 // a debug wrapper.
4258 Function *Kernel = getAnchorScope();
4259 if (Kernel->hasLocalLinkage()) {
4260 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4261 auto *CB = cast<CallBase>(Kernel->user_back());
4262 Kernel = CB->getCaller();
4263 }
4264 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4265
4266 // Check if the kernel is already in SPMD mode, if so, return success.
4267 ConstantStruct *ExistingKernelEnvC =
4269 auto *ExecModeC =
4270 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4271 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4272 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4273 return true;
4274
4275 // We will now unconditionally modify the IR, indicate a change.
4276 Changed = ChangeStatus::CHANGED;
4277
4278 // Do not use instruction guards when no parallel is present inside
4279 // the target region.
4280 if (mayContainParallelRegion())
4281 insertInstructionGuardsHelper(A);
4282 else
4283 forceSingleThreadPerWorkgroupHelper(A);
4284
4285 // Adjust the global exec mode flag that tells the runtime what mode this
4286 // kernel is executed in.
4287 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4288 "Initially non-SPMD kernel has SPMD exec mode!");
4289 setExecModeOfKernelEnvironment(
4290 ConstantInt::get(ExecModeC->getIntegerType(),
4291 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4292
4293 ++NumOpenMPTargetRegionKernelsSPMD;
4294
4295 auto Remark = [&](OptimizationRemark OR) {
4296 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4297 };
4298 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4299 return true;
4300 };
4301
4302 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4303 // If we have disabled state machine rewrites, don't make a custom one
4305 return false;
4306
4307 // Don't rewrite the state machine if we are not in a valid state.
4308 if (!ReachedKnownParallelRegions.isValidState())
4309 return false;
4310
4311 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4312 if (!OMPInfoCache.runtimeFnsAvailable(
4313 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4314 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4315 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4316 return false;
4317
4318 ConstantStruct *ExistingKernelEnvC =
4320
4321 // Check if the current configuration is non-SPMD and generic state machine.
4322 // If we already have SPMD mode or a custom state machine we do not need to
4323 // go any further. If it is anything but a constant something is weird and
4324 // we give up.
4325 ConstantInt *UseStateMachineC =
4326 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4327 ExistingKernelEnvC);
4328 ConstantInt *ModeC =
4329 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4330
4331 // If we are stuck with generic mode, try to create a custom device (=GPU)
4332 // state machine which is specialized for the parallel regions that are
4333 // reachable by the kernel.
4334 if (UseStateMachineC->isZero() ||
4336 return false;
4337
4338 Changed = ChangeStatus::CHANGED;
4339
4340 // If not SPMD mode, indicate we use a custom state machine now.
4341 setUseGenericStateMachineOfKernelEnvironment(
4342 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4343
4344 // If we don't actually need a state machine we are done here. This can
4345 // happen if there simply are no parallel regions. In the resulting kernel
4346 // all worker threads will simply exit right away, leaving the main thread
4347 // to do the work alone.
4348 if (!mayContainParallelRegion()) {
4349 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4350
4351 auto Remark = [&](OptimizationRemark OR) {
4352 return OR << "Removing unused state machine from generic-mode kernel.";
4353 };
4354 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4355
4356 return true;
4357 }
4358
4359 // Keep track in the statistics of our new shiny custom state machine.
4360 if (ReachedUnknownParallelRegions.empty()) {
4361 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4362
4363 auto Remark = [&](OptimizationRemark OR) {
4364 return OR << "Rewriting generic-mode kernel with a customized state "
4365 "machine.";
4366 };
4367 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4368 } else {
4369 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4370
4371 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4372 return OR << "Generic-mode kernel is executed with a customized state "
4373 "machine that requires a fallback.";
4374 };
4375 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4376
4377 // Tell the user why we ended up with a fallback.
4378 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4379 if (!UnknownParallelRegionCB)
4380 continue;
4381 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4382 return ORA << "Call may contain unknown parallel regions. Use "
4383 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4384 "override.";
4385 };
4386 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4387 "OMP133", Remark);
4388 }
4389 }
4390
4391 // Create all the blocks:
4392 //
4393 // InitCB = __kmpc_target_init(...)
4394 // BlockHwSize =
4395 // __kmpc_get_hardware_num_threads_in_block();
4396 // WarpSize = __kmpc_get_warp_size();
4397 // BlockSize = BlockHwSize - WarpSize;
4398 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4399 // if (IsWorker) {
4400 // if (InitCB >= BlockSize) return;
4401 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4402 // void *WorkFn;
4403 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4404 // if (!WorkFn) return;
4405 // SMIsActiveCheckBB: if (Active) {
4406 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4407 // ParFn0(...);
4408 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4409 // ParFn1(...);
4410 // ...
4411 // SMIfCascadeCurrentBB: else
4412 // ((WorkFnTy*)WorkFn)(...);
4413 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4414 // }
4415 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4416 // goto SMBeginBB;
4417 // }
4418 // UserCodeEntryBB: // user code
4419 // __kmpc_target_deinit(...)
4420 //
4421 auto &Ctx = getAnchorValue().getContext();
4422 Function *Kernel = getAssociatedFunction();
4423 assert(Kernel && "Expected an associated function!");
4424
4425 BasicBlock *InitBB = KernelInitCB->getParent();
4426 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4427 KernelInitCB->getNextNode(), "thread.user_code.check");
4428 BasicBlock *IsWorkerCheckBB =
4429 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineIfCascadeCurrentBB =
4437 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4438 Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineEndParallelBB =
4440 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4441 Kernel, UserCodeEntryBB);
4442 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4443 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*InitBB);
4445 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4453
4454 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4455 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4456 InitBB->getTerminator()->eraseFromParent();
4457
4458 Instruction *IsWorker =
4459 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4460 ConstantInt::get(KernelInitCB->getType(), -1),
4461 "thread.is_worker", InitBB);
4462 IsWorker->setDebugLoc(DLoc);
4463 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4464
4465 Module &M = *Kernel->getParent();
4466 FunctionCallee BlockHwSizeFn =
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4469 FunctionCallee WarpSizeFn =
4470 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4471 M, OMPRTL___kmpc_get_warp_size);
4472 CallInst *BlockHwSize =
4473 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4475 BlockHwSize->setDebugLoc(DLoc);
4476 CallInst *WarpSize =
4477 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4479 WarpSize->setDebugLoc(DLoc);
4480 Instruction *BlockSize = BinaryOperator::CreateSub(
4481 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4482 BlockSize->setDebugLoc(DLoc);
4483 Instruction *IsMainOrWorker = ICmpInst::Create(
4484 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4485 "thread.is_main_or_worker", IsWorkerCheckBB);
4486 IsMainOrWorker->setDebugLoc(DLoc);
4487 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4488 IsMainOrWorker, IsWorkerCheckBB);
4489
4490 // Create local storage for the work function pointer.
4491 const DataLayout &DL = M.getDataLayout();
4492 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4493 Instruction *WorkFnAI =
4494 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4495 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4496 WorkFnAI->setDebugLoc(DLoc);
4497
4498 OMPInfoCache.OMPBuilder.updateToLocation(
4499 OpenMPIRBuilder::LocationDescription(
4500 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4501 StateMachineBeginBB->end()),
4502 DLoc));
4503
4504 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4505 Value *GTid = KernelInitCB;
4506
4507 FunctionCallee BarrierFn =
4508 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4509 M, OMPRTL___kmpc_barrier_simple_generic);
4510 CallInst *Barrier =
4511 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4512 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4513 Barrier->setDebugLoc(DLoc);
4514
4515 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4516 (unsigned int)AddressSpace::Generic) {
4517 WorkFnAI = new AddrSpaceCastInst(
4518 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4519 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4520 WorkFnAI->setDebugLoc(DLoc);
4521 }
4522
4523 FunctionCallee KernelParallelFn =
4524 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4525 M, OMPRTL___kmpc_kernel_parallel);
4526 CallInst *IsActiveWorker = CallInst::Create(
4527 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4528 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4529 IsActiveWorker->setDebugLoc(DLoc);
4530 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4531 StateMachineBeginBB);
4532 WorkFn->setDebugLoc(DLoc);
4533
4534 FunctionType *ParallelRegionFnTy = FunctionType::get(
4535 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4536 false);
4537
4538 Instruction *IsDone =
4539 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4540 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4541 StateMachineBeginBB);
4542 IsDone->setDebugLoc(DLoc);
4543 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4544 IsDone, StateMachineBeginBB)
4545 ->setDebugLoc(DLoc);
4546
4547 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4548 StateMachineDoneBarrierBB, IsActiveWorker,
4549 StateMachineIsActiveCheckBB)
4550 ->setDebugLoc(DLoc);
4551
4552 Value *ZeroArg =
4553 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4554
4555 const unsigned int WrapperFunctionArgNo = 6;
4556
4557 // Now that we have most of the CFG skeleton it is time for the if-cascade
4558 // that checks the function pointer we got from the runtime against the
4559 // parallel regions we expect, if there are any.
4560 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4561 auto *CB = ReachedKnownParallelRegions[I];
4562 auto *ParallelRegion = dyn_cast<Function>(
4563 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4564 BasicBlock *PRExecuteBB = BasicBlock::Create(
4565 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4566 StateMachineEndParallelBB);
4567 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4570 ->setDebugLoc(DLoc);
4571
4572 BasicBlock *PRNextBB =
4573 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4574 Kernel, StateMachineEndParallelBB);
4575 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4576 A.registerManifestAddedBasicBlock(*PRNextBB);
4577
4578 // Check if we need to compare the pointer at all or if we can just
4579 // call the parallel region function.
4580 Value *IsPR;
4581 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4582 Instruction *CmpI = ICmpInst::Create(
4583 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4584 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4585 CmpI->setDebugLoc(DLoc);
4586 IsPR = CmpI;
4587 } else {
4588 IsPR = ConstantInt::getTrue(Ctx);
4589 }
4590
4591 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4592 StateMachineIfCascadeCurrentBB)
4593 ->setDebugLoc(DLoc);
4594 StateMachineIfCascadeCurrentBB = PRNextBB;
4595 }
4596
4597 // At the end of the if-cascade we place the indirect function pointer call
4598 // in case we might need it, that is if there can be parallel regions we
4599 // have not handled in the if-cascade above.
4600 if (!ReachedUnknownParallelRegions.empty()) {
4601 StateMachineIfCascadeCurrentBB->setName(
4602 "worker_state_machine.parallel_region.fallback.execute");
4603 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4604 StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606 }
4607 BranchInst::Create(StateMachineEndParallelBB,
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610
4611 FunctionCallee EndParallelFn =
4612 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4613 M, OMPRTL___kmpc_kernel_end_parallel);
4614 CallInst *EndParallel =
4615 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4616 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4617 EndParallel->setDebugLoc(DLoc);
4618 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4619 ->setDebugLoc(DLoc);
4620
4621 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4624 ->setDebugLoc(DLoc);
4625
4626 return true;
4627 }
4628
4629 /// Fixpoint iteration update function. Will be called every time a dependence
4630 /// changed its state (and in the beginning).
4631 ChangeStatus updateImpl(Attributor &A) override {
4632 KernelInfoState StateBefore = getState();
4633
4634 // When we leave this function this RAII will make sure the member
4635 // KernelEnvC is updated properly depending on the state. That member is
4636 // used for simplification of values and needs to be up to date at all
4637 // times.
4638 struct UpdateKernelEnvCRAII {
4639 AAKernelInfoFunction &AA;
4640
4641 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4642
4643 ~UpdateKernelEnvCRAII() {
4644 if (!AA.KernelEnvC)
4645 return;
4646
4647 ConstantStruct *ExistingKernelEnvC =
4649
4650 if (!AA.isValidState()) {
4651 AA.KernelEnvC = ExistingKernelEnvC;
4652 return;
4653 }
4654
4655 if (!AA.ReachedKnownParallelRegions.isValidState())
4656 AA.setUseGenericStateMachineOfKernelEnvironment(
4657 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4658 ExistingKernelEnvC));
4659
4660 if (!AA.SPMDCompatibilityTracker.isValidState())
4661 AA.setExecModeOfKernelEnvironment(
4662 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4663
4664 ConstantInt *MayUseNestedParallelismC =
4665 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4666 AA.KernelEnvC);
4667 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4668 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4669 AA.setMayUseNestedParallelismOfKernelEnvironment(
4670 NewMayUseNestedParallelismC);
4671 }
4672 } RAII(*this);
4673
4674 // Callback to check a read/write instruction.
4675 auto CheckRWInst = [&](Instruction &I) {
4676 // We handle calls later.
4677 if (isa<CallBase>(I))
4678 return true;
4679 // We only care about write effects.
4680 if (!I.mayWriteToMemory())
4681 return true;
4682 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4683 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4684 *this, IRPosition::value(*SI->getPointerOperand()),
4685 DepClassTy::OPTIONAL);
4686 auto *HS = A.getAAFor<AAHeapToStack>(
4687 *this, IRPosition::function(*I.getFunction()),
4688 DepClassTy::OPTIONAL);
4689 if (UnderlyingObjsAA &&
4690 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4691 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4692 return true;
4693 // Check for AAHeapToStack moved objects which must not be
4694 // guarded.
4695 auto *CB = dyn_cast<CallBase>(&Obj);
4696 return CB && HS && HS->isAssumedHeapToStack(*CB);
4697 }))
4698 return true;
4699 }
4700
4701 // Insert instruction that needs guarding.
4702 SPMDCompatibilityTracker.insert(&I);
4703 return true;
4704 };
4705
4706 bool UsedAssumedInformationInCheckRWInst = false;
4707 if (!SPMDCompatibilityTracker.isAtFixpoint())
4708 if (!A.checkForAllReadWriteInstructions(
4709 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4710 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4711
4712 bool UsedAssumedInformationFromReachingKernels = false;
4713 if (!IsKernelEntry) {
4714 updateParallelLevels(A);
4715
4716 bool AllReachingKernelsKnown = true;
4717 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4718 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4719
4720 if (!SPMDCompatibilityTracker.empty()) {
4721 if (!ParallelLevels.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else if (!ReachingKernelEntries.isValidState())
4724 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4725 else {
4726 // Check if all reaching kernels agree on the mode as we can otherwise
4727 // not guard instructions. We might not be sure about the mode so we
4728 // we cannot fix the internal spmd-zation state either.
4729 int SPMD = 0, Generic = 0;
4730 for (auto *Kernel : ReachingKernelEntries) {
4731 auto *CBAA = A.getAAFor<AAKernelInfo>(
4732 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4733 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4734 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 ++SPMD;
4736 else
4737 ++Generic;
4738 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4739 UsedAssumedInformationFromReachingKernels = true;
4740 }
4741 if (SPMD != 0 && Generic != 0)
4742 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4743 }
4744 }
4745 }
4746
4747 // Callback to check a call instruction.
4748 bool AllParallelRegionStatesWereFixed = true;
4749 bool AllSPMDStatesWereFixed = true;
4750 auto CheckCallInst = [&](Instruction &I) {
4751 auto &CB = cast<CallBase>(I);
4752 auto *CBAA = A.getAAFor<AAKernelInfo>(
4753 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4754 if (!CBAA)
4755 return false;
4756 getState() ^= CBAA->getState();
4757 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4760 AllParallelRegionStatesWereFixed &=
4761 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 return true;
4763 };
4764
4765 bool UsedAssumedInformationInCheckCallInst = false;
4766 if (!A.checkForAllCallLikeInstructions(
4767 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4768 LLVM_DEBUG(dbgs() << TAG
4769 << "Failed to visit all call-like instructions!\n";);
4770 return indicatePessimisticFixpoint();
4771 }
4772
4773 // If we haven't used any assumed information for the reached parallel
4774 // region states we can fix it.
4775 if (!UsedAssumedInformationInCheckCallInst &&
4776 AllParallelRegionStatesWereFixed) {
4777 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4778 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the SPMD state we can fix
4782 // it.
4783 if (!UsedAssumedInformationInCheckRWInst &&
4784 !UsedAssumedInformationInCheckCallInst &&
4785 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4786 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4787
4788 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4789 : ChangeStatus::CHANGED;
4790 }
4791
4792private:
4793 /// Update info regarding reaching kernels.
4794 void updateReachingKernelEntries(Attributor &A,
4795 bool &AllReachingKernelsKnown) {
4796 auto PredCallSite = [&](AbstractCallSite ACS) {
4797 Function *Caller = ACS.getInstruction()->getFunction();
4798
4799 assert(Caller && "Caller is nullptr");
4800
4801 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4802 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4803 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4804 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4805 return true;
4806 }
4807
4808 // We lost track of the caller of the associated function, any kernel
4809 // could reach now.
4810 ReachingKernelEntries.indicatePessimisticFixpoint();
4811
4812 return true;
4813 };
4814
4815 if (!A.checkForAllCallSites(PredCallSite, *this,
4816 true /* RequireAllCallSites */,
4817 AllReachingKernelsKnown))
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 }
4820
4821 /// Update info regarding parallel levels.
4822 void updateParallelLevels(Attributor &A) {
4823 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4824 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4825 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4826
4827 auto PredCallSite = [&](AbstractCallSite ACS) {
4828 Function *Caller = ACS.getInstruction()->getFunction();
4829
4830 assert(Caller && "Caller is nullptr");
4831
4832 auto *CAA =
4833 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4834 if (CAA && CAA->ParallelLevels.isValidState()) {
4835 // Any function that is called by `__kmpc_parallel_51` will not be
4836 // folded as the parallel level in the function is updated. In order to
4837 // get it right, all the analysis would depend on the implentation. That
4838 // said, if in the future any change to the implementation, the analysis
4839 // could be wrong. As a consequence, we are just conservative here.
4840 if (Caller == Parallel51RFI.Declaration) {
4841 ParallelLevels.indicatePessimisticFixpoint();
4842 return true;
4843 }
4844
4845 ParallelLevels ^= CAA->ParallelLevels;
4846
4847 return true;
4848 }
4849
4850 // We lost track of the caller of the associated function, any kernel
4851 // could reach now.
4852 ParallelLevels.indicatePessimisticFixpoint();
4853
4854 return true;
4855 };
4856
4857 bool AllCallSitesKnown = true;
4858 if (!A.checkForAllCallSites(PredCallSite, *this,
4859 true /* RequireAllCallSites */,
4860 AllCallSitesKnown))
4861 ParallelLevels.indicatePessimisticFixpoint();
4862 }
4863};
4864
4865/// The call site kernel info abstract attribute, basically, what can we say
4866/// about a call site with regards to the KernelInfoState. For now this simply
4867/// forwards the information from the callee.
4868struct AAKernelInfoCallSite : AAKernelInfo {
4869 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4870 : AAKernelInfo(IRP, A) {}
4871
4872 /// See AbstractAttribute::initialize(...).
4873 void initialize(Attributor &A) override {
4874 AAKernelInfo::initialize(A);
4875
4876 CallBase &CB = cast<CallBase>(getAssociatedValue());
4877 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4878 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4879
4880 // Check for SPMD-mode assumptions.
4881 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4882 indicateOptimisticFixpoint();
4883 return;
4884 }
4885
4886 // First weed out calls we do not care about, that is readonly/readnone
4887 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4888 // parallel region or anything else we are looking for.
4889 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // Next we check if we know the callee. If it is a known OpenMP function
4895 // we will handle them explicitly in the switch below. If it is not, we
4896 // will use an AAKernelInfo object on the callee to gather information and
4897 // merge that into the current state. The latter happens in the updateImpl.
4898 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4899 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4900 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4901 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4902 // Unknown caller or declarations are not analyzable, we give up.
4903 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4904
4905 // Unknown callees might contain parallel regions, except if they have
4906 // an appropriate assumption attached.
4907 if (!AssumptionAA ||
4908 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4909 AssumptionAA->hasAssumption("omp_no_parallelism")))
4910 ReachedUnknownParallelRegions.insert(&CB);
4911
4912 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4913 // idea we can run something unknown in SPMD-mode.
4914 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4915 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4916 SPMDCompatibilityTracker.insert(&CB);
4917 }
4918
4919 // We have updated the state for this unknown call properly, there
4920 // won't be any change so we indicate a fixpoint.
4921 indicateOptimisticFixpoint();
4922 }
4923 // If the callee is known and can be used in IPO, we will update the
4924 // state based on the callee state in updateImpl.
4925 return;
4926 }
4927 if (NumCallees > 1) {
4928 indicatePessimisticFixpoint();
4929 return;
4930 }
4931
4932 RuntimeFunction RF = It->getSecond();
4933 switch (RF) {
4934 // All the functions we know are compatible with SPMD mode.
4935 case OMPRTL___kmpc_is_spmd_exec_mode:
4936 case OMPRTL___kmpc_distribute_static_fini:
4937 case OMPRTL___kmpc_for_static_fini:
4938 case OMPRTL___kmpc_global_thread_num:
4939 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4940 case OMPRTL___kmpc_get_hardware_num_blocks:
4941 case OMPRTL___kmpc_single:
4942 case OMPRTL___kmpc_end_single:
4943 case OMPRTL___kmpc_master:
4944 case OMPRTL___kmpc_end_master:
4945 case OMPRTL___kmpc_barrier:
4946 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4947 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4948 case OMPRTL___kmpc_error:
4949 case OMPRTL___kmpc_flush:
4950 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4951 case OMPRTL___kmpc_get_warp_size:
4952 case OMPRTL_omp_get_thread_num:
4953 case OMPRTL_omp_get_num_threads:
4954 case OMPRTL_omp_get_max_threads:
4955 case OMPRTL_omp_in_parallel:
4956 case OMPRTL_omp_get_dynamic:
4957 case OMPRTL_omp_get_cancellation:
4958 case OMPRTL_omp_get_nested:
4959 case OMPRTL_omp_get_schedule:
4960 case OMPRTL_omp_get_thread_limit:
4961 case OMPRTL_omp_get_supported_active_levels:
4962 case OMPRTL_omp_get_max_active_levels:
4963 case OMPRTL_omp_get_level:
4964 case OMPRTL_omp_get_ancestor_thread_num:
4965 case OMPRTL_omp_get_team_size:
4966 case OMPRTL_omp_get_active_level:
4967 case OMPRTL_omp_in_final:
4968 case OMPRTL_omp_get_proc_bind:
4969 case OMPRTL_omp_get_num_places:
4970 case OMPRTL_omp_get_num_procs:
4971 case OMPRTL_omp_get_place_proc_ids:
4972 case OMPRTL_omp_get_place_num:
4973 case OMPRTL_omp_get_partition_num_places:
4974 case OMPRTL_omp_get_partition_place_nums:
4975 case OMPRTL_omp_get_wtime:
4976 break;
4977 case OMPRTL___kmpc_distribute_static_init_4:
4978 case OMPRTL___kmpc_distribute_static_init_4u:
4979 case OMPRTL___kmpc_distribute_static_init_8:
4980 case OMPRTL___kmpc_distribute_static_init_8u:
4981 case OMPRTL___kmpc_for_static_init_4:
4982 case OMPRTL___kmpc_for_static_init_4u:
4983 case OMPRTL___kmpc_for_static_init_8:
4984 case OMPRTL___kmpc_for_static_init_8u: {
4985 // Check the schedule and allow static schedule in SPMD mode.
4986 unsigned ScheduleArgOpNo = 2;
4987 auto *ScheduleTypeCI =
4988 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4989 unsigned ScheduleTypeVal =
4990 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4991 switch (OMPScheduleType(ScheduleTypeVal)) {
4992 case OMPScheduleType::UnorderedStatic:
4993 case OMPScheduleType::UnorderedStaticChunked:
4994 case OMPScheduleType::OrderedDistribute:
4995 case OMPScheduleType::OrderedDistributeChunked:
4996 break;
4997 default:
4998 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4999 SPMDCompatibilityTracker.insert(&CB);
5000 break;
5001 };
5002 } break;
5003 case OMPRTL___kmpc_target_init:
5004 KernelInitCB = &CB;
5005 break;
5006 case OMPRTL___kmpc_target_deinit:
5007 KernelDeinitCB = &CB;
5008 break;
5009 case OMPRTL___kmpc_parallel_51:
5010 if (!handleParallel51(A, CB))
5011 indicatePessimisticFixpoint();
5012 return;
5013 case OMPRTL___kmpc_omp_task:
5014 // We do not look into tasks right now, just give up.
5015 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5016 SPMDCompatibilityTracker.insert(&CB);
5017 ReachedUnknownParallelRegions.insert(&CB);
5018 break;
5019 case OMPRTL___kmpc_alloc_shared:
5020 case OMPRTL___kmpc_free_shared:
5021 // Return without setting a fixpoint, to be resolved in updateImpl.
5022 return;
5023 default:
5024 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5025 // generally. However, they do not hide parallel regions.
5026 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5027 SPMDCompatibilityTracker.insert(&CB);
5028 break;
5029 }
5030 // All other OpenMP runtime calls will not reach parallel regions so they
5031 // can be safely ignored for now. Since it is a known OpenMP runtime call
5032 // we have now modeled all effects and there is no need for any update.
5033 indicateOptimisticFixpoint();
5034 };
5035
5036 const auto *AACE =
5037 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5038 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5039 CheckCallee(getAssociatedFunction(), 1);
5040 return;
5041 }
5042 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5043 for (auto *Callee : OptimisticEdges) {
5044 CheckCallee(Callee, OptimisticEdges.size());
5045 if (isAtFixpoint())
5046 break;
5047 }
5048 }
5049
5050 ChangeStatus updateImpl(Attributor &A) override {
5051 // TODO: Once we have call site specific value information we can provide
5052 // call site specific liveness information and then it makes
5053 // sense to specialize attributes for call sites arguments instead of
5054 // redirecting requests to the callee argument.
5055 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5056 KernelInfoState StateBefore = getState();
5057
5058 auto CheckCallee = [&](Function *F, int NumCallees) {
5059 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5060
5061 // If F is not a runtime function, propagate the AAKernelInfo of the
5062 // callee.
5063 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5064 const IRPosition &FnPos = IRPosition::function(*F);
5065 auto *FnAA =
5066 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5067 if (!FnAA)
5068 return indicatePessimisticFixpoint();
5069 if (getState() == FnAA->getState())
5070 return ChangeStatus::UNCHANGED;
5071 getState() = FnAA->getState();
5072 return ChangeStatus::CHANGED;
5073 }
5074 if (NumCallees > 1)
5075 return indicatePessimisticFixpoint();
5076
5077 CallBase &CB = cast<CallBase>(getAssociatedValue());
5078 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5079 if (!handleParallel51(A, CB))
5080 return indicatePessimisticFixpoint();
5081 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5082 : ChangeStatus::CHANGED;
5083 }
5084
5085 // F is a runtime function that allocates or frees memory, check
5086 // AAHeapToStack and AAHeapToShared.
5087 assert(
5088 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5089 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5090 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091
5092 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5095 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5096
5097 RuntimeFunction RF = It->getSecond();
5098
5099 switch (RF) {
5100 // If neither HeapToStack nor HeapToShared assume the call is removed,
5101 // assume SPMD incompatibility.
5102 case OMPRTL___kmpc_alloc_shared:
5103 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5104 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5105 SPMDCompatibilityTracker.insert(&CB);
5106 break;
5107 case OMPRTL___kmpc_free_shared:
5108 if ((!HeapToStackAA ||
5109 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5110 (!HeapToSharedAA ||
5111 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5112 SPMDCompatibilityTracker.insert(&CB);
5113 break;
5114 default:
5115 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5116 SPMDCompatibilityTracker.insert(&CB);
5117 }
5118 return ChangeStatus::CHANGED;
5119 };
5120
5121 const auto *AACE =
5122 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5123 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5124 if (Function *F = getAssociatedFunction())
5125 CheckCallee(F, /*NumCallees=*/1);
5126 } else {
5127 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5128 for (auto *Callee : OptimisticEdges) {
5129 CheckCallee(Callee, OptimisticEdges.size());
5130 if (isAtFixpoint())
5131 break;
5132 }
5133 }
5134
5135 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5136 : ChangeStatus::CHANGED;
5137 }
5138
5139 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5140 /// handled, if a problem occurred, false is returned.
5141 bool handleParallel51(Attributor &A, CallBase &CB) {
5142 const unsigned int NonWrapperFunctionArgNo = 5;
5143 const unsigned int WrapperFunctionArgNo = 6;
5144 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5145 ? NonWrapperFunctionArgNo
5146 : WrapperFunctionArgNo;
5147
5148 auto *ParallelRegion = dyn_cast<Function>(
5149 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5150 if (!ParallelRegion)
5151 return false;
5152
5153 ReachedKnownParallelRegions.insert(&CB);
5154 /// Check nested parallelism
5155 auto *FnAA = A.getAAFor<AAKernelInfo>(
5156 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5157 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5158 !FnAA->ReachedKnownParallelRegions.empty() ||
5159 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.empty();
5162 return true;
5163 }
5164};
5165
5166struct AAFoldRuntimeCall
5167 : public StateWrapper<BooleanState, AbstractAttribute> {
5168 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5169
5170 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5171
5172 /// Statistics are tracked as part of manifest for now.
5173 void trackStatistics() const override {}
5174
5175 /// Create an abstract attribute biew for the position \p IRP.
5176 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5177 Attributor &A);
5178
5179 /// See AbstractAttribute::getName()
5180 StringRef getName() const override { return "AAFoldRuntimeCall"; }
5181
5182 /// See AbstractAttribute::getIdAddr()
5183 const char *getIdAddr() const override { return &ID; }
5184
5185 /// This function should return true if the type of the \p AA is
5186 /// AAFoldRuntimeCall
5187 static bool classof(const AbstractAttribute *AA) {
5188 return (AA->getIdAddr() == &ID);
5189 }
5190
5191 static const char ID;
5192};
5193
5194struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5195 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5196 : AAFoldRuntimeCall(IRP, A) {}
5197
5198 /// See AbstractAttribute::getAsStr()
5199 const std::string getAsStr(Attributor *) const override {
5200 if (!isValidState())
5201 return "<invalid>";
5202
5203 std::string Str("simplified value: ");
5204
5205 if (!SimplifiedValue)
5206 return Str + std::string("none");
5207
5208 if (!*SimplifiedValue)
5209 return Str + std::string("nullptr");
5210
5211 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5212 return Str + std::to_string(CI->getSExtValue());
5213
5214 return Str + std::string("unknown");
5215 }
5216
5217 void initialize(Attributor &A) override {
5219 indicatePessimisticFixpoint();
5220
5221 Function *Callee = getAssociatedFunction();
5222
5223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5224 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5225 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5226 "Expected a known OpenMP runtime function");
5227
5228 RFKind = It->getSecond();
5229
5230 CallBase &CB = cast<CallBase>(getAssociatedValue());
5231 A.registerSimplificationCallback(
5233 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5234 bool &UsedAssumedInformation) -> std::optional<Value *> {
5235 assert((isValidState() || SimplifiedValue == nullptr) &&
5236 "Unexpected invalid state!");
5237
5238 if (!isAtFixpoint()) {
5239 UsedAssumedInformation = true;
5240 if (AA)
5241 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5242 }
5243 return SimplifiedValue;
5244 });
5245 }
5246
5247 ChangeStatus updateImpl(Attributor &A) override {
5248 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5249 switch (RFKind) {
5250 case OMPRTL___kmpc_is_spmd_exec_mode:
5251 Changed |= foldIsSPMDExecMode(A);
5252 break;
5253 case OMPRTL___kmpc_parallel_level:
5254 Changed |= foldParallelLevel(A);
5255 break;
5256 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5257 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5258 break;
5259 case OMPRTL___kmpc_get_hardware_num_blocks:
5260 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5261 break;
5262 default:
5263 llvm_unreachable("Unhandled OpenMP runtime function!");
5264 }
5265
5266 return Changed;
5267 }
5268
5269 ChangeStatus manifest(Attributor &A) override {
5270 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5271
5272 if (SimplifiedValue && *SimplifiedValue) {
5273 Instruction &I = *getCtxI();
5274 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5275 A.deleteAfterManifest(I);
5276
5277 CallBase *CB = dyn_cast<CallBase>(&I);
5278 auto Remark = [&](OptimizationRemark OR) {
5279 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5280 return OR << "Replacing OpenMP runtime call "
5281 << CB->getCalledFunction()->getName() << " with "
5282 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5283 return OR << "Replacing OpenMP runtime call "
5284 << CB->getCalledFunction()->getName() << ".";
5285 };
5286
5287 if (CB && EnableVerboseRemarks)
5288 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5289
5290 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5291 << **SimplifiedValue << "\n");
5292
5293 Changed = ChangeStatus::CHANGED;
5294 }
5295
5296 return Changed;
5297 }
5298
5299 ChangeStatus indicatePessimisticFixpoint() override {
5300 SimplifiedValue = nullptr;
5301 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5302 }
5303
5304private:
5305 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5306 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5307 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5308
5309 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5310 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5311 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5312 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5313
5314 if (!CallerKernelInfoAA ||
5315 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5316 return indicatePessimisticFixpoint();
5317
5318 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5319 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5320 DepClassTy::REQUIRED);
5321
5322 if (!AA || !AA->isValidState()) {
5323 SimplifiedValue = nullptr;
5324 return indicatePessimisticFixpoint();
5325 }
5326
5327 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5328 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5329 ++KnownSPMDCount;
5330 else
5331 ++AssumedSPMDCount;
5332 } else {
5333 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5334 ++KnownNonSPMDCount;
5335 else
5336 ++AssumedNonSPMDCount;
5337 }
5338 }
5339
5340 if ((AssumedSPMDCount + KnownSPMDCount) &&
5341 (AssumedNonSPMDCount + KnownNonSPMDCount))
5342 return indicatePessimisticFixpoint();
5343
5344 auto &Ctx = getAnchorValue().getContext();
5345 if (KnownSPMDCount || AssumedSPMDCount) {
5346 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5347 "Expected only SPMD kernels!");
5348 // All reaching kernels are in SPMD mode. Update all function calls to
5349 // __kmpc_is_spmd_exec_mode to 1.
5350 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5351 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5352 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5353 "Expected only non-SPMD kernels!");
5354 // All reaching kernels are in non-SPMD mode. Update all function
5355 // calls to __kmpc_is_spmd_exec_mode to 0.
5356 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5357 } else {
5358 // We have empty reaching kernels, therefore we cannot tell if the
5359 // associated call site can be folded. At this moment, SimplifiedValue
5360 // must be none.
5361 assert(!SimplifiedValue && "SimplifiedValue should be none");
5362 }
5363
5364 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5365 : ChangeStatus::CHANGED;
5366 }
5367
5368 /// Fold __kmpc_parallel_level into a constant if possible.
5369 ChangeStatus foldParallelLevel(Attributor &A) {
5370 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5371
5372 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5373 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5374
5375 if (!CallerKernelInfoAA ||
5376 !CallerKernelInfoAA->ParallelLevels.isValidState())
5377 return indicatePessimisticFixpoint();
5378
5379 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5380 return indicatePessimisticFixpoint();
5381
5382 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5383 assert(!SimplifiedValue &&
5384 "SimplifiedValue should keep none at this point");
5385 return ChangeStatus::UNCHANGED;
5386 }
5387
5388 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5389 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5390 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5391 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5392 DepClassTy::REQUIRED);
5393 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5394 return indicatePessimisticFixpoint();
5395
5396 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5397 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5398 ++KnownSPMDCount;
5399 else
5400 ++AssumedSPMDCount;
5401 } else {
5402 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5403 ++KnownNonSPMDCount;
5404 else
5405 ++AssumedNonSPMDCount;
5406 }
5407 }
5408
5409 if ((AssumedSPMDCount + KnownSPMDCount) &&
5410 (AssumedNonSPMDCount + KnownNonSPMDCount))
5411 return indicatePessimisticFixpoint();
5412
5413 auto &Ctx = getAnchorValue().getContext();
5414 // If the caller can only be reached by SPMD kernel entries, the parallel
5415 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5416 // kernel entries, it is 0.
5417 if (AssumedSPMDCount || KnownSPMDCount) {
5418 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5419 "Expected only SPMD kernels!");
5420 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5421 } else {
5422 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5423 "Expected only non-SPMD kernels!");
5424 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5425 }
5426 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5427 : ChangeStatus::CHANGED;
5428 }
5429
5430 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5431 // Specialize only if all the calls agree with the attribute constant value
5432 int32_t CurrentAttrValue = -1;
5433 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5434
5435 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5436 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5437
5438 if (!CallerKernelInfoAA ||
5439 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5440 return indicatePessimisticFixpoint();
5441
5442 // Iterate over the kernels that reach this function
5443 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5444 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5445
5446 if (NextAttrVal == -1 ||
5447 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5448 return indicatePessimisticFixpoint();
5449 CurrentAttrValue = NextAttrVal;
5450 }
5451
5452 if (CurrentAttrValue != -1) {
5453 auto &Ctx = getAnchorValue().getContext();
5454 SimplifiedValue =
5455 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5456 }
5457 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5458 : ChangeStatus::CHANGED;
5459 }
5460
5461 /// An optional value the associated value is assumed to fold to. That is, we
5462 /// assume the associated value (which is a call) can be replaced by this
5463 /// simplified value.
5464 std::optional<Value *> SimplifiedValue;
5465
5466 /// The runtime function kind of the callee of the associated call site.
5467 RuntimeFunction RFKind;
5468};
5469
5470} // namespace
5471
5472/// Register folding callsite
5473void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5474 auto &RFI = OMPInfoCache.RFIs[RF];
5475 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5476 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5477 if (!CI)
5478 return false;
5479 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5480 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5481 DepClassTy::NONE, /* ForceUpdate */ false,
5482 /* UpdateAfterInit */ false);
5483 return false;
5484 });
5485}
5486
5487void OpenMPOpt::registerAAs(bool IsModulePass) {
5488 if (SCC.empty())
5489 return;
5490
5491 if (IsModulePass) {
5492 // Ensure we create the AAKernelInfo AAs first and without triggering an
5493 // update. This will make sure we register all value simplification
5494 // callbacks before any other AA has the chance to create an AAValueSimplify
5495 // or similar.
5496 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5497 A.getOrCreateAAFor<AAKernelInfo>(
5498 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5499 DepClassTy::NONE, /* ForceUpdate */ false,
5500 /* UpdateAfterInit */ false);
5501 return false;
5502 };
5503 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5504 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5505 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5506
5507 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5508 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5510 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5511 }
5512
5513 // Create CallSite AA for all Getters.
5514 if (DeduceICVValues) {
5515 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5516 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5517
5518 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5519
5520 auto CreateAA = [&](Use &U, Function &Caller) {
5521 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5522 if (!CI)
5523 return false;
5524
5525 auto &CB = cast<CallBase>(*CI);
5526
5527 IRPosition CBPos = IRPosition::callsite_function(CB);
5528 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5529 return false;
5530 };
5531
5532 GetterRFI.foreachUse(SCC, CreateAA);
5533 }
5534 }
5535
5536 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5537 // every function if there is a device kernel.
5538 if (!isOpenMPDevice(M))
5539 return;
5540
5541 for (auto *F : SCC) {
5542 if (F->isDeclaration())
5543 continue;
5544
5545 // We look at internal functions only on-demand but if any use is not a
5546 // direct call or outside the current set of analyzed functions, we have
5547 // to do it eagerly.
5548 if (F->hasLocalLinkage()) {
5549 if (llvm::all_of(F->uses(), [this](const Use &U) {
5550 const auto *CB = dyn_cast<CallBase>(U.getUser());
5551 return CB && CB->isCallee(&U) &&
5552 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5553 }))
5554 continue;
5555 }
5556 registerAAsForFunction(A, *F);
5557 }
5558}
5559
5560void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5562 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5563 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5565 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5566 if (F.hasFnAttribute(Attribute::Convergent))
5567 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5568
5569 for (auto &I : instructions(F)) {
5570 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5571 bool UsedAssumedInformation = false;
5572 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5573 UsedAssumedInformation, AA::Interprocedural);
5574 A.getOrCreateAAFor<AAAddressSpace>(
5575 IRPosition::value(*LI->getPointerOperand()));
5576 continue;
5577 }
5578 if (auto *CI = dyn_cast<CallBase>(&I)) {
5579 if (CI->isIndirectCall())
5580 A.getOrCreateAAFor<AAIndirectCallInfo>(
5582 }
5583 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5584 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5585 A.getOrCreateAAFor<AAAddressSpace>(
5586 IRPosition::value(*SI->getPointerOperand()));
5587 continue;
5588 }
5589 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5590 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5591 continue;
5592 }
5593 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5594 if (II->getIntrinsicID() == Intrinsic::assume) {
5595 A.getOrCreateAAFor<AAPotentialValues>(
5596 IRPosition::value(*II->getArgOperand(0)));
5597 continue;
5598 }
5599 }
5600 }
5601}
5602
5603const char AAICVTracker::ID = 0;
5604const char AAKernelInfo::ID = 0;
5605const char AAExecutionDomain::ID = 0;
5606const char AAHeapToShared::ID = 0;
5607const char AAFoldRuntimeCall::ID = 0;
5608
5609AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5610 Attributor &A) {
5611 AAICVTracker *AA = nullptr;
5612 switch (IRP.getPositionKind()) {
5617 llvm_unreachable("ICVTracker can only be created for function position!");
5619 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5620 break;
5622 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5623 break;
5625 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5626 break;
5628 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5629 break;
5630 }
5631
5632 return *AA;
5633}
5634
5636 Attributor &A) {
5637 AAExecutionDomainFunction *AA = nullptr;
5638 switch (IRP.getPositionKind()) {
5647 "AAExecutionDomain can only be created for function position!");
5649 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5650 break;
5651 }
5652
5653 return *AA;
5654}
5655
5656AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5657 Attributor &A) {
5658 AAHeapToSharedFunction *AA = nullptr;
5659 switch (IRP.getPositionKind()) {
5668 "AAHeapToShared can only be created for function position!");
5670 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5671 break;
5672 }
5673
5674 return *AA;
5675}
5676
5677AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5678 Attributor &A) {
5679 AAKernelInfo *AA = nullptr;
5680 switch (IRP.getPositionKind()) {
5687 llvm_unreachable("KernelInfo can only be created for function position!");
5689 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5690 break;
5692 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5693 break;
5694 }
5695
5696 return *AA;
5697}
5698
5699AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5700 Attributor &A) {
5701 AAFoldRuntimeCall *AA = nullptr;
5702 switch (IRP.getPositionKind()) {
5710 llvm_unreachable("KernelInfo can only be created for call site position!");
5712 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5713 break;
5714 }
5715
5716 return *AA;
5717}
5718
5720 if (!containsOpenMP(M))
5721 return PreservedAnalyses::all();
5723 return PreservedAnalyses::all();
5724
5727 KernelSet Kernels = getDeviceKernels(M);
5728
5730 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5731
5732 auto IsCalled = [&](Function &F) {
5733 if (Kernels.contains(&F))
5734 return true;
5735 return !F.use_empty();
5736 };
5737
5738 auto EmitRemark = [&](Function &F) {
5739 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5740 ORE.emit([&]() {
5741 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5742 return ORA << "Could not internalize function. "
5743 << "Some optimizations may not be possible. [OMP140]";
5744 });
5745 };
5746
5747 bool Changed = false;
5748
5749 // Create internal copies of each function if this is a kernel Module. This
5750 // allows iterprocedural passes to see every call edge.
5751 DenseMap<Function *, Function *> InternalizedMap;
5752 if (isOpenMPDevice(M)) {
5753 SmallPtrSet<Function *, 16> InternalizeFns;
5754 for (Function &F : M)
5755 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5758 InternalizeFns.insert(&F);
5759 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5760 EmitRemark(F);
5761 }
5762 }
5763
5764 Changed |=
5765 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5766 }
5767
5768 // Look at every function in the Module unless it was internalized.
5769 SetVector<Function *> Functions;
5771 for (Function &F : M)
5772 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5773 SCC.push_back(&F);
5774 Functions.insert(&F);
5775 }
5776
5777 if (SCC.empty())
5779
5780 AnalysisGetter AG(FAM);
5781
5782 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5783 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5784 };
5785
5786 BumpPtrAllocator Allocator;
5787 CallGraphUpdater CGUpdater;
5788
5789 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5792 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5793
5794 unsigned MaxFixpointIterations =
5796
5797 AttributorConfig AC(CGUpdater);
5799 AC.IsModulePass = true;
5800 AC.RewriteSignatures = false;
5801 AC.MaxFixpointIterations = MaxFixpointIterations;
5802 AC.OREGetter = OREGetter;
5803 AC.PassName = DEBUG_TYPE;
5804 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5805 AC.IPOAmendableCB = [](const Function &F) {
5806 return F.hasFnAttribute("kernel");
5807 };
5808
5809 Attributor A(Functions, InfoCache, AC);
5810
5811 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5812 Changed |= OMPOpt.run(true);
5813
5814 // Optionally inline device functions for potentially better performance.
5816 for (Function &F : M)
5817 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5818 !F.hasFnAttribute(Attribute::NoInline))
5819 F.addFnAttr(Attribute::AlwaysInline);
5820
5822 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5823
5824 if (Changed)
5825 return PreservedAnalyses::none();
5826
5827 return PreservedAnalyses::all();
5828}
5829
5832 LazyCallGraph &CG,
5833 CGSCCUpdateResult &UR) {
5834 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5835 return PreservedAnalyses::all();
5837 return PreservedAnalyses::all();
5838
5840 // If there are kernels in the module, we have to run on all SCC's.
5841 for (LazyCallGraph::Node &N : C) {
5842 Function *Fn = &N.getFunction();
5843 SCC.push_back(Fn);
5844 }
5845
5846 if (SCC.empty())
5847 return PreservedAnalyses::all();
5848
5849 Module &M = *C.begin()->getFunction().getParent();
5850
5852 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5853
5855 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5856
5857 AnalysisGetter AG(FAM);
5858
5859 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5860 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5861 };
5862
5863 BumpPtrAllocator Allocator;
5864 CallGraphUpdater CGUpdater;
5865 CGUpdater.initialize(CG, C, AM, UR);
5866
5867 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5871 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5872 /*CGSCC*/ &Functions, PostLink);
5873
5874 unsigned MaxFixpointIterations =
5876
5877 AttributorConfig AC(CGUpdater);
5879 AC.IsModulePass = false;
5880 AC.RewriteSignatures = false;
5881 AC.MaxFixpointIterations = MaxFixpointIterations;
5882 AC.OREGetter = OREGetter;
5883 AC.PassName = DEBUG_TYPE;
5884 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5885
5886 Attributor A(Functions, InfoCache, AC);
5887
5888 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5889 bool Changed = OMPOpt.run(false);
5890
5892 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5893
5894 if (Changed)
5895 return PreservedAnalyses::none();
5896
5897 return PreservedAnalyses::all();
5898}
5899
5901 return Fn.hasFnAttribute("kernel");
5902}
5903
5905 KernelSet Kernels;
5906
5907 for (Function &F : M)
5908 if (F.hasKernelCallingConv()) {
5909 // We are only interested in OpenMP target regions. Others, such as
5910 // kernels generated by CUDA but linked together, are not interesting to
5911 // this pass.
5912 if (isOpenMPKernel(F)) {
5913 ++NumOpenMPTargetRegionKernels;
5914 Kernels.insert(&F);
5915 } else
5916 ++NumNonOpenMPTargetRegionKernels;
5917 }
5918
5919 return Kernels;
5920}
5921
5923 Metadata *MD = M.getModuleFlag("openmp");
5924 if (!MD)
5925 return false;
5926
5927 return true;
5928}
5929
5931 Metadata *MD = M.getModuleFlag("openmp-device");
5932 if (!MD)
5933 return false;
5934
5935 return true;
5936}
@ Generic
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
This file defines the DenseSet and SmallDenseSet classes.
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
Loop::LoopBounds::Direction Direction
Definition LoopInfo.cpp:243
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
Machine Check Debug Module
This file provides utility analysis objects describing memory locations.
#define T
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
std::pair< BasicBlock *, BasicBlock * > Edge
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition Value.cpp:480
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:167
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static const int BlockSize
Definition TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator end()
Definition BasicBlock.h:472
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:475
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:477
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
void setCallingConv(CallingConv::ID CC)
bool arg_empty() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
unsigned arg_size() const
AttributeList getAttributes() const
Return the attributes for this call.
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
bool isArgOperand(const Use *U) const
bool hasOperandBundles() const
Return true if this User has any operand bundles.
LLVM_ABI Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition InstrTypes.h:707
@ ICMP_NE
not equal
Definition InstrTypes.h:700
static LLVM_ABI Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
static LLVM_ABI Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition Constants.h:193
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition Constants.h:169
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:187
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:214
LLVM_ABI Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
A proxy from a FunctionAnalysisManager to an SCC.
const BasicBlock & getEntryBlock() const
Definition Function.h:807
const BasicBlock & front() const
Definition Function.h:858
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:359
Argument * getArg(unsigned i) const
Definition Function.h:884
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition Function.cpp:727
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition Globals.cpp:316
bool hasLocalLinkage() const
Module * getParent()
Get the module that this global value is contained inside of...
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition Globals.cpp:511
LLVM_ABI bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
LLVM_ABI bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
LLVM_ABI bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
LLVM_ABI void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
LLVM_ABI StringRef getName() const
Return the name of the corresponding LLVM basic block, or an empty string.
Root of the metadata hierarchy.
Definition Metadata.h:63
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const Triple & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition Module.h:281
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition SetVector.h:59
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:104
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:279
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:168
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
iterator begin() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:269
Triple - Helper class for working with autoconf configuration names.
Definition Triple.h:47
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:21
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:390
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
User * user_back()
Definition Value.h:412
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:701
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:134
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:359
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Abstract Attribute helper functions.
Definition Attributor.h:165
LLVM_ABI bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
LLVM_ABI bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
@ Interprocedural
Definition Attributor.h:184
LLVM_ABI bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
initializer< Ty > init(const Ty &Val)
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
InternalControlVar
IDs for all Internal Control Variables (ICVs).
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
SetVector< Kernel > KernelSet
Set of kernels in the module.
Definition OpenMPOpt.h:24
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition OpenMPOpt.h:21
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
bool empty() const
Definition BasicBlock.h:101
iterator end() const
Definition BasicBlock.h:89
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:330
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1727
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1685
bool succ_empty(const Instruction *I)
Definition CFG.h:256
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:296
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2113
constexpr from_range_t from_range
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
AnalysisManager< LazyCallGraph::SCC, LazyCallGraph & > CGSCCAnalysisManager
The CGSCC analysis manager.
@ ThinLTOPostLink
ThinLTO postlink (backend compile) phase.
Definition Pass.h:83
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
Definition Pass.h:87
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
Definition Pass.h:81
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
BumpPtrAllocatorImpl BumpPtrAllocator
The standard BumpPtrAllocator which just uses the default template parameters.
Definition Allocator.h:383
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1956
ArrayRef(const T &OneElt) -> ArrayRef< T >
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition Attributor.h:496
LLVM_ABI Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
@ OPTIONAL
The target may be valid if the source is not.
Definition Attributor.h:508
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
#define N
static LLVM_ABI AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
AAExecutionDomain(const IRPosition &IRP, Attributor &A)
static LLVM_ABI const char ID
Unique ID (due to the unique address)
AccessKind
Simple enum to distinguish read/write/read-write accesses.
StateType::base_t MemoryLocationsKind
static LLVM_ABI bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
Base struct for all "concrete attribute" deductions.
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Wrapper for FunctionAnalysisManager.
Configuration for the Attributor.
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
const char * PassName
}
OptimizationRemarkGetter OREGetter
IPOAmendableCBTy IPOAmendableCB
bool IsModulePass
Is the user of the Attributor a module pass or not.
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
The fixpoint analysis framework that orchestrates the attribute deduction.
static LLVM_ABI bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Value * >( const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
std::function< std::optional< Constant * >( const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
static LLVM_ABI bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
Simple wrapper for a single bit (boolean) state.
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition Attributor.h:593
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition Attributor.h:661
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition Attributor.h:643
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition Attributor.h:617
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition Attributor.h:629
@ IRP_ARGUMENT
An attribute for a function argument.
Definition Attributor.h:607
@ IRP_RETURNED
An attribute for the function return value.
Definition Attributor.h:603
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition Attributor.h:606
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition Attributor.h:604
@ IRP_FUNCTION
An attribute for a function (scope).
Definition Attributor.h:605
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition Attributor.h:601
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition Attributor.h:608
@ IRP_INVALID
An invalid position.
Definition Attributor.h:600
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition Attributor.h:636
Kind getPositionKind() const
Return the associated position kind.
Definition Attributor.h:889
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition Attributor.h:656
Data structure to hold cached (LLVM-IR) information.
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...