LLVM 22.0.0git
AMDGPURegBankLegalizeRules.cpp
Go to the documentation of this file.
1//===-- AMDGPURegBankLegalizeRules.cpp ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// Definitions of RegBankLegalize Rules for all opcodes.
10/// Implementation of container for all the Rules and search.
11/// Fast search for most common case when Rule.Predicate checks LLT and
12/// uniformity of register in operand 0.
13//
14//===----------------------------------------------------------------------===//
15
17#include "AMDGPUInstrInfo.h"
18#include "GCNSubtarget.h"
21#include "llvm/IR/IntrinsicsAMDGPU.h"
23
24#define DEBUG_TYPE "amdgpu-regbanklegalize"
25
26using namespace llvm;
27using namespace AMDGPU;
28
29bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
30 return Ty.isPointer() && Ty.getSizeInBits() == Width;
31}
32
34 std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
35 std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
36 LoweringMethodID LoweringMethod)
37 : DstOpMapping(DstOpMappingList), SrcOpMapping(SrcOpMappingList),
38 LoweringMethod(LoweringMethod) {}
39
41 std::initializer_list<UniformityLLTOpPredicateID> OpList,
42 std::function<bool(const MachineInstr &)> TestFunc)
43 : OpUniformityAndTypes(OpList), TestFunc(TestFunc) {}
44
46 const MachineUniformityInfo &MUI,
47 const MachineRegisterInfo &MRI) {
48 switch (UniID) {
49 case S1:
50 return MRI.getType(Reg) == LLT::scalar(1);
51 case S16:
52 return MRI.getType(Reg) == LLT::scalar(16);
53 case S32:
54 return MRI.getType(Reg) == LLT::scalar(32);
55 case S64:
56 return MRI.getType(Reg) == LLT::scalar(64);
57 case S128:
58 return MRI.getType(Reg) == LLT::scalar(128);
59 case P0:
60 return MRI.getType(Reg) == LLT::pointer(0, 64);
61 case P1:
62 return MRI.getType(Reg) == LLT::pointer(1, 64);
63 case P3:
64 return MRI.getType(Reg) == LLT::pointer(3, 32);
65 case P4:
66 return MRI.getType(Reg) == LLT::pointer(4, 64);
67 case P5:
68 return MRI.getType(Reg) == LLT::pointer(5, 32);
69 case Ptr32:
70 return isAnyPtr(MRI.getType(Reg), 32);
71 case Ptr64:
72 return isAnyPtr(MRI.getType(Reg), 64);
73 case Ptr128:
74 return isAnyPtr(MRI.getType(Reg), 128);
75 case V2S32:
76 return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
77 case V4S32:
78 return MRI.getType(Reg) == LLT::fixed_vector(4, 32);
79 case B32:
80 return MRI.getType(Reg).getSizeInBits() == 32;
81 case B64:
82 return MRI.getType(Reg).getSizeInBits() == 64;
83 case B96:
84 return MRI.getType(Reg).getSizeInBits() == 96;
85 case B128:
86 return MRI.getType(Reg).getSizeInBits() == 128;
87 case B256:
88 return MRI.getType(Reg).getSizeInBits() == 256;
89 case B512:
90 return MRI.getType(Reg).getSizeInBits() == 512;
91 case UniS1:
92 return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
93 case UniS16:
94 return MRI.getType(Reg) == LLT::scalar(16) && MUI.isUniform(Reg);
95 case UniS32:
96 return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
97 case UniS64:
98 return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
99 case UniS128:
100 return MRI.getType(Reg) == LLT::scalar(128) && MUI.isUniform(Reg);
101 case UniP0:
102 return MRI.getType(Reg) == LLT::pointer(0, 64) && MUI.isUniform(Reg);
103 case UniP1:
104 return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
105 case UniP3:
106 return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
107 case UniP4:
108 return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
109 case UniP5:
110 return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
111 case UniPtr32:
112 return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
113 case UniPtr64:
114 return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
115 case UniPtr128:
116 return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
117 case UniV2S16:
118 return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
119 case UniB32:
120 return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
121 case UniB64:
122 return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
123 case UniB96:
124 return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
125 case UniB128:
126 return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
127 case UniB256:
128 return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
129 case UniB512:
130 return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
131 case DivS1:
132 return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
133 case DivS16:
134 return MRI.getType(Reg) == LLT::scalar(16) && MUI.isDivergent(Reg);
135 case DivS32:
136 return MRI.getType(Reg) == LLT::scalar(32) && MUI.isDivergent(Reg);
137 case DivS64:
138 return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
139 case DivS128:
140 return MRI.getType(Reg) == LLT::scalar(128) && MUI.isDivergent(Reg);
141 case DivP0:
142 return MRI.getType(Reg) == LLT::pointer(0, 64) && MUI.isDivergent(Reg);
143 case DivP1:
144 return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
145 case DivP3:
146 return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
147 case DivP4:
148 return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
149 case DivP5:
150 return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
151 case DivPtr32:
152 return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
153 case DivPtr64:
154 return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
155 case DivPtr128:
156 return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
157 case DivV2S16:
158 return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
159 case DivB32:
160 return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
161 case DivB64:
162 return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
163 case DivB96:
164 return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
165 case DivB128:
166 return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
167 case DivB256:
168 return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
169 case DivB512:
170 return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
171 case _:
172 return true;
173 default:
174 llvm_unreachable("missing matchUniformityAndLLT");
175 }
176}
177
179 const MachineUniformityInfo &MUI,
180 const MachineRegisterInfo &MRI) const {
181 // Check LLT signature.
182 for (unsigned i = 0; i < OpUniformityAndTypes.size(); ++i) {
183 if (OpUniformityAndTypes[i] == _) {
184 if (MI.getOperand(i).isReg())
185 return false;
186 continue;
187 }
188
189 // Remaining IDs check registers.
190 if (!MI.getOperand(i).isReg())
191 return false;
192
193 if (!matchUniformityAndLLT(MI.getOperand(i).getReg(),
194 OpUniformityAndTypes[i], MUI, MRI))
195 return false;
196 }
197
198 // More complex check.
199 if (TestFunc)
200 return TestFunc(MI);
201
202 return true;
203}
204
206
208 : FastTypes(FastTypes) {}
209
211 if (Ty == LLT::scalar(16))
212 return S16;
213 if (Ty == LLT::scalar(32))
214 return S32;
215 if (Ty == LLT::scalar(64))
216 return S64;
217 if (Ty == LLT::fixed_vector(2, 16))
218 return V2S16;
219 if (Ty == LLT::fixed_vector(2, 32))
220 return V2S32;
221 if (Ty == LLT::fixed_vector(3, 32))
222 return V3S32;
223 if (Ty == LLT::fixed_vector(4, 32))
224 return V4S32;
225 return _;
226}
227
229 if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
230 isAnyPtr(Ty, 32))
231 return B32;
232 if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
233 Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
234 return B64;
235 if (Ty == LLT::fixed_vector(3, 32))
236 return B96;
237 if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
238 return B128;
239 return _;
240}
241
242const RegBankLLTMapping &
245 const MachineUniformityInfo &MUI) const {
246 // Search in "Fast Rules".
247 // Note: if fast rules are enabled, RegBankLLTMapping must be added in each
248 // slot that could "match fast Predicate". If not, InvalidMapping is
249 // returned which results in failure, does not search "Slow Rules".
250 if (FastTypes != NoFastRules) {
251 Register Reg = MI.getOperand(0).getReg();
252 int Slot;
253 if (FastTypes == StandardB)
254 Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
255 else
256 Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
257
258 if (Slot != -1)
259 return MUI.isUniform(Reg) ? Uni[Slot] : Div[Slot];
260 }
261
262 // Slow search for more complex rules.
263 for (const RegBankLegalizeRule &Rule : Rules) {
264 if (Rule.Predicate.match(MI, MUI, MRI))
265 return Rule.OperandMapping;
266 }
267
268 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
269 llvm_unreachable("None of the rules defined for MI's opcode matched MI");
270}
271
273 Rules.push_back(Rule);
274}
275
277 RegBankLLTMapping RuleApplyIDs) {
278 int Slot = getFastPredicateSlot(Ty);
279 assert(Slot != -1 && "Ty unsupported in this FastRulesTypes");
280 Div[Slot] = RuleApplyIDs;
281}
282
284 RegBankLLTMapping RuleApplyIDs) {
285 int Slot = getFastPredicateSlot(Ty);
286 assert(Slot != -1 && "Ty unsupported in this FastRulesTypes");
287 Uni[Slot] = RuleApplyIDs;
288}
289
290int SetOfRulesForOpcode::getFastPredicateSlot(
292 switch (FastTypes) {
293 case Standard: {
294 switch (Ty) {
295 case S32:
296 return 0;
297 case S16:
298 return 1;
299 case S64:
300 return 2;
301 case V2S16:
302 return 3;
303 default:
304 return -1;
305 }
306 }
307 case StandardB: {
308 switch (Ty) {
309 case B32:
310 return 0;
311 case B64:
312 return 1;
313 case B96:
314 return 2;
315 case B128:
316 return 3;
317 default:
318 return -1;
319 }
320 }
321 case Vector: {
322 switch (Ty) {
323 case S32:
324 return 0;
325 case V2S32:
326 return 1;
327 case V3S32:
328 return 2;
329 case V4S32:
330 return 3;
331 default:
332 return -1;
333 }
334 }
335 default:
336 return -1;
337 }
338}
339
340RegBankLegalizeRules::RuleSetInitializer
341RegBankLegalizeRules::addRulesForGOpcs(std::initializer_list<unsigned> OpcList,
342 FastRulesTypes FastTypes) {
343 return RuleSetInitializer(OpcList, GRulesAlias, GRules, FastTypes);
344}
345
346RegBankLegalizeRules::RuleSetInitializer
347RegBankLegalizeRules::addRulesForIOpcs(std::initializer_list<unsigned> OpcList,
348 FastRulesTypes FastTypes) {
349 return RuleSetInitializer(OpcList, IRulesAlias, IRules, FastTypes);
350}
351
354 unsigned Opc = MI.getOpcode();
355 if (Opc == AMDGPU::G_INTRINSIC || Opc == AMDGPU::G_INTRINSIC_CONVERGENT ||
356 Opc == AMDGPU::G_INTRINSIC_W_SIDE_EFFECTS ||
357 Opc == AMDGPU::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS) {
358 unsigned IntrID = cast<GIntrinsic>(MI).getIntrinsicID();
359 auto IRAIt = IRulesAlias.find(IntrID);
360 if (IRAIt == IRulesAlias.end()) {
361 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
362 llvm_unreachable("No rules defined for intrinsic opcode");
363 }
364 return IRules.at(IRAIt->second);
365 }
366
367 auto GRAIt = GRulesAlias.find(Opc);
368 if (GRAIt == GRulesAlias.end()) {
369 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
370 llvm_unreachable("No rules defined for generic opcode");
371 }
372 return GRules.at(GRAIt->second);
373}
374
375// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
377private:
378 struct Elt {
379 // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
380 // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
381 // Sequences of && and || will be represented by jumps, for example:
382 // (A && B && ... X) or (A && B && ... X) || Y
383 // A == true jump to B
384 // A == false jump to end or Y, result is A(false) or Y
385 // (A || B || ... X) or (A || B || ... X) && Y
386 // A == true jump to end or Y, result is A(true) or Y
387 // A == false jump to B
388 // Notice that when negating expression, we simply flip Neg on each Pred
389 // and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&).
390 std::function<bool(const MachineInstr &)> Pred;
391 bool Neg; // Neg of Pred is calculated before jump
392 unsigned TJumpOffset;
393 unsigned FJumpOffset;
394 };
395
397
398 Predicate(SmallVectorImpl<Elt> &&Expr) { Expression.swap(Expr); };
399
400public:
401 Predicate(std::function<bool(const MachineInstr &)> Pred) {
402 Expression.push_back({Pred, false, 1, 1});
403 };
404
405 bool operator()(const MachineInstr &MI) const {
406 unsigned Idx = 0;
407 unsigned ResultIdx = Expression.size();
408 bool Result;
409 do {
410 Result = Expression[Idx].Pred(MI);
411 Result = Expression[Idx].Neg ? !Result : Result;
412 if (Result) {
413 Idx += Expression[Idx].TJumpOffset;
414 } else {
415 Idx += Expression[Idx].FJumpOffset;
416 }
417 } while ((Idx != ResultIdx));
418
419 return Result;
420 };
421
423 SmallVector<Elt, 8> NegExpression;
424 for (const Elt &ExprElt : Expression) {
425 NegExpression.push_back({ExprElt.Pred, !ExprElt.Neg, ExprElt.FJumpOffset,
426 ExprElt.TJumpOffset});
427 }
428 return Predicate(std::move(NegExpression));
429 };
430
432 SmallVector<Elt, 8> AndExpression = Expression;
433
434 unsigned RHSSize = RHS.Expression.size();
435 unsigned ResultIdx = Expression.size();
436 for (unsigned i = 0; i < ResultIdx; ++i) {
437 // LHS results in false, whole expression results in false.
438 if (i + AndExpression[i].FJumpOffset == ResultIdx)
439 AndExpression[i].FJumpOffset += RHSSize;
440 }
441
442 AndExpression.append(RHS.Expression);
443
444 return Predicate(std::move(AndExpression));
445 }
446
448 SmallVector<Elt, 8> OrExpression = Expression;
449
450 unsigned RHSSize = RHS.Expression.size();
451 unsigned ResultIdx = Expression.size();
452 for (unsigned i = 0; i < ResultIdx; ++i) {
453 // LHS results in true, whole expression results in true.
454 if (i + OrExpression[i].TJumpOffset == ResultIdx)
455 OrExpression[i].TJumpOffset += RHSSize;
456 }
457
458 OrExpression.append(RHS.Expression);
459
460 return Predicate(std::move(OrExpression));
461 }
462};
463
464// Initialize rules
467 : ST(&_ST), MRI(&_MRI) {
468
469 addRulesForGOpcs({G_ADD, G_SUB}, Standard)
470 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
471 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
472
473 addRulesForGOpcs({G_MUL}, Standard).Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
474
475 addRulesForGOpcs({G_XOR, G_OR, G_AND}, StandardB)
477 .Any({{DivS1}, {{Vcc}, {Vcc, Vcc}}})
478 .Any({{UniS16}, {{Sgpr16}, {Sgpr16, Sgpr16}}})
479 .Any({{DivS16}, {{Vgpr16}, {Vgpr16, Vgpr16}}})
480 .Uni(B32, {{SgprB32}, {SgprB32, SgprB32}})
481 .Div(B32, {{VgprB32}, {VgprB32, VgprB32}})
482 .Uni(B64, {{SgprB64}, {SgprB64, SgprB64}})
483 .Div(B64, {{VgprB64}, {VgprB64, VgprB64}, SplitTo32});
484
485 addRulesForGOpcs({G_SHL}, Standard)
486 .Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32ZExt}})
487 .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
489 .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
490 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
491 .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
492 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
493 .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
494
495 addRulesForGOpcs({G_LSHR}, Standard)
496 .Uni(S16, {{Sgpr32Trunc}, {Sgpr32ZExt, Sgpr32ZExt}})
497 .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
499 .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
500 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
501 .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
502 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
503 .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
504
505 addRulesForGOpcs({G_ASHR}, Standard)
506 .Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt, Sgpr32ZExt}})
507 .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
509 .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
510 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
511 .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
512 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
513 .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
514
515 addRulesForGOpcs({G_FRAME_INDEX}).Any({{UniP5, _}, {{SgprP5}, {None}}});
516
517 addRulesForGOpcs({G_UBFX, G_SBFX}, Standard)
518 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32, Sgpr32}, S_BFE})
519 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32, Vgpr32}})
520 .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32, Sgpr32}, S_BFE})
521 .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32, Vgpr32}, V_BFE});
522
523 // Note: we only write S1 rules for G_IMPLICIT_DEF, G_CONSTANT, G_FCONSTANT
524 // and G_FREEZE here, rest is trivially regbankselected earlier
525 addRulesForGOpcs({G_IMPLICIT_DEF}).Any({{UniS1}, {{Sgpr32Trunc}, {}}});
526 addRulesForGOpcs({G_CONSTANT})
527 .Any({{UniS1, _}, {{Sgpr32Trunc}, {None}, UniCstExt}});
528 addRulesForGOpcs({G_FREEZE}).Any({{DivS1}, {{Vcc}, {Vcc}}});
529
530 addRulesForGOpcs({G_ICMP})
531 .Any({{UniS1, _, S32}, {{Sgpr32Trunc}, {None, Sgpr32, Sgpr32}}})
532 .Any({{DivS1, _, S32}, {{Vcc}, {None, Vgpr32, Vgpr32}}})
533 .Any({{DivS1, _, S64}, {{Vcc}, {None, Vgpr64, Vgpr64}}});
534
535 addRulesForGOpcs({G_FCMP})
536 .Any({{UniS1, _, S32}, {{UniInVcc}, {None, Vgpr32, Vgpr32}}})
537 .Any({{DivS1, _, S32}, {{Vcc}, {None, Vgpr32, Vgpr32}}});
538
539 addRulesForGOpcs({G_BRCOND})
540 .Any({{UniS1}, {{}, {Sgpr32AExtBoolInReg}}})
541 .Any({{DivS1}, {{}, {Vcc}}});
542
543 addRulesForGOpcs({G_BR}).Any({{_}, {{}, {None}}});
544
545 addRulesForGOpcs({G_SELECT}, StandardB)
546 .Any({{DivS16}, {{Vgpr16}, {Vcc, Vgpr16, Vgpr16}}})
548 .Div(B32, {{VgprB32}, {Vcc, VgprB32, VgprB32}})
552
553 addRulesForGOpcs({G_ANYEXT})
554 .Any({{UniS16, S1}, {{None}, {None}}}) // should be combined away
555 .Any({{UniS32, S1}, {{None}, {None}}}) // should be combined away
556 .Any({{UniS64, S1}, {{None}, {None}}}) // should be combined away
557 .Any({{DivS16, S1}, {{Vgpr16}, {Vcc}, VccExtToSel}})
558 .Any({{DivS32, S1}, {{Vgpr32}, {Vcc}, VccExtToSel}})
559 .Any({{DivS64, S1}, {{Vgpr64}, {Vcc}, VccExtToSel}})
560 .Any({{UniS64, S32}, {{Sgpr64}, {Sgpr32}, Ext32To64}})
561 .Any({{DivS64, S32}, {{Vgpr64}, {Vgpr32}, Ext32To64}})
562 .Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
563 .Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});
564
565 // In global-isel G_TRUNC in-reg is treated as no-op, inst selected into COPY.
566 // It is up to user to deal with truncated bits.
567 addRulesForGOpcs({G_TRUNC})
568 .Any({{UniS1, UniS16}, {{None}, {None}}}) // should be combined away
569 .Any({{UniS1, UniS32}, {{None}, {None}}}) // should be combined away
570 .Any({{UniS1, UniS64}, {{None}, {None}}}) // should be combined away
571 .Any({{UniS16, S32}, {{Sgpr16}, {Sgpr32}}})
572 .Any({{DivS16, S32}, {{Vgpr16}, {Vgpr32}}})
573 .Any({{UniS32, S64}, {{Sgpr32}, {Sgpr64}}})
574 .Any({{DivS32, S64}, {{Vgpr32}, {Vgpr64}}})
575 .Any({{UniV2S16, V2S32}, {{SgprV2S16}, {SgprV2S32}}})
576 .Any({{DivV2S16, V2S32}, {{VgprV2S16}, {VgprV2S32}}})
577 // This is non-trivial. VgprToVccCopy is done using compare instruction.
578 .Any({{DivS1, DivS16}, {{Vcc}, {Vgpr16}, VgprToVccCopy}})
579 .Any({{DivS1, DivS32}, {{Vcc}, {Vgpr32}, VgprToVccCopy}})
580 .Any({{DivS1, DivS64}, {{Vcc}, {Vgpr64}, VgprToVccCopy}});
581
582 addRulesForGOpcs({G_ZEXT})
586 .Any({{DivS16, S1}, {{Vgpr16}, {Vcc}, VccExtToSel}})
587 .Any({{DivS32, S1}, {{Vgpr32}, {Vcc}, VccExtToSel}})
588 .Any({{DivS64, S1}, {{Vgpr64}, {Vcc}, VccExtToSel}})
589 .Any({{UniS64, S32}, {{Sgpr64}, {Sgpr32}, Ext32To64}})
590 .Any({{DivS64, S32}, {{Vgpr64}, {Vgpr32}, Ext32To64}})
591 // not extending S16 to S32 is questionable.
592 .Any({{UniS64, S16}, {{Sgpr64}, {Sgpr32ZExt}, Ext32To64}})
593 .Any({{DivS64, S16}, {{Vgpr64}, {Vgpr32ZExt}, Ext32To64}})
594 .Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
595 .Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});
596
597 addRulesForGOpcs({G_SEXT})
601 .Any({{DivS16, S1}, {{Vgpr16}, {Vcc}, VccExtToSel}})
602 .Any({{DivS32, S1}, {{Vgpr32}, {Vcc}, VccExtToSel}})
603 .Any({{DivS64, S1}, {{Vgpr64}, {Vcc}, VccExtToSel}})
604 .Any({{UniS64, S32}, {{Sgpr64}, {Sgpr32}, Ext32To64}})
605 .Any({{DivS64, S32}, {{Vgpr64}, {Vgpr32}, Ext32To64}})
606 // not extending S16 to S32 is questionable.
607 .Any({{UniS64, S16}, {{Sgpr64}, {Sgpr32SExt}, Ext32To64}})
608 .Any({{DivS64, S16}, {{Vgpr64}, {Vgpr32SExt}, Ext32To64}})
609 .Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
610 .Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});
611
612 addRulesForGOpcs({G_SEXT_INREG})
613 .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}})
614 .Any({{DivS32, S32}, {{Vgpr32}, {Vgpr32}}})
615 .Any({{UniS64, S64}, {{Sgpr64}, {Sgpr64}}})
617
618 bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
619 bool hasSMRDSmall = ST->hasScalarSubwordLoads();
620
621 Predicate isAlign16([](const MachineInstr &MI) -> bool {
622 return (*MI.memoperands_begin())->getAlign() >= Align(16);
623 });
624
625 Predicate isAlign4([](const MachineInstr &MI) -> bool {
626 return (*MI.memoperands_begin())->getAlign() >= Align(4);
627 });
628
629 Predicate isAtomicMMO([](const MachineInstr &MI) -> bool {
630 return (*MI.memoperands_begin())->isAtomic();
631 });
632
633 Predicate isUniMMO([](const MachineInstr &MI) -> bool {
634 return AMDGPU::isUniformMMO(*MI.memoperands_begin());
635 });
636
637 Predicate isConst([](const MachineInstr &MI) -> bool {
638 // Address space in MMO be different then address space on pointer.
639 const MachineMemOperand *MMO = *MI.memoperands_begin();
640 const unsigned AS = MMO->getAddrSpace();
641 return AS == AMDGPUAS::CONSTANT_ADDRESS ||
643 });
644
645 Predicate isVolatileMMO([](const MachineInstr &MI) -> bool {
646 return (*MI.memoperands_begin())->isVolatile();
647 });
648
649 Predicate isInvMMO([](const MachineInstr &MI) -> bool {
650 return (*MI.memoperands_begin())->isInvariant();
651 });
652
653 Predicate isNoClobberMMO([](const MachineInstr &MI) -> bool {
654 return (*MI.memoperands_begin())->getFlags() & MONoClobber;
655 });
656
657 Predicate isNaturalAlignedSmall([](const MachineInstr &MI) -> bool {
658 const MachineMemOperand *MMO = *MI.memoperands_begin();
659 const unsigned MemSize = 8 * MMO->getSize().getValue();
660 return (MemSize == 16 && MMO->getAlign() >= Align(2)) ||
661 (MemSize == 8 && MMO->getAlign() >= Align(1));
662 });
663
664 auto isUL = !isAtomicMMO && isUniMMO && (isConst || !isVolatileMMO) &&
665 (isConst || isInvMMO || isNoClobberMMO);
666
667 // clang-format off
668 addRulesForGOpcs({G_LOAD})
669 .Any({{DivB32, DivP0}, {{VgprB32}, {VgprP0}}})
670 .Any({{DivB32, UniP0}, {{VgprB32}, {VgprP0}}})
671
672 .Any({{DivB32, DivP1}, {{VgprB32}, {VgprP1}}})
673 .Any({{{UniB256, UniP1}, isAlign4 && isUL}, {{SgprB256}, {SgprP1}}})
674 .Any({{{UniB512, UniP1}, isAlign4 && isUL}, {{SgprB512}, {SgprP1}}})
675 .Any({{{UniB32, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB32}, {SgprP1}}})
676 .Any({{{UniB64, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB64}, {SgprP1}}})
677 .Any({{{UniB96, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB96}, {SgprP1}}})
678 .Any({{{UniB128, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB128}, {SgprP1}}})
679 .Any({{{UniB256, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB256}, {VgprP1}, SplitLoad}})
680 .Any({{{UniB512, UniP1}, !isAlign4 || !isUL}, {{UniInVgprB512}, {VgprP1}, SplitLoad}})
681
682 .Any({{DivB32, UniP3}, {{VgprB32}, {VgprP3}}})
683 .Any({{{UniB32, UniP3}, isAlign4 && isUL}, {{SgprB32}, {SgprP3}}})
684 .Any({{{UniB32, UniP3}, !isAlign4 || !isUL}, {{UniInVgprB32}, {VgprP3}}})
685
686 .Any({{{DivB256, DivP4}}, {{VgprB256}, {VgprP4}, SplitLoad}})
687 .Any({{{UniB32, UniP4}, isNaturalAlignedSmall && isUL}, {{SgprB32}, {SgprP4}}}, hasSMRDSmall) // i8 and i16 load
688 .Any({{{UniB32, UniP4}, isAlign4 && isUL}, {{SgprB32}, {SgprP4}}})
689 .Any({{{UniB96, UniP4}, isAlign16 && isUL}, {{SgprB96}, {SgprP4}, WidenLoad}}, !hasUnalignedLoads)
690 .Any({{{UniB96, UniP4}, isAlign4 && !isAlign16 && isUL}, {{SgprB96}, {SgprP4}, SplitLoad}}, !hasUnalignedLoads)
691 .Any({{{UniB96, UniP4}, isAlign4 && isUL}, {{SgprB96}, {SgprP4}}}, hasUnalignedLoads)
692 .Any({{{UniB128, UniP4}, isAlign4 && isUL}, {{SgprB128}, {SgprP4}}})
693 .Any({{{UniB256, UniP4}, isAlign4 && isUL}, {{SgprB256}, {SgprP4}}})
694 .Any({{{UniB512, UniP4}, isAlign4 && isUL}, {{SgprB512}, {SgprP4}}})
695 .Any({{{UniB32, UniP4}, !isNaturalAlignedSmall || !isUL}, {{UniInVgprB32}, {VgprP4}}}, hasSMRDSmall) // i8 and i16 load
696 .Any({{{UniB32, UniP4}, !isAlign4 || !isUL}, {{UniInVgprB32}, {VgprP4}}})
697 .Any({{{UniB256, UniP4}, !isAlign4 || !isUL}, {{UniInVgprB256}, {VgprP4}, SplitLoad}})
698 .Any({{{UniB512, UniP4}, !isAlign4 || !isUL}, {{UniInVgprB512}, {VgprP4}, SplitLoad}})
699
700 .Any({{DivB32, P5}, {{VgprB32}, {VgprP5}}});
701
702 addRulesForGOpcs({G_ZEXTLOAD}) // i8 and i16 zero-extending loads
703 .Any({{{UniB32, UniP3}, !isAlign4 || !isUL}, {{UniInVgprB32}, {VgprP3}}})
704 .Any({{{UniB32, UniP4}, !isAlign4 || !isUL}, {{UniInVgprB32}, {VgprP4}}});
705 // clang-format on
706
707 addRulesForGOpcs({G_AMDGPU_BUFFER_LOAD}, StandardB)
716
717 addRulesForGOpcs({G_STORE})
718 .Any({{S32, P0}, {{}, {Vgpr32, VgprP0}}})
719 .Any({{S32, P1}, {{}, {Vgpr32, VgprP1}}})
720 .Any({{S64, P1}, {{}, {Vgpr64, VgprP1}}})
721 .Any({{V4S32, P1}, {{}, {VgprV4S32, VgprP1}}});
722
723 addRulesForGOpcs({G_AMDGPU_BUFFER_STORE})
724 .Any({{S32}, {{}, {Vgpr32, SgprV4S32, Vgpr32, Vgpr32, Sgpr32}}});
725
726 addRulesForGOpcs({G_PTR_ADD})
727 .Any({{UniPtr32}, {{SgprPtr32}, {SgprPtr32, Sgpr32}}})
728 .Any({{DivPtr32}, {{VgprPtr32}, {VgprPtr32, Vgpr32}}})
729 .Any({{UniPtr64}, {{SgprPtr64}, {SgprPtr64, Sgpr64}}})
730 .Any({{DivPtr64}, {{VgprPtr64}, {VgprPtr64, Vgpr64}}});
731
732 addRulesForGOpcs({G_INTTOPTR})
733 .Any({{UniPtr32}, {{SgprPtr32}, {Sgpr32}}})
734 .Any({{DivPtr32}, {{VgprPtr32}, {Vgpr32}}})
735 .Any({{UniPtr64}, {{SgprPtr64}, {Sgpr64}}})
736 .Any({{DivPtr64}, {{VgprPtr64}, {Vgpr64}}})
737 .Any({{UniPtr128}, {{SgprPtr128}, {Sgpr128}}})
738 .Any({{DivPtr128}, {{VgprPtr128}, {Vgpr128}}});
739
740 addRulesForGOpcs({G_PTRTOINT})
741 .Any({{UniS32}, {{Sgpr32}, {SgprPtr32}}})
742 .Any({{DivS32}, {{Vgpr32}, {VgprPtr32}}})
743 .Any({{UniS64}, {{Sgpr64}, {SgprPtr64}}})
744 .Any({{DivS64}, {{Vgpr64}, {VgprPtr64}}})
745 .Any({{UniS128}, {{Sgpr128}, {SgprPtr128}}})
746 .Any({{DivS128}, {{Vgpr128}, {VgprPtr128}}});
747
748 addRulesForGOpcs({G_ABS}, Standard).Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt}});
749
750 bool hasSALUFloat = ST->hasSALUFloatInsts();
751
752 addRulesForGOpcs({G_FADD}, Standard)
753 .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}, hasSALUFloat)
754 .Uni(S32, {{UniInVgprS32}, {Vgpr32, Vgpr32}}, !hasSALUFloat)
755 .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
756
757 addRulesForGOpcs({G_FPTOUI})
758 .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat)
759 .Any({{UniS32, S32}, {{UniInVgprS32}, {Vgpr32}}}, !hasSALUFloat);
760
761 addRulesForGOpcs({G_UITOFP})
762 .Any({{DivS32, S32}, {{Vgpr32}, {Vgpr32}}})
763 .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat)
764 .Any({{UniS32, S32}, {{UniInVgprS32}, {Vgpr32}}}, !hasSALUFloat);
765
766 using namespace Intrinsic;
767
768 addRulesForIOpcs({amdgcn_s_getpc}).Any({{UniS64, _}, {{Sgpr64}, {None}}});
769
770 // This is "intrinsic lane mask" it was set to i32/i64 in llvm-ir.
771 addRulesForIOpcs({amdgcn_end_cf}).Any({{_, S32}, {{}, {None, Sgpr32}}});
772
773 addRulesForIOpcs({amdgcn_if_break}, Standard)
774 .Uni(S32, {{Sgpr32}, {IntrId, Vcc, Sgpr32}});
775
776 addRulesForIOpcs({amdgcn_mbcnt_lo, amdgcn_mbcnt_hi}, Standard)
777 .Div(S32, {{}, {Vgpr32, None, Vgpr32, Vgpr32}});
778
779 addRulesForIOpcs({amdgcn_readfirstlane})
780 .Any({{UniS32, _, DivS32}, {{}, {Sgpr32, None, Vgpr32}}})
781 // this should not exist in the first place, it is from call lowering
782 // readfirstlaning just in case register is not in sgpr.
783 .Any({{UniS32, _, UniS32}, {{}, {Sgpr32, None, Vgpr32}}});
784
785} // end initialize rules
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
AMDGPU address space definition.
Contains the definition of a TargetInstrInfo class that is common to all AMD GPUs.
UniformityLLTOpPredicateID LLTToBId(LLT Ty)
bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, const MachineUniformityInfo &MUI, const MachineRegisterInfo &MRI)
UniformityLLTOpPredicateID LLTToId(LLT Ty)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
AMD GCN specific subclass of TargetSubtarget.
Declares convenience wrapper classes for interpreting MachineInstr instances as specific generic oper...
IRTranslator LLVM IR MI
Machine IR instance of the generic uniformity analysis.
#define LLVM_DEBUG(...)
Definition: Debug.h:119
Value * RHS
bool operator()(const MachineInstr &MI) const
Predicate operator||(const Predicate &RHS) const
Predicate operator&&(const Predicate &RHS) const
Predicate(std::function< bool(const MachineInstr &)> Pred)
Predicate operator!() const
RegBankLegalizeRules(const GCNSubtarget &ST, MachineRegisterInfo &MRI)
const SetOfRulesForOpcode & getRulesForOpc(MachineInstr &MI) const
const RegBankLLTMapping & findMappingForMI(const MachineInstr &MI, const MachineRegisterInfo &MRI, const MachineUniformityInfo &MUI) const
void addFastRuleDivergent(UniformityLLTOpPredicateID Ty, RegBankLLTMapping RuleApplyIDs)
void addFastRuleUniform(UniformityLLTOpPredicateID Ty, RegBankLLTMapping RuleApplyIDs)
Definition: Any.h:28
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:177
iterator end()
Definition: DenseMap.h:87
Class representing an expression and its matching format.
bool hasScalarSubwordLoads() const
Definition: GCNSubtarget.h:500
Generation getGeneration() const
Definition: GCNSubtarget.h:356
bool hasSALUFloatInsts() const
bool isDivergent(ConstValueRefT V) const
Whether V is divergent at its definition.
bool isUniform(ConstValueRefT V) const
Whether V is uniform/non-divergent.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:43
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:58
constexpr TypeSize getSizeInBits() const
Returns the total size of the type. Must only be called on sized types.
Definition: LowLevelType.h:191
constexpr bool isPointer() const
Definition: LowLevelType.h:150
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:101
Representation of each machine instruction.
Definition: MachineInstr.h:72
A description of a memory reference used in the backend.
LocationSize getSize() const
Return the size in bytes of the memory reference.
unsigned getAddrSpace() const
LLVM_ABI Align getAlign() const
Return the minimum known alignment in bytes of the actual memory reference.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
void dump() const
Definition: Pass.cpp:146
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
size_t size() const
Definition: SmallVector.h:79
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:684
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ CONSTANT_ADDRESS_32BIT
Address space for 32-bit constant memory.
@ CONSTANT_ADDRESS
Address space for constant memory (VTX2).
bool isAnyPtr(LLT Ty, unsigned Width)
bool isUniformMMO(const MachineMemOperand *MMO)
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
static const MachineMemOperand::Flags MONoClobber
Mark the MMO of a uniform load if there are no potentially clobbering stores on any path from the sta...
Definition: SIInstrInfo.h:44
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
SmallVector< UniformityLLTOpPredicateID, 4 > OpUniformityAndTypes
PredicateMapping(std::initializer_list< UniformityLLTOpPredicateID > OpList, std::function< bool(const MachineInstr &)> TestFunc=nullptr)
bool match(const MachineInstr &MI, const MachineUniformityInfo &MUI, const MachineRegisterInfo &MRI) const
std::function< bool(const MachineInstr &)> TestFunc
RegBankLLTMapping(std::initializer_list< RegBankLLTMappingApplyID > DstOpMappingList, std::initializer_list< RegBankLLTMappingApplyID > SrcOpMappingList, LoweringMethodID LoweringMethod=DoNotLower)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39