LLVM 22.0.0git
NVPTXForwardParams.cpp
Go to the documentation of this file.
1//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PTX supports 2 methods of accessing device function parameters:
10//
11// - "simple" case: If a parameters is only loaded, and all loads can address
12// the parameter via a constant offset, then the parameter may be loaded via
13// the ".param" address space. This case is not possible if the parameters
14// is stored to or has it's address taken. This method is preferable when
15// possible. Ex:
16//
17// ld.param.u32 %r1, [foo_param_1];
18// ld.param.u32 %r2, [foo_param_1+4];
19//
20// - "move param" case: For more complex cases the address of the param may be
21// placed in a register via a "mov" instruction. This "mov" also implicitly
22// moves the param to the ".local" address space and allows for it to be
23// written to. This essentially defers the responsibilty of the byval copy
24// to the PTX calling convention.
25//
26// mov.b64 %rd1, foo_param_0;
27// st.local.u32 [%rd1], 42;
28// add.u64 %rd3, %rd1, %rd2;
29// ld.local.u32 %r2, [%rd3];
30//
31// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32// parameters will use the "move param" case and the local address space. This
33// pass is responsible for switching to the "simple" case when possible, as it
34// is more efficient.
35//
36// We do this by simply traversing uses of the param "mov" instructions an
37// trivially checking if they are all loads.
38//
39//===----------------------------------------------------------------------===//
40
41#include "NVPTX.h"
49
50using namespace llvm;
51
55 switch (U.getOpcode()) {
56 case NVPTX::LD_i16:
57 case NVPTX::LD_i32:
58 case NVPTX::LD_i64:
59 case NVPTX::LDV_i16_v2:
60 case NVPTX::LDV_i16_v4:
61 case NVPTX::LDV_i32_v2:
62 case NVPTX::LDV_i32_v4:
63 case NVPTX::LDV_i64_v2:
64 case NVPTX::LDV_i64_v4: {
65 LoadInsts.push_back(&U);
66 return true;
67 }
68 case NVPTX::cvta_local:
69 case NVPTX::cvta_local_64:
70 case NVPTX::cvta_to_local:
71 case NVPTX::cvta_to_local_64: {
72 for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
73 if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
74 return false;
75
76 RemoveList.push_back(&U);
77 return true;
78 }
79 default:
80 return false;
81 }
82}
83
86 SmallVector<MachineInstr *, 16> MaybeRemoveList;
88
89 for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
90 if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
91 return false;
92
93 RemoveList.append(MaybeRemoveList);
94 RemoveList.push_back(&Mov);
95
96 const MachineOperand *ParamSymbol = Mov.uses().begin();
97 assert(ParamSymbol->isSymbol());
98
99 constexpr unsigned LDInstBasePtrOpIdx = 5;
100 constexpr unsigned LDInstAddrSpaceOpIdx = 2;
101 for (auto *LI : LoadInsts) {
102 (LI->uses().begin() + LDInstBasePtrOpIdx)
103 ->ChangeToES(ParamSymbol->getSymbolName());
104 (LI->uses().begin() + LDInstAddrSpaceOpIdx)
105 ->ChangeToImmediate(NVPTX::AddressSpace::Param);
106 }
107 return true;
108}
109
111 const auto &MRI = MF.getRegInfo();
112
113 bool Changed = false;
115 for (auto &MI : make_early_inc_range(*MF.begin()))
116 if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
117 MI.getOpcode() == NVPTX::MOV64_PARAM)
118 Changed |= eliminateMove(MI, MRI, RemoveList);
119
120 for (auto *MI : RemoveList)
121 MI->eraseFromParent();
122
123 return Changed;
124}
125
126/// ----------------------------------------------------------------------------
127/// Pass (Manager) Boilerplate
128/// ----------------------------------------------------------------------------
129
130namespace {
131struct NVPTXForwardParamsPass : public MachineFunctionPass {
132 static char ID;
133 NVPTXForwardParamsPass() : MachineFunctionPass(ID) {}
134
135 bool runOnMachineFunction(MachineFunction &MF) override;
136
137 void getAnalysisUsage(AnalysisUsage &AU) const override {
139 }
140};
141} // namespace
142
143char NVPTXForwardParamsPass::ID = 0;
144
145INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
146 "NVPTX Forward Params", false, false)
147
148bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
149 return forwardDeviceParams(MF);
150}
151
153 return new NVPTXForwardParamsPass();
154}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
IRTranslator LLVM IR MI
static bool forwardDeviceParams(MachineFunction &MF)
static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI, SmallVectorImpl< MachineInstr * > &RemoveList, SmallVectorImpl< MachineInstr * > &LoadInsts)
static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI, SmallVectorImpl< MachineInstr * > &RemoveList)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
This file defines the SmallVector class.
Represent the analysis usage information of a pass.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
Definition: MachineInstr.h:72
mop_iterator operands_begin()
Definition: MachineInstr.h:687
mop_range uses()
Returns all operands which may be register uses.
Definition: MachineInstr.h:731
MachineOperand class - Representation of each machine instruction operand.
bool isSymbol() const
isSymbol - Tests if this is a MO_ExternalSymbol operand.
const char * getSymbolName() const
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:684
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
IteratorT begin() const
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
MachineFunctionPass * createNVPTXForwardParamsPass()