LLVM 22.0.0git
TileShapeInfo.h
Go to the documentation of this file.
1//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Shape utility for AMX.
10/// AMX hardware requires to config the shape of tile data register before use.
11/// The 2D shape includes row and column. In AMX intrinsics interface the shape
12/// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13/// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14/// tile config and register allocator. The row and column are machine operand
15/// of AMX pseudo instructions.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20#define LLVM_CODEGEN_TILESHAPEINFO_H
21
26
27namespace llvm {
28
29class ShapeT {
30public:
32 const MachineRegisterInfo *MRI = nullptr)
33 : Row(Row), Col(Col) {
34 if (MRI)
36 }
37 // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col)
38 // and ImmShapes. Due to the most case is only one shape (just simply use
39 // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector
40 // Shapes to keep the speed and code simplicity.
41 // TODO: The upper solution is a temporary way to minimize current tile
42 // register allocation code changes. It can not handle both Reg shape and
43 // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2
44 // is imm shape). Refine me when we have more multi-tile shape instructions!
46 const MachineRegisterInfo *MRI = nullptr)
47 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
48 ColImm(InvalidImmShape) {
49 assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!");
50
51 llvm::append_range(Shapes, ShapesOperands);
52
53 if (MRI)
55 }
57 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
58 ColImm(InvalidImmShape) {}
59 // TODO: We need to extern cmp operator for multi-shapes if
60 // we have requirement in the future.
61 bool operator==(const ShapeT &Shape) const {
62 MachineOperand *R = Shape.Row;
63 MachineOperand *C = Shape.Col;
64 if (!R || !C)
65 return false;
66 if (!Row || !Col)
67 return false;
68 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
69 return true;
70 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
71 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
72 return false;
73 }
74
75 bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
76
77 MachineOperand *getRow(unsigned I = 0) const {
78 if (Shapes.empty())
79 return Row;
80 assert(Shapes.size() / 2 >= I && "Get invalid row from id!");
81 return Shapes[I * 2];
82 }
83
84 MachineOperand *getCol(unsigned I = 0) const {
85 if (Shapes.empty())
86 return Col;
87 assert(Shapes.size() / 2 >= I && "Get invalid col from id!");
88 return Shapes[I * 2 + 1];
89 }
90
91 int64_t getRowImm(unsigned I = 0) const {
92 if (ImmShapes.empty())
93 return RowImm;
94 assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!");
95 return ImmShapes[I * 2];
96 }
97
98 int64_t getColImm(unsigned I = 0) const {
99 if (ImmShapes.empty())
100 return ColImm;
101 assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!");
102 return ImmShapes[I * 2 + 1];
103 }
104
105 unsigned getShapeNum() {
106 if (Shapes.empty())
107 return isValid() ? 1 : 0;
108 else
109 return Shapes.size() / 2;
110 }
111
112 bool isValid() { return (Row != nullptr) && (Col != nullptr); }
113
115 // All def must be the same value, otherwise it is invalid MIs.
116 // Find the immediate.
117 // TODO copy propagation.
118 auto GetImm = [&](Register Reg) {
119 int64_t Imm = InvalidImmShape;
120 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
121 const auto *MI = DefMO.getParent();
122 if (MI->isMoveImmediate()) {
123 assert(MI->getNumOperands() == 2 &&
124 "Unsupported number of operands in instruction for setting "
125 "row/column.");
126 if (MI->getOperand(1).isImm()) {
127 Imm = MI->getOperand(1).getImm();
128 } else {
129 assert(MI->getOperand(1).isImplicit() &&
130 "Operand 1 is assumed to be implicit.");
131 Imm = 0;
132 }
133 break;
134 }
135 }
136 return Imm;
137 };
138 if (Shapes.empty()) { // Single Shape
139 RowImm = GetImm(Row->getReg());
140 ColImm = GetImm(Col->getReg());
141 // The number of rows of 2nd destination buffer is assigned by the one of
142 // 1st destination buffer. If the column size is equal to zero, the row
143 // size should be reset to zero too.
144 if (ColImm == 0)
145 Row = Col;
146 } else { // Multiple Shapes
147 for (auto *Shape : Shapes) {
148 int64_t ImmShape = GetImm(Shape->getReg());
149 ImmShapes.push_back(ImmShape);
150 }
151 }
152 }
153
154private:
155 static constexpr int64_t InvalidImmShape = -1;
156 MachineOperand *Row;
157 MachineOperand *Col;
158 int64_t RowImm = -1;
159 int64_t ColImm = -1;
160 // Multiple Shapes
162 SmallVector<int64_t, 0> ImmShapes;
163};
164
165} // namespace llvm
166
167#endif
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
Register Reg
static bool GetImm(MachineInstr *MI, unsigned Op, int64_t &Imm)
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:147
MachineOperand class - Representation of each machine instruction operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
int64_t getColImm(unsigned I=0) const
Definition: TileShapeInfo.h:98
void deduceImm(const MachineRegisterInfo *MRI)
ShapeT(ArrayRef< MachineOperand * > ShapesOperands, const MachineRegisterInfo *MRI=nullptr)
Definition: TileShapeInfo.h:45
int64_t getRowImm(unsigned I=0) const
Definition: TileShapeInfo.h:91
ShapeT(MachineOperand *Row, MachineOperand *Col, const MachineRegisterInfo *MRI=nullptr)
Definition: TileShapeInfo.h:31
bool operator!=(const ShapeT &Shape) const
Definition: TileShapeInfo.h:75
MachineOperand * getRow(unsigned I=0) const
Definition: TileShapeInfo.h:77
bool operator==(const ShapeT &Shape) const
Definition: TileShapeInfo.h:61
MachineOperand * getCol(unsigned I=0) const
Definition: TileShapeInfo.h:84
unsigned getShapeNum()
bool empty() const
Definition: SmallVector.h:82
size_t size() const
Definition: SmallVector.h:79
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2155