LLVM 22.0.0git
RISCVOptWInstrs.cpp
Go to the documentation of this file.
1//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass does some optimizations for *W instructions at the MI level.
10//
11// First it removes unneeded sext.w instructions. Either because the sign
12// extended bits aren't consumed or because the input was already sign extended
13// by an earlier instruction.
14//
15// Then:
16// 1. Unless explicit disabled or the target prefers instructions with W suffix,
17// it removes the -w suffix from opw instructions whenever all users are
18// dependent only on the lower word of the result of the instruction.
19// The cases handled are:
20// * addw because c.add has a larger register encoding than c.addw.
21// * addiw because it helps reduce test differences between RV32 and RV64
22// w/o being a pessimization.
23// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24// * slliw because c.slliw doesn't exist and c.slli does
25//
26// 2. Or if explicit enabled or the target prefers instructions with W suffix,
27// it adds the W suffix to the instruction whenever all users are dependent
28// only on the lower word of the result of the instruction.
29// The cases handled are:
30// * add/addi/sub/mul.
31// * slli with imm < 32.
32// * ld/lwu.
33//===---------------------------------------------------------------------===//
34
35#include "RISCV.h"
37#include "RISCVSubtarget.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/ADT/Statistic.h"
42
43using namespace llvm;
44
45#define DEBUG_TYPE "riscv-opt-w-instrs"
46#define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47
48STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49STATISTIC(NumTransformedToWInstrs,
50 "Number of instructions transformed to W-ops");
51STATISTIC(NumTransformedToNonWInstrs,
52 "Number of instructions transformed to non-W-ops");
53
54static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
55 cl::desc("Disable removal of sext.w"),
56 cl::init(false), cl::Hidden);
57static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
58 cl::desc("Disable strip W suffix"),
59 cl::init(false), cl::Hidden);
60
61namespace {
62
63class RISCVOptWInstrs : public MachineFunctionPass {
64public:
65 static char ID;
66
67 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
68
69 bool runOnMachineFunction(MachineFunction &MF) override;
70 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
72 bool canonicalizeWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73 const RISCVSubtarget &ST,
75
76 void getAnalysisUsage(AnalysisUsage &AU) const override {
77 AU.setPreservesCFG();
79 }
80
81 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
82};
83
84} // end anonymous namespace
85
86char RISCVOptWInstrs::ID = 0;
88 false)
89
91 return new RISCVOptWInstrs();
92}
93
95 unsigned Bits) {
96 const MachineInstr &MI = *UserOp.getParent();
97 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
98
99 if (!MCOpcode)
100 return false;
101
102 const MCInstrDesc &MCID = MI.getDesc();
103 const uint64_t TSFlags = MCID.TSFlags;
104 if (!RISCVII::hasSEWOp(TSFlags))
105 return false;
106 assert(RISCVII::hasVLOp(TSFlags));
107 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
108
109 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
110 return false;
111
112 auto NumDemandedBits =
113 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
114 return NumDemandedBits && Bits >= *NumDemandedBits;
115}
116
117// Checks if all users only demand the lower \p OrigBits of the original
118// instruction's result.
119// TODO: handle multiple interdependent transformations
120static bool hasAllNBitUsers(const MachineInstr &OrigMI,
121 const RISCVSubtarget &ST,
122 const MachineRegisterInfo &MRI, unsigned OrigBits) {
123
126
127 Worklist.emplace_back(&OrigMI, OrigBits);
128
129 while (!Worklist.empty()) {
130 auto P = Worklist.pop_back_val();
131 const MachineInstr *MI = P.first;
132 unsigned Bits = P.second;
133
134 if (!Visited.insert(P).second)
135 continue;
136
137 // Only handle instructions with one def.
138 if (MI->getNumExplicitDefs() != 1)
139 return false;
140
141 Register DestReg = MI->getOperand(0).getReg();
142 if (!DestReg.isVirtual())
143 return false;
144
145 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
146 const MachineInstr *UserMI = UserOp.getParent();
147 unsigned OpIdx = UserOp.getOperandNo();
148
149 switch (UserMI->getOpcode()) {
150 default:
151 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
152 break;
153 return false;
154
155 case RISCV::ADDIW:
156 case RISCV::ADDW:
157 case RISCV::DIVUW:
158 case RISCV::DIVW:
159 case RISCV::MULW:
160 case RISCV::REMUW:
161 case RISCV::REMW:
162 case RISCV::SLLW:
163 case RISCV::SRAIW:
164 case RISCV::SRAW:
165 case RISCV::SRLIW:
166 case RISCV::SRLW:
167 case RISCV::SUBW:
168 case RISCV::ROLW:
169 case RISCV::RORW:
170 case RISCV::RORIW:
171 case RISCV::CLZW:
172 case RISCV::CTZW:
173 case RISCV::CPOPW:
174 case RISCV::SLLI_UW:
175 case RISCV::FMV_W_X:
176 case RISCV::FCVT_H_W:
177 case RISCV::FCVT_H_W_INX:
178 case RISCV::FCVT_H_WU:
179 case RISCV::FCVT_H_WU_INX:
180 case RISCV::FCVT_S_W:
181 case RISCV::FCVT_S_W_INX:
182 case RISCV::FCVT_S_WU:
183 case RISCV::FCVT_S_WU_INX:
184 case RISCV::FCVT_D_W:
185 case RISCV::FCVT_D_W_INX:
186 case RISCV::FCVT_D_WU:
187 case RISCV::FCVT_D_WU_INX:
188 if (Bits >= 32)
189 break;
190 return false;
191
192 case RISCV::SEXT_B:
193 case RISCV::PACKH:
194 if (Bits >= 8)
195 break;
196 return false;
197 case RISCV::SEXT_H:
198 case RISCV::FMV_H_X:
199 case RISCV::ZEXT_H_RV32:
200 case RISCV::ZEXT_H_RV64:
201 case RISCV::PACKW:
202 if (Bits >= 16)
203 break;
204 return false;
205
206 case RISCV::PACK:
207 if (Bits >= (ST.getXLen() / 2))
208 break;
209 return false;
210
211 case RISCV::SRLI: {
212 // If we are shifting right by less than Bits, and users don't demand
213 // any bits that were shifted into [Bits-1:0], then we can consider this
214 // as an N-Bit user.
215 unsigned ShAmt = UserMI->getOperand(2).getImm();
216 if (Bits > ShAmt) {
217 Worklist.emplace_back(UserMI, Bits - ShAmt);
218 break;
219 }
220 return false;
221 }
222
223 // these overwrite higher input bits, otherwise the lower word of output
224 // depends only on the lower word of input. So check their uses read W.
225 case RISCV::SLLI: {
226 unsigned ShAmt = UserMI->getOperand(2).getImm();
227 if (Bits >= (ST.getXLen() - ShAmt))
228 break;
229 Worklist.emplace_back(UserMI, Bits + ShAmt);
230 break;
231 }
232 case RISCV::SLLIW: {
233 unsigned ShAmt = UserMI->getOperand(2).getImm();
234 if (Bits >= 32 - ShAmt)
235 break;
236 Worklist.emplace_back(UserMI, Bits + ShAmt);
237 break;
238 }
239
240 case RISCV::ANDI: {
241 uint64_t Imm = UserMI->getOperand(2).getImm();
242 if (Bits >= (unsigned)llvm::bit_width(Imm))
243 break;
244 Worklist.emplace_back(UserMI, Bits);
245 break;
246 }
247 case RISCV::ORI: {
248 uint64_t Imm = UserMI->getOperand(2).getImm();
249 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
250 break;
251 Worklist.emplace_back(UserMI, Bits);
252 break;
253 }
254
255 case RISCV::SLL:
256 case RISCV::BSET:
257 case RISCV::BCLR:
258 case RISCV::BINV:
259 // Operand 2 is the shift amount which uses log2(xlen) bits.
260 if (OpIdx == 2) {
261 if (Bits >= Log2_32(ST.getXLen()))
262 break;
263 return false;
264 }
265 Worklist.emplace_back(UserMI, Bits);
266 break;
267
268 case RISCV::SRA:
269 case RISCV::SRL:
270 case RISCV::ROL:
271 case RISCV::ROR:
272 // Operand 2 is the shift amount which uses 6 bits.
273 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
274 break;
275 return false;
276
277 case RISCV::ADD_UW:
278 case RISCV::SH1ADD_UW:
279 case RISCV::SH2ADD_UW:
280 case RISCV::SH3ADD_UW:
281 // Operand 1 is implicitly zero extended.
282 if (OpIdx == 1 && Bits >= 32)
283 break;
284 Worklist.emplace_back(UserMI, Bits);
285 break;
286
287 case RISCV::BEXTI:
288 if (UserMI->getOperand(2).getImm() >= Bits)
289 return false;
290 break;
291
292 case RISCV::SB:
293 // The first argument is the value to store.
294 if (OpIdx == 0 && Bits >= 8)
295 break;
296 return false;
297 case RISCV::SH:
298 // The first argument is the value to store.
299 if (OpIdx == 0 && Bits >= 16)
300 break;
301 return false;
302 case RISCV::SW:
303 // The first argument is the value to store.
304 if (OpIdx == 0 && Bits >= 32)
305 break;
306 return false;
307
308 // For these, lower word of output in these operations, depends only on
309 // the lower word of input. So, we check all uses only read lower word.
310 case RISCV::COPY:
311 case RISCV::PHI:
312
313 case RISCV::ADD:
314 case RISCV::ADDI:
315 case RISCV::AND:
316 case RISCV::MUL:
317 case RISCV::OR:
318 case RISCV::SUB:
319 case RISCV::XOR:
320 case RISCV::XORI:
321
322 case RISCV::ANDN:
323 case RISCV::CLMUL:
324 case RISCV::ORN:
325 case RISCV::SH1ADD:
326 case RISCV::SH2ADD:
327 case RISCV::SH3ADD:
328 case RISCV::XNOR:
329 case RISCV::BSETI:
330 case RISCV::BCLRI:
331 case RISCV::BINVI:
332 Worklist.emplace_back(UserMI, Bits);
333 break;
334
335 case RISCV::BREV8:
336 case RISCV::ORC_B:
337 // BREV8 and ORC_B work on bytes. Round Bits down to the nearest byte.
338 Worklist.emplace_back(UserMI, alignDown(Bits, 8));
339 break;
340
341 case RISCV::PseudoCCMOVGPR:
342 case RISCV::PseudoCCMOVGPRNoX0:
343 // Either operand 4 or operand 5 is returned by this instruction. If
344 // only the lower word of the result is used, then only the lower word
345 // of operand 4 and 5 is used.
346 if (OpIdx != 4 && OpIdx != 5)
347 return false;
348 Worklist.emplace_back(UserMI, Bits);
349 break;
350
351 case RISCV::CZERO_EQZ:
352 case RISCV::CZERO_NEZ:
353 case RISCV::VT_MASKC:
354 case RISCV::VT_MASKCN:
355 if (OpIdx != 1)
356 return false;
357 Worklist.emplace_back(UserMI, Bits);
358 break;
359 }
360 }
361 }
362
363 return true;
364}
365
366static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
367 const MachineRegisterInfo &MRI) {
368 return hasAllNBitUsers(OrigMI, ST, MRI, 32);
369}
370
371// This function returns true if the machine instruction always outputs a value
372// where bits 63:32 match bit 31.
373static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
374 uint64_t TSFlags = MI.getDesc().TSFlags;
375
376 // Instructions that can be determined from opcode are marked in tablegen.
378 return true;
379
380 // Special cases that require checking operands.
381 switch (MI.getOpcode()) {
382 // shifting right sufficiently makes the value 32-bit sign-extended
383 case RISCV::SRAI:
384 return MI.getOperand(2).getImm() >= 32;
385 case RISCV::SRLI:
386 return MI.getOperand(2).getImm() > 32;
387 // The LI pattern ADDI rd, X0, imm is sign extended.
388 case RISCV::ADDI:
389 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
390 // An ANDI with an 11 bit immediate will zero bits 63:11.
391 case RISCV::ANDI:
392 return isUInt<11>(MI.getOperand(2).getImm());
393 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
394 case RISCV::ORI:
395 return !isUInt<11>(MI.getOperand(2).getImm());
396 // A bseti with X0 is sign extended if the immediate is less than 31.
397 case RISCV::BSETI:
398 return MI.getOperand(2).getImm() < 31 &&
399 MI.getOperand(1).getReg() == RISCV::X0;
400 // Copying from X0 produces zero.
401 case RISCV::COPY:
402 return MI.getOperand(1).getReg() == RISCV::X0;
403 // Ignore the scratch register destination.
404 case RISCV::PseudoAtomicLoadNand32:
405 return OpNo == 0;
406 case RISCV::PseudoVMV_X_S: {
407 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
408 int64_t Log2SEW = MI.getOperand(2).getImm();
409 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
410 return Log2SEW <= 5;
411 }
412 }
413
414 return false;
415}
416
417static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
420 SmallSet<Register, 4> Visited;
422
423 auto AddRegToWorkList = [&](Register SrcReg) {
424 if (!SrcReg.isVirtual())
425 return false;
426 Worklist.push_back(SrcReg);
427 return true;
428 };
429
430 if (!AddRegToWorkList(SrcReg))
431 return false;
432
433 while (!Worklist.empty()) {
434 Register Reg = Worklist.pop_back_val();
435
436 // If we already visited this register, we don't need to check it again.
437 if (!Visited.insert(Reg).second)
438 continue;
439
440 MachineInstr *MI = MRI.getVRegDef(Reg);
441 if (!MI)
442 continue;
443
444 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
445 assert(OpNo != -1 && "Couldn't find register");
446
447 // If this is a sign extending operation we don't need to look any further.
448 if (isSignExtendingOpW(*MI, OpNo))
449 continue;
450
451 // Is this an instruction that propagates sign extend?
452 switch (MI->getOpcode()) {
453 default:
454 // Unknown opcode, give up.
455 return false;
456 case RISCV::COPY: {
457 const MachineFunction *MF = MI->getMF();
458 const RISCVMachineFunctionInfo *RVFI =
460
461 // If this is the entry block and the register is livein, see if we know
462 // it is sign extended.
463 if (MI->getParent() == &MF->front()) {
464 Register VReg = MI->getOperand(0).getReg();
465 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
466 continue;
467 }
468
469 Register CopySrcReg = MI->getOperand(1).getReg();
470 if (CopySrcReg == RISCV::X10) {
471 // For a method return value, we check the ZExt/SExt flags in attribute.
472 // We assume the following code sequence for method call.
473 // PseudoCALL @bar, ...
474 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
475 // %0:gpr = COPY $x10
476 //
477 // We use the PseudoCall to look up the IR function being called to find
478 // its return attributes.
479 const MachineBasicBlock *MBB = MI->getParent();
480 auto II = MI->getIterator();
481 if (II == MBB->instr_begin() ||
482 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
483 return false;
484
485 const MachineInstr &CallMI = *(--II);
486 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
487 return false;
488
489 auto *CalleeFn =
490 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
491 if (!CalleeFn)
492 return false;
493
494 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
495 if (!IntTy)
496 return false;
497
498 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
499 unsigned BitWidth = IntTy->getBitWidth();
500 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
501 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
502 continue;
503 }
504
505 if (!AddRegToWorkList(CopySrcReg))
506 return false;
507
508 break;
509 }
510
511 // For these, we just need to check if the 1st operand is sign extended.
512 case RISCV::BCLRI:
513 case RISCV::BINVI:
514 case RISCV::BSETI:
515 if (MI->getOperand(2).getImm() >= 31)
516 return false;
517 [[fallthrough]];
518 case RISCV::REM:
519 case RISCV::ANDI:
520 case RISCV::ORI:
521 case RISCV::XORI:
522 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
523 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
524 // Logical operations use a sign extended 12-bit immediate.
525 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
526 return false;
527
528 break;
529 case RISCV::PseudoCCADDW:
530 case RISCV::PseudoCCADDIW:
531 case RISCV::PseudoCCSUBW:
532 case RISCV::PseudoCCSLLW:
533 case RISCV::PseudoCCSRLW:
534 case RISCV::PseudoCCSRAW:
535 case RISCV::PseudoCCSLLIW:
536 case RISCV::PseudoCCSRLIW:
537 case RISCV::PseudoCCSRAIW:
538 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
539 // need to check if operand 4 is sign extended.
540 if (!AddRegToWorkList(MI->getOperand(4).getReg()))
541 return false;
542 break;
543 case RISCV::REMU:
544 case RISCV::AND:
545 case RISCV::OR:
546 case RISCV::XOR:
547 case RISCV::ANDN:
548 case RISCV::ORN:
549 case RISCV::XNOR:
550 case RISCV::MAX:
551 case RISCV::MAXU:
552 case RISCV::MIN:
553 case RISCV::MINU:
554 case RISCV::PseudoCCMOVGPR:
555 case RISCV::PseudoCCMOVGPRNoX0:
556 case RISCV::PseudoCCAND:
557 case RISCV::PseudoCCOR:
558 case RISCV::PseudoCCXOR:
559 case RISCV::PHI: {
560 // If all incoming values are sign-extended, the output of AND, OR, XOR,
561 // MIN, MAX, or PHI is also sign-extended.
562
563 // The input registers for PHI are operand 1, 3, ...
564 // The input registers for PseudoCCMOVGPR(NoX0) are 4 and 5.
565 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
566 // The input registers for others are operand 1 and 2.
567 unsigned B = 1, E = 3, D = 1;
568 switch (MI->getOpcode()) {
569 case RISCV::PHI:
570 E = MI->getNumOperands();
571 D = 2;
572 break;
573 case RISCV::PseudoCCMOVGPR:
574 case RISCV::PseudoCCMOVGPRNoX0:
575 B = 4;
576 E = 6;
577 break;
578 case RISCV::PseudoCCAND:
579 case RISCV::PseudoCCOR:
580 case RISCV::PseudoCCXOR:
581 B = 4;
582 E = 7;
583 break;
584 }
585
586 for (unsigned I = B; I != E; I += D) {
587 if (!MI->getOperand(I).isReg())
588 return false;
589
590 if (!AddRegToWorkList(MI->getOperand(I).getReg()))
591 return false;
592 }
593
594 break;
595 }
596
597 case RISCV::CZERO_EQZ:
598 case RISCV::CZERO_NEZ:
599 case RISCV::VT_MASKC:
600 case RISCV::VT_MASKCN:
601 // Instructions return zero or operand 1. Result is sign extended if
602 // operand 1 is sign extended.
603 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
604 return false;
605 break;
606
607 case RISCV::ADDI: {
608 if (MI->getOperand(1).isReg() && MI->getOperand(1).getReg().isVirtual()) {
609 if (MachineInstr *SrcMI = MRI.getVRegDef(MI->getOperand(1).getReg())) {
610 if (SrcMI->getOpcode() == RISCV::LUI &&
611 SrcMI->getOperand(1).isImm()) {
612 uint64_t Imm = SrcMI->getOperand(1).getImm();
613 Imm = SignExtend64<32>(Imm << 12);
614 Imm += (uint64_t)MI->getOperand(2).getImm();
615 if (isInt<32>(Imm))
616 continue;
617 }
618 }
619 }
620
621 if (hasAllWUsers(*MI, ST, MRI)) {
622 FixableDef.insert(MI);
623 break;
624 }
625 return false;
626 }
627
628 // With these opcode, we can "fix" them with the W-version
629 // if we know all users of the result only rely on bits 31:0
630 case RISCV::SLLI:
631 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
632 if (MI->getOperand(2).getImm() >= 32)
633 return false;
634 [[fallthrough]];
635 case RISCV::ADD:
636 case RISCV::LD:
637 case RISCV::LWU:
638 case RISCV::MUL:
639 case RISCV::SUB:
640 if (hasAllWUsers(*MI, ST, MRI)) {
641 FixableDef.insert(MI);
642 break;
643 }
644 return false;
645 }
646 }
647
648 // If we get here, then every node we visited produces a sign extended value
649 // or propagated sign extended values. So the result must be sign extended.
650 return true;
651}
652
653static unsigned getWOp(unsigned Opcode) {
654 switch (Opcode) {
655 case RISCV::ADDI:
656 return RISCV::ADDIW;
657 case RISCV::ADD:
658 return RISCV::ADDW;
659 case RISCV::LD:
660 case RISCV::LWU:
661 return RISCV::LW;
662 case RISCV::MUL:
663 return RISCV::MULW;
664 case RISCV::SLLI:
665 return RISCV::SLLIW;
666 case RISCV::SUB:
667 return RISCV::SUBW;
668 default:
669 llvm_unreachable("Unexpected opcode for replacement with W variant");
670 }
671}
672
673bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
674 const RISCVInstrInfo &TII,
675 const RISCVSubtarget &ST,
678 return false;
679
680 bool MadeChange = false;
681 for (MachineBasicBlock &MBB : MF) {
683 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
684 if (!RISCVInstrInfo::isSEXT_W(MI))
685 continue;
686
687 Register SrcReg = MI.getOperand(1).getReg();
688
690
691 // If all users only use the lower bits, this sext.w is redundant.
692 // Or if all definitions reaching MI sign-extend their output,
693 // then sext.w is redundant.
694 if (!hasAllWUsers(MI, ST, MRI) &&
695 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
696 continue;
697
698 Register DstReg = MI.getOperand(0).getReg();
699 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
700 continue;
701
702 // Convert Fixable instructions to their W versions.
703 for (MachineInstr *Fixable : FixableDefs) {
704 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
705 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
706 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
707 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
708 Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
709 LLVM_DEBUG(dbgs() << " with " << *Fixable);
710 ++NumTransformedToWInstrs;
711 }
712
713 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
714 MRI.replaceRegWith(DstReg, SrcReg);
715 MRI.clearKillFlags(SrcReg);
716 MI.eraseFromParent();
717 ++NumRemovedSExtW;
718 MadeChange = true;
719 }
720 }
721
722 return MadeChange;
723}
724
725// Strips or adds W suffixes to eligible instructions depending on the
726// subtarget preferences.
727bool RISCVOptWInstrs::canonicalizeWSuffixes(MachineFunction &MF,
728 const RISCVInstrInfo &TII,
729 const RISCVSubtarget &ST,
731 bool ShouldStripW = !(DisableStripWSuffix || ST.preferWInst());
732 bool ShouldPreferW = ST.preferWInst();
733 bool MadeChange = false;
734
735 for (MachineBasicBlock &MBB : MF) {
736 for (MachineInstr &MI : MBB) {
737 std::optional<unsigned> WOpc;
738 std::optional<unsigned> NonWOpc;
739 unsigned OrigOpc = MI.getOpcode();
740 switch (OrigOpc) {
741 default:
742 continue;
743 case RISCV::ADDW:
744 NonWOpc = RISCV::ADD;
745 break;
746 case RISCV::ADDIW:
747 NonWOpc = RISCV::ADDI;
748 break;
749 case RISCV::MULW:
750 NonWOpc = RISCV::MUL;
751 break;
752 case RISCV::SLLIW:
753 NonWOpc = RISCV::SLLI;
754 break;
755 case RISCV::SUBW:
756 NonWOpc = RISCV::SUB;
757 break;
758 case RISCV::ADD:
759 WOpc = RISCV::ADDW;
760 break;
761 case RISCV::ADDI:
762 WOpc = RISCV::ADDIW;
763 break;
764 case RISCV::SUB:
765 WOpc = RISCV::SUBW;
766 break;
767 case RISCV::MUL:
768 WOpc = RISCV::MULW;
769 break;
770 case RISCV::SLLI:
771 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits.
772 if (MI.getOperand(2).getImm() >= 32)
773 continue;
774 WOpc = RISCV::SLLIW;
775 break;
776 case RISCV::LD:
777 case RISCV::LWU:
778 WOpc = RISCV::LW;
779 break;
780 }
781
782 if (ShouldStripW && NonWOpc.has_value() && hasAllWUsers(MI, ST, MRI)) {
783 LLVM_DEBUG(dbgs() << "Replacing " << MI);
784 MI.setDesc(TII.get(NonWOpc.value()));
785 LLVM_DEBUG(dbgs() << " with " << MI);
786 ++NumTransformedToNonWInstrs;
787 MadeChange = true;
788 continue;
789 }
790 // LWU is always converted to LW when possible as 1) LW is compressible
791 // and 2) it helps minimise differences vs RV32.
792 if ((ShouldPreferW || OrigOpc == RISCV::LWU) && WOpc.has_value() &&
793 hasAllWUsers(MI, ST, MRI)) {
794 LLVM_DEBUG(dbgs() << "Replacing " << MI);
795 MI.setDesc(TII.get(WOpc.value()));
796 MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
797 MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
798 MI.clearFlag(MachineInstr::MIFlag::IsExact);
799 LLVM_DEBUG(dbgs() << " with " << MI);
800 ++NumTransformedToWInstrs;
801 MadeChange = true;
802 continue;
803 }
804 }
805 }
806 return MadeChange;
807}
808
809bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
810 if (skipFunction(MF.getFunction()))
811 return false;
812
815 const RISCVInstrInfo &TII = *ST.getInstrInfo();
816
817 if (!ST.is64Bit())
818 return false;
819
820 bool MadeChange = false;
821 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
822 MadeChange |= canonicalizeWSuffixes(MF, TII, ST, MRI);
823 return MadeChange;
824}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
MachineInstr unsigned OpIdx
uint64_t IntrinsicInst * II
#define P(N)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:56
static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI, SmallPtrSetImpl< MachineInstr * > &FixableDef)
static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI)
static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo)
static cl::opt< bool > DisableStripWSuffix("riscv-disable-strip-w-suffix", cl::desc("Disable strip W suffix"), cl::init(false), cl::Hidden)
static bool hasAllNBitUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, const MachineRegisterInfo &MRI, unsigned OrigBits)
#define RISCV_OPT_W_INSTRS_NAME
static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, unsigned Bits)
static cl::opt< bool > DisableSExtWRemoval("riscv-disable-sextw-removal", cl::desc("Disable removal of sext.w"), cl::init(false), cl::Hidden)
#define DEBUG_TYPE
static unsigned getWOp(unsigned Opcode)
This file defines the SmallSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
#define LLVM_DEBUG(...)
Definition: Debug.h:119
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:247
Represent the analysis usage information of a pass.
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:270
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
Describe properties that are true of each instruction in the target description file.
Definition: MCInstrDesc.h:199
instr_iterator instr_begin()
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineBasicBlock & front() const
Representation of each machine instruction.
Definition: MachineInstr.h:72
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
Definition: MachineInstr.h:587
const MachineBasicBlock * getParent() const
Definition: MachineInstr.h:359
bool isCall(QueryType Type=AnyInBundle) const
Definition: MachineInstr.h:948
unsigned getOperandNo(const_mop_iterator I) const
Returns the number of the operand iterator I points to.
Definition: MachineInstr.h:773
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:595
MachineOperand class - Representation of each machine instruction operand.
LLVM_ABI unsigned getOperandNo() const
Returns the index of this operand in the instruction that it belongs to.
const GlobalValue * getGlobal() const
int64_t getImm() const
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI bool isLiveIn(Register Reg) const
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:85
RISCVMachineFunctionInfo - This class is derived from MachineFunctionInfo and contains private RISCV-...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition: Register.h:74
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:380
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:401
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:541
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:134
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:182
bool empty() const
Definition: SmallVector.h:82
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:938
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
static unsigned getVLOpNum(const MCInstrDesc &Desc)
static bool hasVLOp(uint64_t TSFlags)
static unsigned getSEWOpNum(const MCInstrDesc &Desc)
static bool hasSEWOp(uint64_t TSFlags)
unsigned getRVVMCOpcode(unsigned RVVPseudoOpcode)
std::optional< unsigned > getVectorLowDemandedScalarBits(unsigned Opcode, unsigned Log2SEW)
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:444
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
int bit_width(T Value)
Returns the number of bits needed to represent Value if Value is nonzero.
Definition: bit.h:270
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
constexpr T alignDown(U Value, V Align, W Skew=0)
Returns the largest unsigned integer less than or equal to Value and is Skew mod Align.
Definition: MathExtras.h:551
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition: MathExtras.h:336
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
FunctionPass * createRISCVOptWInstrsPass()
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:223