LLVM 22.0.0git
SMEPeepholeOpt.cpp
Go to the documentation of this file.
1//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass tries to remove back-to-back (smstart, smstop) and
9// (smstop, smstart) sequences. The pass is conservative when it cannot
10// determine that it is safe to remove these sequences.
11//===----------------------------------------------------------------------===//
12
13#include "AArch64InstrInfo.h"
15#include "AArch64Subtarget.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "aarch64-sme-peephole-opt"
25
26namespace {
27
28struct SMEPeepholeOpt : public MachineFunctionPass {
29 static char ID;
30
31 SMEPeepholeOpt() : MachineFunctionPass(ID) {}
32
33 bool runOnMachineFunction(MachineFunction &MF) override;
34
35 StringRef getPassName() const override {
36 return "SME Peephole Optimization pass";
37 }
38
39 void getAnalysisUsage(AnalysisUsage &AU) const override {
40 AU.setPreservesCFG();
42 }
43
44 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
45 bool &HasRemovedAllSMChanges) const;
46 bool visitRegSequence(MachineInstr &MI);
47};
48
49char SMEPeepholeOpt::ID = 0;
50
51} // end anonymous namespace
52
54 return MI->getOpcode() == AArch64::MSRpstatePseudo;
55}
56
58 const MachineInstr *MI2) {
59 // We only consider the same type of streaming mode change here, i.e.
60 // start/stop SM, or start/stop ZA pairs.
61 if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
62 return false;
63
64 // One must be 'start', the other must be 'stop'
65 if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
66 return false;
67
68 bool IsConditional = isConditionalStartStop(MI2);
69 if (isConditionalStartStop(MI1) != IsConditional)
70 return false;
71
72 if (!IsConditional)
73 return true;
74
75 // Check to make sure the conditional start/stop pairs are identical.
76 if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
77 return false;
78
79 // Ensure reg masks are identical.
80 if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
81 return false;
82
83 // Only consider conditional start/stop pairs which read the same register
84 // holding the original value of pstate.sm. This is somewhat over conservative
85 // as all conditional streaming mode changes only look at the state on entry
86 // to the function.
87 if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
88 Register Reg1 = MI1->getOperand(3).getReg();
89 Register Reg2 = MI2->getOperand(3).getReg();
90 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
91 return false;
92 }
93
94 return true;
95}
96
98 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
99 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
100 "Expected MI to be a smstart/smstop instruction");
101 return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
102 MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
103}
104
107 const MachineOperand &MO) {
108 if (!MO.isReg())
109 return false;
110
111 Register R = MO.getReg();
112 if (R.isPhysical())
113 return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
114 return AArch64::ZPRRegClass.contains(SR) ||
115 AArch64::PPRRegClass.contains(SR);
116 });
117
118 const TargetRegisterClass *RC = MRI.getRegClass(R);
119 return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
120 TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
121}
122
123bool SMEPeepholeOpt::optimizeStartStopPairs(
124 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
126 const TargetRegisterInfo &TRI =
128
129 bool Changed = false;
130 MachineInstr *Prev = nullptr;
131
132 // Walk through instructions in the block trying to find pairs of smstart
133 // and smstop nodes that cancel each other out. We only permit a limited
134 // set of instructions to appear between them, otherwise we reset our
135 // tracking.
136 unsigned NumSMChanges = 0;
137 unsigned NumSMChangesRemoved = 0;
139 switch (MI.getOpcode()) {
140 case AArch64::MSRpstatesvcrImm1:
141 case AArch64::MSRpstatePseudo: {
143 NumSMChanges++;
144
145 if (!Prev)
146 Prev = &MI;
147 else if (isMatchingStartStopPair(Prev, &MI)) {
148 // If they match, we can remove them, and possibly any instructions
149 // that we marked for deletion in between.
150 Prev->eraseFromParent();
151 MI.eraseFromParent();
152 Prev = nullptr;
153 Changed = true;
154 NumSMChangesRemoved += 2;
155 } else {
156 Prev = &MI;
157 }
158 continue;
159 }
160 default:
161 if (!Prev)
162 // Avoid doing expensive checks when Prev is nullptr.
163 continue;
164 break;
165 }
166
167 // Test if the instructions in between the start/stop sequence are agnostic
168 // of streaming mode. If not, the algorithm should reset.
169 switch (MI.getOpcode()) {
170 default:
171 Prev = nullptr;
172 break;
173 case AArch64::COALESCER_BARRIER_FPR16:
174 case AArch64::COALESCER_BARRIER_FPR32:
175 case AArch64::COALESCER_BARRIER_FPR64:
176 case AArch64::COALESCER_BARRIER_FPR128:
177 case AArch64::COPY:
178 // These instructions should be safe when executed on their own, but
179 // the code remains conservative when SVE registers are used. There may
180 // exist subtle cases where executing a COPY in a different mode results
181 // in different behaviour, even if we can't yet come up with any
182 // concrete example/test-case.
183 if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
184 isSVERegOp(TRI, MRI, MI.getOperand(1)))
185 Prev = nullptr;
186 break;
187 case AArch64::ADJCALLSTACKDOWN:
188 case AArch64::ADJCALLSTACKUP:
189 case AArch64::ANDXri:
190 case AArch64::ADDXri:
191 // We permit these as they don't generate SVE/NEON instructions.
192 break;
193 case AArch64::MSRpstatesvcrImm1:
194 case AArch64::MSRpstatePseudo:
195 llvm_unreachable("Should have been handled");
196 }
197 }
198
199 HasRemovedAllSMChanges =
200 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
201 return Changed;
202}
203
204// Using the FORM_TRANSPOSED_REG_TUPLE pseudo can improve register allocation
205// of multi-vector intrinsics. However, the pseudo should only be emitted if
206// the input registers of the REG_SEQUENCE are copy nodes where the source
207// register is in a StridedOrContiguous class. For example:
208//
209// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
210// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
211// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
212// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
213// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
214// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
215// %9:zpr2mul2 = REG_SEQUENCE %5:zpr, %subreg.zsub0, %8:zpr, %subreg.zsub1
216//
217// -> %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
218//
219bool SMEPeepholeOpt::visitRegSequence(MachineInstr &MI) {
220 assert(MI.getMF()->getRegInfo().isSSA() && "Expected to be run on SSA form!");
221
222 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
223 switch (MRI.getRegClass(MI.getOperand(0).getReg())->getID()) {
224 case AArch64::ZPR2RegClassID:
225 case AArch64::ZPR4RegClassID:
226 case AArch64::ZPR2Mul2RegClassID:
227 case AArch64::ZPR4Mul4RegClassID:
228 break;
229 default:
230 return false;
231 }
232
233 // The first operand is the register class created by the REG_SEQUENCE.
234 // Each operand pair after this consists of a vreg + subreg index, so
235 // for example a sequence of 2 registers will have a total of 5 operands.
236 if (MI.getNumOperands() != 5 && MI.getNumOperands() != 9)
237 return false;
238
240 for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
241 MachineOperand &MO = MI.getOperand(I);
242
243 MachineOperand *Def = MRI.getOneDef(MO.getReg());
244 if (!Def || !Def->getParent()->isCopy())
245 return false;
246
247 const MachineOperand &CopySrc = Def->getParent()->getOperand(1);
248 unsigned OpSubReg = CopySrc.getSubReg();
250 SubReg = OpSubReg;
251
252 MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
253 if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
254 CopySrcOp->getReg().isPhysical())
255 return false;
256
257 const TargetRegisterClass *CopySrcClass =
258 MRI.getRegClass(CopySrcOp->getReg());
259 if (CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
260 CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass)
261 return false;
262 }
263
264 unsigned Opc = MI.getNumOperands() == 5
265 ? AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO
266 : AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
267
268 const TargetInstrInfo *TII =
269 MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo();
270 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
271 TII->get(Opc), MI.getOperand(0).getReg());
272 for (unsigned I = 1; I < MI.getNumOperands(); I += 2)
273 MIB.addReg(MI.getOperand(I).getReg());
274
275 MI.eraseFromParent();
276 return true;
277}
278
279INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
280 "SME Peephole Optimization", false, false)
281
282bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
283 if (skipFunction(MF.getFunction()))
284 return false;
285
286 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
287 return false;
288
289 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
290
291 bool Changed = false;
292 bool FunctionHasAllSMChangesRemoved = false;
293
294 // Even if the block lives in a function with no SME attributes attached we
295 // still have to analyze all the blocks because we may call a streaming
296 // function that requires smstart/smstop pairs.
297 for (MachineBasicBlock &MBB : MF) {
298 bool BlockHasAllSMChangesRemoved;
299 Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
300 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
301
302 if (MF.getSubtarget<AArch64Subtarget>().isStreaming()) {
304 if (MI.getOpcode() == AArch64::REG_SEQUENCE)
305 Changed |= visitRegSequence(MI);
306 }
307 }
308
309 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
310 if (FunctionHasAllSMChangesRemoved)
311 AFI->setHasStreamingModeChanges(false);
312
313 return Changed;
314}
315
316FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
unsigned SubReg
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
Register const TargetRegisterInfo * TRI
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
static bool isSVERegOp(const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, const MachineOperand &MO)
static bool isMatchingStartStopPair(const MachineInstr *MI1, const MachineInstr *MI2)
static bool isConditionalStartStop(const MachineInstr *MI)
static bool ChangesStreamingMode(const MachineInstr *MI)
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
void setHasStreamingModeChanges(bool HasChanges)
bool isStreaming() const
Returns true if the function has a streaming body.
Represent the analysis usage information of a pass.
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:270
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
Wrapper class representing physical registers. Should be passed by value.
Definition: MCRegister.h:33
static constexpr unsigned NoRegister
Definition: MCRegister.h:52
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
const MachineInstrBuilder & addReg(Register RegNo, unsigned flags=0, unsigned SubReg=0) const
Add a new virtual register operand.
Representation of each machine instruction.
Definition: MachineInstr.h:72
LLVM_ABI void eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:595
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
const uint32_t * getRegMask() const
getRegMask - Returns a bit mask of registers preserved by this RegMask operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:85
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition: Register.h:78
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetRegisterInfo * getRegisterInfo() const =0
Return the target's register information.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
NodeAddr< DefNode * > Def
Definition: RDFGraph.h:384
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
FunctionPass * createSMEPeepholeOptPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1751