150#include "llvm/IR/IntrinsicsNVPTX.h"
160#define DEBUG_TYPE "nvptx-lower-args"
172 return "Lower pointer arguments of CUDA kernels";
180char NVPTXLowerArgsLegacyPass::ID = 1;
183 "Lower arguments (NVPTX)",
false,
false)
211 bool IsGridConstant) {
212 Instruction *
I = dyn_cast<Instruction>(OldUse->getUser());
213 assert(
I &&
"OldUse must be in an instruction");
222 auto CloneInstInParamAS = [HasCvtaParam,
223 IsGridConstant](
const IP &
I) ->
Value * {
224 if (
auto *LI = dyn_cast<LoadInst>(
I.OldInstruction)) {
225 LI->setOperand(0,
I.NewParam);
228 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(
I.OldInstruction)) {
231 GEP->getSourceElementType(),
I.NewParam, Indices,
GEP->getName(),
233 NewGEP->setIsInBounds(
GEP->isInBounds());
236 if (
auto *BC = dyn_cast<BitCastInst>(
I.OldInstruction)) {
238 return BitCastInst::Create(BC->getOpcode(),
I.NewParam, NewBCType,
239 BC->getName(), BC->getIterator());
241 if (
auto *ASC = dyn_cast<AddrSpaceCastInst>(
I.OldInstruction)) {
247 if (
auto *
MI = dyn_cast<MemTransferInst>(
I.OldInstruction)) {
248 if (
MI->getRawSource() ==
I.OldUse->get()) {
254 ID,
MI->getRawDest(),
MI->getDestAlign(),
I.NewParam,
255 MI->getSourceAlign(),
MI->getLength(),
MI->isVolatile());
256 for (
unsigned I : {0, 1})
257 if (
uint64_t Bytes =
MI->getParamDereferenceableBytes(
I))
258 B->addDereferenceableParamAttr(
I, Bytes);
266 auto GetParamAddrCastToGeneric =
272 auto *ParamInGenericAS =
273 GetParamAddrCastToGeneric(
I.NewParam,
I.OldInstruction);
276 if (
auto *
PHI = dyn_cast<PHINode>(
I.OldInstruction)) {
278 if (V.get() ==
I.OldUse->get())
279 PHI->setIncomingValue(
Idx, ParamInGenericAS);
282 if (
auto *SI = dyn_cast<SelectInst>(
I.OldInstruction)) {
283 if (SI->getTrueValue() ==
I.OldUse->get())
284 SI->setTrueValue(ParamInGenericAS);
285 if (SI->getFalseValue() ==
I.OldUse->get())
286 SI->setFalseValue(ParamInGenericAS);
291 if (IsGridConstant) {
292 if (
auto *CI = dyn_cast<CallInst>(
I.OldInstruction)) {
293 I.OldUse->set(ParamInGenericAS);
296 if (
auto *SI = dyn_cast<StoreInst>(
I.OldInstruction)) {
298 if (SI->getValueOperand() ==
I.OldUse->get())
299 SI->setOperand(0, ParamInGenericAS);
302 if (
auto *PI = dyn_cast<PtrToIntInst>(
I.OldInstruction)) {
303 if (PI->getPointerOperand() ==
I.OldUse->get())
304 PI->setOperand(0, ParamInGenericAS);
315 while (!ItemsToConvert.
empty()) {
317 Value *NewInst = CloneInstInParamAS(
I);
319 if (NewInst && NewInst !=
I.OldInstruction) {
323 for (
Use &U :
I.OldInstruction->uses())
324 ItemsToConvert.
push_back({&U, cast<Instruction>(U.getUser()), NewInst});
326 InstructionsToDelete.
push_back(
I.OldInstruction);
338 I->eraseFromParent();
353 const Align NewArgAlign =
357 if (CurArgAlign >= NewArgAlign)
361 <<
" instead of " << CurArgAlign.
value() <<
" for " << *Arg
380 std::queue<LoadContext> Worklist;
381 Worklist.push({ArgInParamAS, 0});
383 while (!Worklist.empty()) {
384 LoadContext Ctx = Worklist.
front();
387 for (
User *CurUser : Ctx.InitialVal->users()) {
388 if (
auto *
I = dyn_cast<LoadInst>(CurUser))
390 else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser))
391 Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
392 else if (
auto *
I = dyn_cast<GetElementPtrInst>(CurUser)) {
393 APInt OffsetAccumulated =
396 if (!
I->accumulateConstantOffset(
DL, OffsetAccumulated))
401 assert(
Offset != OffsetLimit &&
"Expect Offset less than UINT64_MAX");
403 Worklist.push({
I, Ctx.Offset +
Offset});
408 for (Load &CurLoad : Loads) {
409 Align NewLoadAlign(std::gcd(NewArgAlign.
value(), CurLoad.Offset));
410 Align CurLoadAlign = CurLoad.Inst->getAlign();
411 CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
422 &Arg, {}, Arg.
getName() +
".param");
439 ArgUseChecker(
const DataLayout &
DL,
bool IsGridConstant)
443 assert(
A.getType()->isPointerTy());
444 IntegerType *IntIdxTy = cast<IntegerType>(
DL.getIndexType(
A.getType()));
445 IsOffsetKnown =
false;
448 Conditionals.
clear();
457 while (!(Worklist.empty() || PI.isAborted())) {
458 UseToVisit ToVisit = Worklist.pop_back_val();
459 U = ToVisit.UseAndIsOffsetKnown.getPointer();
461 if (isa<PHINode>(
I) || isa<SelectInst>(
I))
467 LLVM_DEBUG(
dbgs() <<
"Argument pointer escaped: " << *PI.getEscapingInst()
469 else if (PI.isAborted())
470 LLVM_DEBUG(
dbgs() <<
"Pointer use needs a copy: " << *PI.getAbortingInst()
473 <<
" conditionals\n");
479 if (
U->get() ==
SI.getValueOperand())
480 return PI.setEscapedAndAborted(&SI);
484 return PI.setAborted(&SI);
490 return PI.setEscapedAndAborted(&ASC);
500 assert(isa<PHINode>(
I) || isa<SelectInst>(
I));
507 if (*U ==
II.getRawDest() && !IsGridConstant)
540 IRB.CreateMemCpy(AllocA, AllocA->
getAlign(), ArgInParam, AllocA->
getAlign(),
548 const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
555 ArgUseChecker AUC(
DL, IsGridConstant);
556 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
557 bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
559 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
567 for (
Use *U : UsesToUpdate)
572 cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
584 if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
585 LLVM_DEBUG(
dbgs() <<
"Using non-copy pointer to " << *Arg <<
"\n");
605 ParamSpaceArg->setOperand(0, Arg);
607 copyByValParam(*Func, *Arg);
621 InsertPt = ++cast<Instruction>(
Ptr)->getIterator();
622 assert(InsertPt != InsertPt->getParent()->end() &&
623 "We don't call this function with Ptr being a terminator.");
627 Ptr, PointerType::get(
Ptr->getContext(), AS),
Ptr->getName(), InsertPt);
629 Ptr->getName(), InsertPt);
631 Ptr->replaceAllUsesWith(PtrInGeneric);
647 auto HandleIntToPtr = [](
Value &V) {
648 if (
llvm::all_of(V.users(), [](
User *U) { return isa<IntToPtrInst>(U); })) {
650 for (
User *U : UsersToUpdate)
654 if (TM.getDrvInterface() == NVPTX::CUDA) {
658 if (
LoadInst *LI = dyn_cast<LoadInst>(&
I)) {
659 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
661 if (
Argument *Arg = dyn_cast<Argument>(UO)) {
664 if (LI->getType()->isPointerTy())
676 LLVM_DEBUG(
dbgs() <<
"Lowering kernel args of " <<
F.getName() <<
"\n");
681 TM.getDrvInterface() == NVPTX::CUDA) {
690 LLVM_DEBUG(
dbgs() <<
"Lowering function args of " <<
F.getName() <<
"\n");
693 cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
707bool NVPTXLowerArgsLegacyPass::runOnFunction(
Function &
F) {
712 return new NVPTXLowerArgsLegacyPass();
716 LLVM_DEBUG(
dbgs() <<
"Creating a copy of byval args of " <<
F.getName()
718 bool Changed =
false;
723 copyByValParam(
F, Arg);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
NVPTX address space definition.
static bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F)
nvptx lower Lower arguments(NVPTX)"
static CallInst * createNVVMInternalAddrspaceWrap(IRBuilder<> &IRB, Argument &Arg)
static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, const NVPTXTargetLowering *TLI)
static bool copyFunctionByValArgs(Function &F)
static void markPointerAsAS(Value *Ptr, const unsigned AS)
nvptx lower Lower static false void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam, bool IsGridConstant)
static bool processFunction(Function &F, NVPTXTargetMachine &TM)
static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F)
static void markPointerAsGlobal(Value *Ptr)
static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file provides a collection of visitors which walk the (instruction) uses of a pointer.
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
This class represents a conversion between pointers from one address space to another.
unsigned getDestAddressSpace() const
Returns the address space of the result.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
LLVM_ABI std::optional< TypeSize > getAllocationSize(const DataLayout &DL) const
Get allocation size in bytes.
void setAlignment(Align Align)
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
This class represents an incoming formal argument to a Function.
LLVM_ABI void addAttr(Attribute::AttrKind Kind)
LLVM_ABI bool hasByValAttr() const
Return true if this argument has the byval attribute.
LLVM_ABI void removeAttr(Attribute::AttrKind Kind)
Remove attributes from an argument.
const Function * getParent() const
LLVM_ABI Type * getParamByValType() const
If this is a byval argument, return its type.
LLVM_ABI MaybeAlign getParamAlign() const
If this is a byval or inalloca argument, return its alignment.
static LLVM_ABI Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
void addRetAttr(Attribute::AttrKind Kind)
Adds the attribute to the return value.
This class represents a function call, abstracting a target machine's calling convention.
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
const BasicBlock & getEntryBlock() const
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
LLVM_ABI CallInst * CreateMemTransferInst(Intrinsic::ID IntrID, Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, Value *Size, bool isVolatile=false, const AAMDNodes &AAInfo=AAMDNodes())
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
void visitPtrToIntInst(PtrToIntInst &I)
void visit(Iterator Start, Iterator End)
void visitPHINode(PHINode &I)
void visitAddrSpaceCastInst(AddrSpaceCastInst &I)
void visitMemTransferInst(MemTransferInst &I)
void visitMemSetInst(MemSetInst &I)
void visitSelectInst(SelectInst &I)
Class to represent integer types.
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
An instruction for reading from memory.
This class wraps the llvm.memset and llvm.memset.inline intrinsics.
This class wraps the llvm.memcpy/memmove intrinsics.
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents a cast from a pointer to an integer.
A base class for visitors over the uses of a pointer value.
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC)
void visitStoreInst(StoreInst &SI)
void visitPtrToIntInst(PtrToIntInst &I)
This class represents the LLVM 'select' instruction.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Class to represent struct types.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isPointerTy() const
True if this is an instance of PointerType.
bool isIntegerTy() const
True if this is an instance of IntegerType.
A Use represents the edge between a Value definition and its users.
void setOperand(unsigned i, Value *Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
void enqueueUsers(Value &I)
Enqueue the users of this instruction in the visit worklist.
PtrInfo PI
The info collected about the pointer being visited thus far.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
FunctionPass * createNVPTXLowerArgsPass()
auto reverse(ContainerTy &&C)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isParamGridConstant(const Argument &Arg)
bool isKernelFunction(const Function &F)
iterator_range< pointer_iterator< WrappedIteratorT > > make_pointer_range(RangeT &&Range)
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
This struct is a compact representation of a valid (non-zero power of two) alignment.
uint64_t value() const
This is a hole in the type system and should not be abused.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)