LLVM 22.0.0git
X86LowerAMXType.cpp
Go to the documentation of this file.
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
43#include "llvm/ADT/SetVector.h"
46#include "llvm/CodeGen/Passes.h"
49#include "llvm/IR/DataLayout.h"
50#include "llvm/IR/Function.h"
51#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/Pass.h"
61
62#include <map>
63
64using namespace llvm;
65using namespace PatternMatch;
66
67#define DEBUG_TYPE "lower-amx-type"
68
74
75// Some instructions may return more than one tiles.
76// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal
77static unsigned getNumDefTiles(IntrinsicInst *II) {
78 Type *Ty = II->getType();
79 if (Ty->isX86_AMXTy())
80 return 1;
81
82 unsigned Num = 0;
83 for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) {
84 Type *STy = Ty->getContainedType(i);
85 if (STy->isX86_AMXTy())
86 Num++;
87 }
88 return Num;
89}
90
91static bool isAMXIntrinsic(Value *I) {
93 if (!II)
94 return false;
95 if (isAMXCast(II))
96 return false;
97 // Check if return type or parameter is x86_amx. If it is x86_amx
98 // the intrinsic must be x86 amx intrinsics.
99 if (getNumDefTiles(II) > 0)
100 return true;
101 for (Value *V : II->args()) {
102 if (V->getType()->isX86_AMXTy())
103 return true;
104 }
105
106 return false;
107}
108
110 for (BasicBlock &BB : F)
111 for (Instruction &I : BB)
112 if (I.getType()->isX86_AMXTy())
113 return true;
114 return false;
115}
116
118 Type *Ty) {
119 Function &F = *BB->getParent();
120 const DataLayout &DL = F.getDataLayout();
121
122 LLVMContext &Ctx = Builder.getContext();
123 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
124 unsigned AllocaAS = DL.getAllocaAddrSpace();
125 AllocaInst *AllocaRes =
126 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
127 AllocaRes->setAlignment(AllocaAlignment);
128 return AllocaRes;
129}
130
132 for (Instruction &I : F.getEntryBlock())
133 if (!isa<AllocaInst>(&I))
134 return &I;
135 llvm_unreachable("No terminator in the entry block!");
136}
137
139private:
140 TargetMachine *TM = nullptr;
141
142 // In AMX intrinsics we let Shape = {Row, Col}, but the
143 // RealCol = Col / ElementSize. We may use the RealCol
144 // as a new Row for other new created AMX intrinsics.
145 std::map<Value *, Value *> Col2Row, Row2Col;
146
147public:
148 ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {}
149 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
150 std::pair<Value *, Value *> getShape(PHINode *Phi);
151 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
152 Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
153};
154
156 unsigned Granularity) {
157 if (auto It = Col2Row.find(V); It != Col2Row.end())
158 return It->second;
159 IRBuilder<> Builder(II);
160 Value *RealRow = nullptr;
161 if (isa<ConstantInt>(V))
162 RealRow =
163 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
164 else if (isa<Instruction>(V)) {
165 // When it is not a const value and it is not a function argument, we
166 // create Row after the definition of V instead of
167 // before II. For example, II is %118, we try to getshape for %117:
168 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
169 // i32> %115).
170 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
171 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
172 // %117).
173 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
174 // definition is after its user(new tileload for %117).
175 // So, the best choice is to create %row right after the definition of
176 // %106.
177 Builder.SetInsertPoint(cast<Instruction>(V));
178 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
179 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
180 } else {
181 // When it is not a const value and it is a function argument, we create
182 // Row at the entry bb.
183 IRBuilder<> NewBuilder(
184 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
185 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
186 }
187 Col2Row[V] = RealRow;
188 return RealRow;
189}
190
192 unsigned Granularity) {
193 if (auto It = Row2Col.find(V); It != Row2Col.end())
194 return It->second;
195 IRBuilder<> Builder(II);
196 Value *RealCol = nullptr;
197 if (isa<ConstantInt>(V))
198 RealCol =
199 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
200 else if (isa<Instruction>(V)) {
201 Builder.SetInsertPoint(cast<Instruction>(V));
202 RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
203 cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
204 } else {
205 // When it is not a const value and it is a function argument, we create
206 // Row at the entry bb.
207 IRBuilder<> NewBuilder(
208 getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
209 RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity));
210 }
211 Row2Col[V] = RealCol;
212 return RealCol;
213}
214
215// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
216std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
217 unsigned OpNo) {
218 (void)TM;
219 IRBuilder<> Builder(II);
220 Value *Row = nullptr, *Col = nullptr;
221 switch (II->getIntrinsicID()) {
222 default:
223 llvm_unreachable("Expect amx intrinsics");
224 case Intrinsic::x86_t2rpntlvwz0_internal:
225 case Intrinsic::x86_t2rpntlvwz0t1_internal:
226 case Intrinsic::x86_t2rpntlvwz1_internal:
227 case Intrinsic::x86_t2rpntlvwz1t1_internal:
228 case Intrinsic::x86_tileloadd64_internal:
229 case Intrinsic::x86_tileloaddt164_internal:
230 case Intrinsic::x86_tilestored64_internal:
231 case Intrinsic::x86_t2rpntlvwz0rs_internal:
232 case Intrinsic::x86_t2rpntlvwz0rst1_internal:
233 case Intrinsic::x86_t2rpntlvwz1rs_internal:
234 case Intrinsic::x86_t2rpntlvwz1rst1_internal:
235 case Intrinsic::x86_tileloaddrs64_internal:
236 case Intrinsic::x86_tileloaddrst164_internal: {
237 Row = II->getArgOperand(0);
238 Col = II->getArgOperand(1);
239 break;
240 }
241 // a * b + c
242 // The shape depends on which operand.
243 case Intrinsic::x86_tcmmimfp16ps_internal:
244 case Intrinsic::x86_tcmmrlfp16ps_internal:
245 case Intrinsic::x86_tdpbssd_internal:
246 case Intrinsic::x86_tdpbsud_internal:
247 case Intrinsic::x86_tdpbusd_internal:
248 case Intrinsic::x86_tdpbuud_internal:
249 case Intrinsic::x86_tdpbf16ps_internal:
250 case Intrinsic::x86_tdpfp16ps_internal:
251 case Intrinsic::x86_tmmultf32ps_internal:
252 case Intrinsic::x86_tdpbf8ps_internal:
253 case Intrinsic::x86_tdpbhf8ps_internal:
254 case Intrinsic::x86_tdphbf8ps_internal:
255 case Intrinsic::x86_tdphf8ps_internal: {
256 switch (OpNo) {
257 case 3:
258 Row = II->getArgOperand(0);
259 Col = II->getArgOperand(1);
260 break;
261 case 4:
262 Row = II->getArgOperand(0);
263 Col = II->getArgOperand(2);
264 break;
265 case 5:
266 Row = getRowFromCol(II, II->getArgOperand(2), 4);
267 Col = II->getArgOperand(1);
268 break;
269 }
270 break;
271 }
272 case Intrinsic::x86_ttransposed_internal:
273 case Intrinsic::x86_tconjtfp16_internal: {
274 assert((OpNo == 2) && "Illegal Operand Number.");
275 Row = getRowFromCol(II, II->getArgOperand(1), 4);
276 Col = getColFromRow(II, II->getArgOperand(0), 4);
277 break;
278 }
279 case Intrinsic::x86_tcvtrowd2ps_internal:
280 case Intrinsic::x86_tcvtrowps2bf16h_internal:
281 case Intrinsic::x86_tcvtrowps2bf16l_internal:
282 case Intrinsic::x86_tcvtrowps2phh_internal:
283 case Intrinsic::x86_tcvtrowps2phl_internal:
284 case Intrinsic::x86_tilemovrow_internal: {
285 assert(OpNo == 2 && "Illegal Operand Number.");
286 Row = II->getArgOperand(0);
287 Col = II->getArgOperand(1);
288 break;
289 }
290 case Intrinsic::x86_ttdpbf16ps_internal:
291 case Intrinsic::x86_ttdpfp16ps_internal:
292 case Intrinsic::x86_ttcmmimfp16ps_internal:
293 case Intrinsic::x86_ttcmmrlfp16ps_internal:
294 case Intrinsic::x86_tconjtcmmimfp16ps_internal:
295 case Intrinsic::x86_ttmmultf32ps_internal: {
296 switch (OpNo) {
297 case 3:
298 Row = II->getArgOperand(0);
299 Col = II->getArgOperand(1);
300 break;
301 case 4:
302 Row = getRowFromCol(II, II->getArgOperand(2), 4);
303 Col = getColFromRow(II, II->getArgOperand(0), 4);
304 break;
305 case 5:
306 Row = getRowFromCol(II, II->getArgOperand(2), 4);
307 Col = II->getArgOperand(1);
308 break;
309 }
310 break;
311 }
312 }
313
314 return std::make_pair(Row, Col);
315}
316
317std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {
318 Use &U = *(Phi->use_begin());
319 unsigned OpNo = U.getOperandNo();
320 User *V = U.getUser();
321 // TODO We don't traverse all users. To make the algorithm simple, here we
322 // just traverse the first user. If we can find shape, then return the shape,
323 // otherwise just return nullptr and the optimization for undef/zero will be
324 // abandoned.
325 while (V) {
327 if (V->use_empty())
328 break;
329 Use &U = *(V->use_begin());
330 OpNo = U.getOperandNo();
331 V = U.getUser();
332 } else if (isAMXIntrinsic(V)) {
333 return getShape(cast<IntrinsicInst>(V), OpNo);
334 } else if (isa<PHINode>(V)) {
335 if (V->use_empty())
336 break;
337 Use &U = *(V->use_begin());
338 V = U.getUser();
339 } else {
340 break;
341 }
342 }
343
344 return std::make_pair(nullptr, nullptr);
345}
346
347namespace {
348class X86LowerAMXType {
349 Function &Func;
350 ShapeCalculator *SC;
351
352 // In AMX intrinsics we let Shape = {Row, Col}, but the
353 // RealCol = Col / ElementSize. We may use the RealCol
354 // as a new Row for other new created AMX intrinsics.
355 std::map<Value *, Value *> Col2Row, Row2Col;
356
357public:
358 X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {}
359 bool visit();
360 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
361 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
362 bool transformBitcast(BitCastInst *Bitcast);
363};
364
365// %src = load <256 x i32>, <256 x i32>* %addr, align 64
366// %2 = bitcast <256 x i32> %src to x86_amx
367// -->
368// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
369// i8* %addr, i64 %stride64)
370void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
371 Value *Row = nullptr, *Col = nullptr;
372 Use &U = *(Bitcast->use_begin());
373 unsigned OpNo = U.getOperandNo();
374 auto *II = cast<IntrinsicInst>(U.getUser());
375 std::tie(Row, Col) = SC->getShape(II, OpNo);
376 IRBuilder<> Builder(Bitcast);
377 // Use the maximun column as stride.
378 Value *Stride = Builder.getInt64(64);
379 Value *I8Ptr = LD->getOperand(0);
380 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
381
382 Value *NewInst =
383 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
384 Bitcast->replaceAllUsesWith(NewInst);
385}
386
387// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
388// %stride);
389// %13 = bitcast x86_amx %src to <256 x i32>
390// store <256 x i32> %13, <256 x i32>* %addr, align 64
391// -->
392// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
393// %stride64, %13)
394void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
395
396 Value *Tile = Bitcast->getOperand(0);
397 auto *II = cast<IntrinsicInst>(Tile);
398 // Tile is output from AMX intrinsic. The first operand of the
399 // intrinsic is row, the second operand of the intrinsic is column.
400 Value *Row = II->getOperand(0);
401 Value *Col = II->getOperand(1);
402 IRBuilder<> Builder(ST);
403 // Use the maximum column as stride. It must be the same with load
404 // stride.
405 Value *Stride = Builder.getInt64(64);
406 Value *I8Ptr = ST->getOperand(1);
407 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
408 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
409 if (Bitcast->hasOneUse())
410 return;
411 // %13 = bitcast x86_amx %src to <256 x i32>
412 // store <256 x i32> %13, <256 x i32>* %addr, align 64
413 // %add = <256 x i32> %13, <256 x i32> %src2
414 // -->
415 // %13 = bitcast x86_amx %src to <256 x i32>
416 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
417 // %stride64, %13)
418 // %14 = load <256 x i32>, %addr
419 // %add = <256 x i32> %14, <256 x i32> %src2
420 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
421 Bitcast->replaceAllUsesWith(Vec);
422}
423
424// transform bitcast to <store, load> instructions.
425bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
426 IRBuilder<> Builder(Bitcast);
427 AllocaInst *AllocaAddr;
428 Value *I8Ptr, *Stride;
429 auto *Src = Bitcast->getOperand(0);
430
431 auto Prepare = [&](Type *MemTy) {
432 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
433 I8Ptr = AllocaAddr;
434 Stride = Builder.getInt64(64);
435 };
436
437 if (Bitcast->getType()->isX86_AMXTy()) {
438 // %2 = bitcast <256 x i32> %src to x86_amx
439 // -->
440 // %addr = alloca <256 x i32>, align 64
441 // store <256 x i32> %src, <256 x i32>* %addr, align 64
442 // %addr2 = bitcast <256 x i32>* to i8*
443 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
444 // i8* %addr2,
445 // i64 64)
446 Use &U = *(Bitcast->use_begin());
447 unsigned OpNo = U.getOperandNo();
448 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
449 if (!II)
450 return false; // May be bitcast from x86amx to <256 x i32>.
451 Prepare(Bitcast->getOperand(0)->getType());
452 Builder.CreateStore(Src, AllocaAddr);
453 // TODO we can pick an constant operand for the shape.
454 Value *Row = nullptr, *Col = nullptr;
455 std::tie(Row, Col) = SC->getShape(II, OpNo);
456 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
457 Value *NewInst =
458 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
459 Bitcast->replaceAllUsesWith(NewInst);
460 } else {
461 // %2 = bitcast x86_amx %src to <256 x i32>
462 // -->
463 // %addr = alloca <256 x i32>, align 64
464 // %addr2 = bitcast <256 x i32>* to i8*
465 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
466 // i8* %addr2, i64 %stride)
467 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
468 auto *II = dyn_cast<IntrinsicInst>(Src);
469 if (!II)
470 return false; // May be bitcast from <256 x i32> to x86amx.
471 Prepare(Bitcast->getType());
472 Value *Row = II->getOperand(0);
473 Value *Col = II->getOperand(1);
474 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
476 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
477 Bitcast->replaceAllUsesWith(NewInst);
478 }
479
480 return true;
481}
482
483bool X86LowerAMXType::visit() {
484 SmallVector<Instruction *, 8> DeadInsts;
485 Col2Row.clear();
486
487 for (BasicBlock *BB : post_order(&Func)) {
488 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
489 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
490 if (!Bitcast)
491 continue;
492
493 Value *Src = Bitcast->getOperand(0);
494 if (Bitcast->getType()->isX86_AMXTy()) {
495 if (Bitcast->user_empty()) {
496 DeadInsts.push_back(Bitcast);
497 continue;
498 }
499 LoadInst *LD = dyn_cast<LoadInst>(Src);
500 if (!LD) {
501 if (transformBitcast(Bitcast))
502 DeadInsts.push_back(Bitcast);
503 continue;
504 }
505 // If load has multi-user, duplicate a vector load.
506 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
507 // %2 = bitcast <256 x i32> %src to x86_amx
508 // %add = add <256 x i32> %src, <256 x i32> %src2
509 // -->
510 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
511 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
512 // i8* %addr, i64 %stride64)
513 // %add = add <256 x i32> %src, <256 x i32> %src2
514
515 // If load has one user, the load will be eliminated in DAG ISel.
516 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
517 // %2 = bitcast <256 x i32> %src to x86_amx
518 // -->
519 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
520 // i8* %addr, i64 %stride64)
521 combineLoadBitcast(LD, Bitcast);
522 DeadInsts.push_back(Bitcast);
523 if (LD->hasOneUse())
524 DeadInsts.push_back(LD);
525 } else if (Src->getType()->isX86_AMXTy()) {
526 if (Bitcast->user_empty()) {
527 DeadInsts.push_back(Bitcast);
528 continue;
529 }
530 StoreInst *ST = nullptr;
531 for (Use &U : Bitcast->uses()) {
532 ST = dyn_cast<StoreInst>(U.getUser());
533 if (ST)
534 break;
535 }
536 if (!ST) {
537 if (transformBitcast(Bitcast))
538 DeadInsts.push_back(Bitcast);
539 continue;
540 }
541 // If bitcast (%13) has one use, combine bitcast and store to amx store.
542 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
543 // %stride);
544 // %13 = bitcast x86_amx %src to <256 x i32>
545 // store <256 x i32> %13, <256 x i32>* %addr, align 64
546 // -->
547 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
548 // %stride64, %13)
549 //
550 // If bitcast (%13) has multi-use, transform as below.
551 // %13 = bitcast x86_amx %src to <256 x i32>
552 // store <256 x i32> %13, <256 x i32>* %addr, align 64
553 // %add = <256 x i32> %13, <256 x i32> %src2
554 // -->
555 // %13 = bitcast x86_amx %src to <256 x i32>
556 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
557 // %stride64, %13)
558 // %14 = load <256 x i32>, %addr
559 // %add = <256 x i32> %14, <256 x i32> %src2
560 //
561 combineBitcastStore(Bitcast, ST);
562 // Delete user first.
563 DeadInsts.push_back(ST);
564 DeadInsts.push_back(Bitcast);
565 }
566 }
567 }
568
569 bool C = !DeadInsts.empty();
570
571 for (auto *Inst : DeadInsts)
572 Inst->eraseFromParent();
573
574 return C;
575}
576} // anonymous namespace
577
579 Function *F = BB->getParent();
580 IRBuilder<> Builder(&F->getEntryBlock().front());
581 const DataLayout &DL = F->getDataLayout();
582 unsigned AllocaAS = DL.getAllocaAddrSpace();
583 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
584 AllocaInst *AllocaRes =
585 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
586 BasicBlock::iterator Iter = AllocaRes->getIterator();
587 ++Iter;
588 Builder.SetInsertPoint(&*Iter);
589 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
590 return I8Ptr;
591}
592
594 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
595 auto *II = dyn_cast<IntrinsicInst>(TileDef);
596 unsigned Idx = 0;
597 // Extract tile from multiple tiles' def.
598 if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) {
599 assert(Extr->hasIndices() && "Tile extract miss index!");
600 Idx = Extr->getIndices()[0];
601 II = cast<IntrinsicInst>(Extr->getOperand(0));
602 }
603
604 assert(II && "Not tile intrinsic!");
605 Value *Row = II->getOperand(Idx);
606 Value *Col = II->getOperand(Idx + 1);
607
608 BasicBlock *BB = TileDef->getParent();
609 BasicBlock::iterator Iter = TileDef->getIterator();
610 IRBuilder<> Builder(BB, ++Iter);
611 Value *Stride = Builder.getInt64(64);
612 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
613
614 Instruction *TileStore =
615 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
616 return TileStore;
617}
618
619static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
620 Value *V = U.get();
621 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
622
623 // Get tile shape.
624 IntrinsicInst *II = nullptr;
625 unsigned Idx = 0;
626 if (IsPHI) {
627 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
628 II = cast<IntrinsicInst>(PhiOp);
629 } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) {
630 // Extract tile from multiple tiles' def.
631 assert(Extr->hasIndices() && "Tile extract miss index!");
632 Idx = Extr->getIndices()[0];
633 II = cast<IntrinsicInst>(Extr->getOperand(0));
634 } else {
636 }
637 Value *Row = II->getOperand(Idx);
638 Value *Col = II->getOperand(Idx + 1);
639
640 Instruction *UserI = cast<Instruction>(U.getUser());
641 IRBuilder<> Builder(UserI);
642 Value *Stride = Builder.getInt64(64);
643 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
644
645 Value *TileLoad =
646 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
647 UserI->replaceUsesOfWith(V, TileLoad);
648}
649
651 for (Use &U : I->uses()) {
652 User *V = U.getUser();
653 if (isa<PHINode>(V))
654 return true;
655 }
656 return false;
657}
658
659// Let all AMX tile data become volatile data, shorten the life range
660// of each tile register before fast register allocation.
661namespace {
662class X86VolatileTileData {
663 Function &F;
664
665public:
666 X86VolatileTileData(Function &Func) : F(Func) {}
667 Value *updatePhiIncomings(BasicBlock *BB,
668 SmallVector<Instruction *, 2> &Incomings);
669 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
670 bool volatileTileData();
671 void volatileTilePHI(PHINode *PHI);
672 void volatileTileNonPHI(Instruction *I);
673};
674
675Value *X86VolatileTileData::updatePhiIncomings(
676 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
677 Value *I8Ptr = getAllocaPos(BB);
678
679 for (auto *I : Incomings) {
680 User *Store = createTileStore(I, I8Ptr);
681
682 // All its uses (except phi) should load from stored mem.
683 for (Use &U : I->uses()) {
684 User *V = U.getUser();
685 if (isa<PHINode>(V) || V == Store)
686 continue;
687 replaceWithTileLoad(U, I8Ptr);
688 }
689 }
690 return I8Ptr;
691}
692
693void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
694 Value *StorePtr) {
695 for (Use &U : PHI->uses())
696 replaceWithTileLoad(U, StorePtr, true);
697 PHI->eraseFromParent();
698}
699
700// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
701// and their related AMX intrinsics.
702// 1) PHI Def should change to tileload.
703// 2) PHI Incoming Values should tilestored in just after their def.
704// 3) The mem of these tileload and tilestores should be same.
705// e.g.
706// ------------------------------------------------------
707// bb_dom:
708// ...
709// br i1 %bool.cond, label %if.else, label %if.then
710//
711// if.then:
712// def %t0 = ...
713// ...
714// use %t0
715// ...
716// br label %if.end
717//
718// if.else:
719// def %t1 = ...
720// br label %if.end
721//
722// if.end:
723// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
724// ...
725// use %td
726// ------------------------------------------------------
727// -->
728// ------------------------------------------------------
729// bb_entry:
730// %mem = alloca <256 x i32>, align 1024 *
731// ...
732// bb_dom:
733// ...
734// br i1 %bool.cond, label %if.else, label %if.then
735//
736// if.then:
737// def %t0 = ...
738// call void @llvm.x86.tilestored64.internal(mem, %t0) *
739// ...
740// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
741// use %t0` *
742// ...
743// br label %if.end
744//
745// if.else:
746// def %t1 = ...
747// call void @llvm.x86.tilestored64.internal(mem, %t1) *
748// br label %if.end
749//
750// if.end:
751// ...
752// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
753// use %td
754// ------------------------------------------------------
755void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
756 BasicBlock *BB = PHI->getParent();
757 SmallVector<Instruction *, 2> Incomings;
758
759 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
760 Value *Op = PHI->getIncomingValue(I);
762 assert(Inst && "We shouldn't fold AMX instrution!");
763 Incomings.push_back(Inst);
764 }
765
766 Value *StorePtr = updatePhiIncomings(BB, Incomings);
767 replacePhiDefWithLoad(PHI, StorePtr);
768}
769
770// Store the defined tile and load it before use.
771// All its users are not PHI.
772// e.g.
773// ------------------------------------------------------
774// def %td = ...
775// ...
776// "use %td"
777// ------------------------------------------------------
778// -->
779// ------------------------------------------------------
780// def %td = ...
781// call void @llvm.x86.tilestored64.internal(mem, %td)
782// ...
783// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
784// "use %td2"
785// ------------------------------------------------------
786void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
787 BasicBlock *BB = I->getParent();
788 Value *I8Ptr = getAllocaPos(BB);
789 User *Store = createTileStore(I, I8Ptr);
790
791 // All its uses should load from stored mem.
792 for (Use &U : I->uses()) {
793 User *V = U.getUser();
794 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
795 if (V != Store)
796 replaceWithTileLoad(U, I8Ptr);
797 }
798}
799
800// Volatile Tile Model:
801// 1) All the uses of tile data comes from tileload in time.
802// 2) All the defs of tile data tilestore into mem immediately.
803// For example:
804// --------------------------------------------------------------------------
805// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
806// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
807// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
808// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
809// call void @llvm.x86.tilestored64.internal(... td) area
810// --------------------------------------------------------------------------
811// 3) No terminator, call or other amx instructions in the key amx area.
812bool X86VolatileTileData::volatileTileData() {
813 bool Changed = false;
814 for (BasicBlock &BB : F) {
815 SmallVector<Instruction *, 2> PHIInsts;
816 SmallVector<Instruction *, 8> AMXDefInsts;
817
818 for (Instruction &I : BB) {
819 if (!I.getType()->isX86_AMXTy())
820 continue;
821 if (isa<PHINode>(&I))
822 PHIInsts.push_back(&I);
823 else
824 AMXDefInsts.push_back(&I);
825 }
826
827 // First we "volatile" the non-phi related amx intrinsics.
828 for (Instruction *I : AMXDefInsts) {
829 if (isIncomingOfPHI(I))
830 continue;
831 volatileTileNonPHI(I);
832 Changed = true;
833 }
834
835 for (Instruction *I : PHIInsts) {
836 volatileTilePHI(dyn_cast<PHINode>(I));
837 Changed = true;
838 }
839 }
840 return Changed;
841}
842
843} // anonymous namespace
844
845namespace {
846
847class X86LowerAMXCast {
848 Function &Func;
849 ShapeCalculator *SC;
850 std::unique_ptr<DominatorTree> DT;
851
852public:
853 X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC)
854 : Func(F), SC(ShapeC), DT(nullptr) {}
855 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
856 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
857 bool combineTilezero(IntrinsicInst *Cast);
858 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
859 bool combineAMXcast(TargetLibraryInfo *TLI);
860 bool transformAMXCast(IntrinsicInst *AMXCast);
861 bool transformAllAMXCast();
862 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
863 SmallSetVector<Instruction *, 16> &DeadInst);
864};
865
866static bool DCEInstruction(Instruction *I,
867 SmallSetVector<Instruction *, 16> &WorkList,
868 const TargetLibraryInfo *TLI) {
869 if (isInstructionTriviallyDead(I, TLI)) {
872
873 // Null out all of the instruction's operands to see if any operand becomes
874 // dead as we go.
875 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
876 Value *OpV = I->getOperand(i);
877 I->setOperand(i, nullptr);
878
879 if (!OpV->use_empty() || I == OpV)
880 continue;
881
882 // If the operand is an instruction that became dead as we nulled out the
883 // operand, and if it is 'trivially' dead, delete it in a future loop
884 // iteration.
885 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
886 if (isInstructionTriviallyDead(OpI, TLI)) {
887 WorkList.insert(OpI);
888 }
889 }
890 }
891 I->eraseFromParent();
892 return true;
893 }
894 return false;
895}
896
897/// This function handles following case
898///
899/// A -> B amxcast
900/// PHI
901/// B -> A amxcast
902///
903/// All the related PHI nodes can be replaced by new PHI nodes with type A.
904/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
905bool X86LowerAMXCast::optimizeAMXCastFromPhi(
906 IntrinsicInst *CI, PHINode *PN,
907 SmallSetVector<Instruction *, 16> &DeadInst) {
908 IRBuilder<> Builder(CI);
909 Value *Src = CI->getOperand(0);
910 Type *SrcTy = Src->getType(); // Type B
911 Type *DestTy = CI->getType(); // Type A
912
913 SmallVector<PHINode *, 4> PhiWorklist;
914 SmallSetVector<PHINode *, 4> OldPhiNodes;
915
916 // Find all of the A->B casts and PHI nodes.
917 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
918 // OldPhiNodes is used to track all known PHI nodes, before adding a new
919 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
920 PhiWorklist.push_back(PN);
921 OldPhiNodes.insert(PN);
922 while (!PhiWorklist.empty()) {
923 auto *OldPN = PhiWorklist.pop_back_val();
924 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
925 Value *IncValue = OldPN->getIncomingValue(I);
926 // TODO: currently, We ignore cases where it is a const. In the future, we
927 // might support const.
928 if (isa<Constant>(IncValue)) {
929 auto *IncConst = dyn_cast<Constant>(IncValue);
930 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
931 return false;
932 Value *Row = nullptr, *Col = nullptr;
933 std::tie(Row, Col) = SC->getShape(OldPN);
934 // TODO: If it is not constant the Row and Col must domoniate tilezero
935 // that we are going to create.
936 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
937 return false;
938 // Create tilezero at the end of incoming block.
939 auto *Block = OldPN->getIncomingBlock(I);
940 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
941 Instruction *NewInst = Builder.CreateIntrinsic(
942 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
943 NewInst->moveBefore(Iter);
944 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
945 {IncValue->getType()}, {NewInst});
946 NewInst->moveBefore(Iter);
947 // Replace InValue with new Value.
948 OldPN->setIncomingValue(I, NewInst);
949 IncValue = NewInst;
950 }
951
952 if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
953 if (OldPhiNodes.insert(PNode))
954 PhiWorklist.push_back(PNode);
955 continue;
956 }
957 Instruction *ACI = dyn_cast<Instruction>(IncValue);
958 if (ACI && isAMXCast(ACI)) {
959 // Verify it's a A->B cast.
960 Type *TyA = ACI->getOperand(0)->getType();
961 Type *TyB = ACI->getType();
962 if (TyA != DestTy || TyB != SrcTy)
963 return false;
964 continue;
965 }
966 return false;
967 }
968 }
969
970 // Check that each user of each old PHI node is something that we can
971 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
972 for (auto *OldPN : OldPhiNodes) {
973 for (User *V : OldPN->users()) {
975 if (ACI && isAMXCast(ACI)) {
976 // Verify it's a B->A cast.
977 Type *TyB = ACI->getOperand(0)->getType();
978 Type *TyA = ACI->getType();
979 if (TyA != DestTy || TyB != SrcTy)
980 return false;
981 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
982 // As long as the user is another old PHI node, then even if we don't
983 // rewrite it, the PHI web we're considering won't have any users
984 // outside itself, so it'll be dead.
985 // example:
986 // bb.0:
987 // %0 = amxcast ...
988 // bb.1:
989 // %1 = amxcast ...
990 // bb.2:
991 // %goodphi = phi %0, %1
992 // %3 = amxcast %goodphi
993 // bb.3:
994 // %goodphi2 = phi %0, %goodphi
995 // %4 = amxcast %goodphi2
996 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
997 // outside the phi-web, so the combination stop When
998 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
999 // will be done.
1000 if (OldPhiNodes.count(PHI) == 0)
1001 return false;
1002 } else
1003 return false;
1004 }
1005 }
1006
1007 // For each old PHI node, create a corresponding new PHI node with a type A.
1008 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
1009 for (auto *OldPN : OldPhiNodes) {
1010 Builder.SetInsertPoint(OldPN);
1011 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
1012 NewPNodes[OldPN] = NewPN;
1013 }
1014
1015 // Fill in the operands of new PHI nodes.
1016 for (auto *OldPN : OldPhiNodes) {
1017 PHINode *NewPN = NewPNodes[OldPN];
1018 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
1019 Value *V = OldPN->getOperand(j);
1020 Value *NewV = nullptr;
1022 // There should not be a AMXcast from a const.
1023 if (ACI && isAMXCast(ACI))
1024 NewV = ACI->getOperand(0);
1025 else if (auto *PrevPN = dyn_cast<PHINode>(V))
1026 NewV = NewPNodes[PrevPN];
1027 assert(NewV);
1028 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
1029 }
1030 }
1031
1032 // Traverse all accumulated PHI nodes and process its users,
1033 // which are Stores and BitcCasts. Without this processing
1034 // NewPHI nodes could be replicated and could lead to extra
1035 // moves generated after DeSSA.
1036 // If there is a store with type B, change it to type A.
1037
1038 // Replace users of BitCast B->A with NewPHI. These will help
1039 // later to get rid of a closure formed by OldPHI nodes.
1040 for (auto *OldPN : OldPhiNodes) {
1041 PHINode *NewPN = NewPNodes[OldPN];
1042 for (User *V : make_early_inc_range(OldPN->users())) {
1044 if (ACI && isAMXCast(ACI)) {
1045 Type *TyB = ACI->getOperand(0)->getType();
1046 Type *TyA = ACI->getType();
1047 assert(TyA == DestTy && TyB == SrcTy);
1048 (void)TyA;
1049 (void)TyB;
1050 ACI->replaceAllUsesWith(NewPN);
1051 DeadInst.insert(ACI);
1052 } else if (auto *PHI = dyn_cast<PHINode>(V)) {
1053 // We don't need to push PHINode into DeadInst since they are operands
1054 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
1055 assert(OldPhiNodes.contains(PHI));
1056 (void)PHI;
1057 } else
1058 llvm_unreachable("all uses should be handled");
1059 }
1060 }
1061 return true;
1062}
1063
1064static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx,
1065 bool IsRow) {
1066 if (!isAMXIntrinsic(Inst))
1067 return nullptr;
1068
1069 auto *II = cast<IntrinsicInst>(Inst);
1070 if (IsRow)
1071 return II->getOperand(0);
1072
1073 assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!");
1074 return II->getOperand(ShapeIdx + 1);
1075}
1076
1077// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
1078// store <256 x i32> %43, <256 x i32>* %p, align 64
1079// -->
1080// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1081// i64 64, x86_amx %42)
1082bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
1083 Value *Tile = Cast->getOperand(0);
1084
1085 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
1086
1087 // TODO: Specially handle the multi-use case.
1088 if (!Tile->hasOneUse())
1089 return false;
1090
1091 // We don't fetch shape from tilestore, we only get shape from tiledef,
1092 // so we can set the max tile shape to tilestore for special cases.
1093 IRBuilder<> Builder(ST);
1094 Value *Row = nullptr;
1095 Value *Col = nullptr;
1096
1097 if (isAMXIntrinsic(Tile)) {
1098 auto *II = cast<IntrinsicInst>(Tile);
1099 // Tile is output from AMX intrinsic. The first operand of the
1100 // intrinsic is row, the second operand of the intrinsic is column.
1101 Row = II->getOperand(0);
1102 Col = II->getOperand(1);
1103 } else {
1104 // Now we supported multi-tiles value in structure, so we may get tile
1105 // from extracting multi-tiles structure.
1106 // For example:
1107 // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1,
1108 // i16 %2, i16 %3, i8* %4, i64 %5)
1109 // %7 = extractvalue { x86_amx, x86_amx } %6, 0
1110 // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7)
1111 // store <256 x i32> %8, <256 x i32>* %0, align 1024
1112 //
1113 // TODO: Currently we only handle extractvalue case, enhance me for other
1114 // cases if possible.
1115 auto *II = cast<ExtractValueInst>(Tile);
1116 assert(II && "We meet unhandle source in fetching tile value!");
1117 unsigned ShapeIdx = II->getIndices()[0];
1118 Value *Tiles = II->getOperand(0);
1119 Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true);
1120 Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false);
1121 }
1122 assert(Row && Col && "Shape got failed!");
1123
1124 // Stride should be equal to col(measured by bytes)
1125 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1126 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
1127 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
1128 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1129 return true;
1130}
1131
1132// %65 = load <256 x i32>, <256 x i32>* %p, align 64
1133// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1134// -->
1135// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1136// i8* %p, i64 64)
1137bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
1138 bool EraseLoad = true;
1139 Value *Row = nullptr, *Col = nullptr;
1140 Use &U = *(Cast->use_begin());
1141 unsigned OpNo = U.getOperandNo();
1142 auto *II = cast<IntrinsicInst>(U.getUser());
1143 // TODO: If it is cast intrinsic or phi node, we can propagate the
1144 // shape information through def-use chain.
1145 if (!isAMXIntrinsic(II))
1146 return false;
1147 std::tie(Row, Col) = SC->getShape(II, OpNo);
1148 IRBuilder<> Builder(LD);
1149 // Stride should be equal to col(measured by bytes)
1150 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1151 Value *I8Ptr;
1152
1153 // To save compiling time, we create doninator tree when it is really
1154 // needed.
1155 if (!DT)
1156 DT.reset(new DominatorTree(Func));
1157 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1158 // store the value to stack and reload it from stack before cast.
1159 auto *AllocaAddr =
1160 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1161 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1162 Builder.CreateStore(LD, AllocaAddr);
1163
1164 Builder.SetInsertPoint(Cast);
1165 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1166 EraseLoad = false;
1167 } else {
1168 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1169 }
1170 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1171
1172 Value *NewInst =
1173 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1174 Cast->replaceAllUsesWith(NewInst);
1175
1176 return EraseLoad;
1177}
1178
1179// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1180// -->
1181// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1182bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1183 Value *Row = nullptr, *Col = nullptr;
1184 Use &U = *(Cast->use_begin());
1185 unsigned OpNo = U.getOperandNo();
1186 auto *II = cast<IntrinsicInst>(U.getUser());
1187 if (!isAMXIntrinsic(II))
1188 return false;
1189
1190 std::tie(Row, Col) = SC->getShape(II, OpNo);
1191
1192 IRBuilder<> Builder(Cast);
1193 Value *NewInst =
1194 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1195 Cast->replaceAllUsesWith(NewInst);
1196 return true;
1197}
1198
1199bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1200 bool Change = false;
1201 for (auto *Cast : Casts) {
1202 auto *II = cast<IntrinsicInst>(Cast);
1203 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1204 // store <256 x i32> %43, <256 x i32>* %p, align 64
1205 // -->
1206 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1207 // i64 64, x86_amx %42)
1208 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1209 SmallVector<Instruction *, 2> DeadStores;
1210 for (User *U : Cast->users()) {
1211 StoreInst *Store = dyn_cast<StoreInst>(U);
1212 if (!Store)
1213 continue;
1214 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1215 DeadStores.push_back(Store);
1216 Change = true;
1217 }
1218 }
1219 for (auto *Store : DeadStores)
1220 Store->eraseFromParent();
1221 } else { // x86_cast_vector_to_tile
1222 // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1223 // -->
1224 // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1225 if (isa<ConstantAggregateZero>(Cast->getOperand(0))) {
1226 Change |= combineTilezero(cast<IntrinsicInst>(Cast));
1227 continue;
1228 }
1229
1230 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1231 if (!Load || !Load->hasOneUse())
1232 continue;
1233 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1234 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1235 // -->
1236 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1237 // i8* %p, i64 64)
1238 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1239 // Set the operand is null so that load instruction can be erased.
1240 Cast->setOperand(0, nullptr);
1241 Load->eraseFromParent();
1242 Change = true;
1243 }
1244 }
1245 }
1246 return Change;
1247}
1248
1249bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1250 bool Change = false;
1251 // Collect tile cast instruction.
1252 SmallVector<Instruction *, 8> Vec2TileInsts;
1253 SmallVector<Instruction *, 8> Tile2VecInsts;
1254 SmallVector<Instruction *, 8> PhiCastWorkList;
1255 SmallSetVector<Instruction *, 16> DeadInst;
1256 for (BasicBlock &BB : Func) {
1257 for (Instruction &I : BB) {
1258 Value *Vec;
1259 if (match(&I,
1261 Vec2TileInsts.push_back(&I);
1263 m_Value(Vec))))
1264 Tile2VecInsts.push_back(&I);
1265 }
1266 }
1267
1268 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1269 for (auto *Inst : Insts) {
1270 for (User *U : Inst->users()) {
1271 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1272 if (!II || II->getIntrinsicID() != IID)
1273 continue;
1274 // T1 = vec2tile V0
1275 // V2 = tile2vec T1
1276 // V3 = OP V2
1277 // -->
1278 // T1 = vec2tile V0
1279 // V2 = tile2vec T1
1280 // V3 = OP V0
1281 II->replaceAllUsesWith(Inst->getOperand(0));
1282 Change = true;
1283 }
1284 }
1285 };
1286
1287 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1288 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1289
1290 SmallVector<Instruction *, 8> LiveCasts;
1291 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1292 for (auto *Inst : Insts) {
1293 if (Inst->use_empty()) {
1294 Inst->eraseFromParent();
1295 Change = true;
1296 } else {
1297 LiveCasts.push_back(Inst);
1298 }
1299 }
1300 };
1301
1302 EraseInst(Vec2TileInsts);
1303 EraseInst(Tile2VecInsts);
1304 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1305 "Vec2Tile and Tile2Vec:\n";
1306 Func.dump());
1307 Change |= combineLdSt(LiveCasts);
1308 EraseInst(LiveCasts);
1309 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1310 "AMXCast and load/store:\n";
1311 Func.dump());
1312
1313 // Handle the A->B->A cast, and there is an intervening PHI node.
1314 for (BasicBlock &BB : Func) {
1315 for (Instruction &I : BB) {
1316 if (isAMXCast(&I)) {
1317 if (isa<PHINode>(I.getOperand(0)))
1318 PhiCastWorkList.push_back(&I);
1319 }
1320 }
1321 }
1322 for (auto *I : PhiCastWorkList) {
1323 // We skip the dead Amxcast.
1324 if (DeadInst.contains(I))
1325 continue;
1326 PHINode *PN = cast<PHINode>(I->getOperand(0));
1327 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1328 DeadInst.insert(PN);
1329 Change = true;
1330 }
1331 }
1332
1333 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1334 // have no uses. We do some DeadCodeElimination for them.
1335 while (!DeadInst.empty()) {
1336 Instruction *I = DeadInst.pop_back_val();
1337 Change |= DCEInstruction(I, DeadInst, TLI);
1338 }
1339 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1340 "optimizeAMXCastFromPhi:\n";
1341 Func.dump());
1342 return Change;
1343}
1344
1345// There might be remaining AMXcast after combineAMXcast and they should be
1346// handled elegantly.
1347bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1348 IRBuilder<> Builder(AMXCast);
1349 AllocaInst *AllocaAddr;
1350 Value *I8Ptr, *Stride;
1351 auto *Src = AMXCast->getOperand(0);
1352
1353 auto Prepare = [&](Type *MemTy) {
1354 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1355 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1356 Stride = Builder.getInt64(64);
1357 };
1358
1359 if (AMXCast->getType()->isX86_AMXTy()) {
1360 // %2 = amxcast <225 x i32> %src to x86_amx
1361 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1362 // i8* %addr3, i64 60, x86_amx %2)
1363 // -->
1364 // %addr = alloca <225 x i32>, align 64
1365 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1366 // %addr2 = bitcast <225 x i32>* %addr to i8*
1367 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1368 // i8* %addr2,
1369 // i64 60)
1370 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1371 // i8* %addr3, i64 60, x86_amx %2)
1372 if (AMXCast->use_empty()) {
1373 AMXCast->eraseFromParent();
1374 return true;
1375 }
1376 Use &U = *(AMXCast->use_begin());
1377 unsigned OpNo = U.getOperandNo();
1378 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1379 if (!II)
1380 return false; // May be bitcast from x86amx to <256 x i32>.
1381 Prepare(AMXCast->getOperand(0)->getType());
1382 Builder.CreateStore(Src, AllocaAddr);
1383 // TODO we can pick an constant operand for the shape.
1384 Value *Row = nullptr, *Col = nullptr;
1385 std::tie(Row, Col) = SC->getShape(II, OpNo);
1386 std::array<Value *, 4> Args = {
1387 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1388 Value *NewInst =
1389 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1390 AMXCast->replaceAllUsesWith(NewInst);
1391 AMXCast->eraseFromParent();
1392 } else {
1393 // %2 = amxcast x86_amx %src to <225 x i32>
1394 // -->
1395 // %addr = alloca <225 x i32>, align 64
1396 // %addr2 = bitcast <225 x i32>* to i8*
1397 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1398 // i8* %addr2, i64 %stride)
1399 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1400 auto *II = dyn_cast<IntrinsicInst>(Src);
1401 if (!II)
1402 return false; // May be bitcast from <256 x i32> to x86amx.
1403 Prepare(AMXCast->getType());
1404 Value *Row = II->getOperand(0);
1405 Value *Col = II->getOperand(1);
1406 std::array<Value *, 5> Args = {
1407 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1408 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1409 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1410 AMXCast->replaceAllUsesWith(NewInst);
1411 AMXCast->eraseFromParent();
1412 }
1413
1414 return true;
1415}
1416
1417bool X86LowerAMXCast::transformAllAMXCast() {
1418 bool Change = false;
1419 // Collect tile cast instruction.
1420 SmallVector<Instruction *, 8> WorkLists;
1421 for (BasicBlock &BB : Func) {
1422 for (Instruction &I : BB) {
1423 if (isAMXCast(&I))
1424 WorkLists.push_back(&I);
1425 }
1426 }
1427
1428 for (auto *Inst : WorkLists) {
1429 Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1430 }
1431
1432 return Change;
1433}
1434
1435} // anonymous namespace
1436
1437namespace {
1438
1439class X86LowerAMXTypeLegacyPass : public FunctionPass {
1440public:
1441 static char ID;
1442
1443 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1444
1445 bool runOnFunction(Function &F) override {
1446 // Performance optimization: most code doesn't use AMX, so return early if
1447 // there are no instructions that produce AMX values. This is sufficient, as
1448 // AMX arguments and constants are not allowed -- so any producer of an AMX
1449 // value must be an instruction.
1450 // TODO: find a cheaper way for this, without looking at all instructions.
1451 if (!containsAMXCode(F))
1452 return false;
1453
1454 bool C = false;
1455 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1456 TargetLibraryInfo *TLI =
1457 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1458
1459 ShapeCalculator SC(TM);
1460 X86LowerAMXCast LAC(F, &SC);
1461 C |= LAC.combineAMXcast(TLI);
1462 // There might be remaining AMXcast after combineAMXcast and they should be
1463 // handled elegantly.
1464 C |= LAC.transformAllAMXCast();
1465
1466 X86LowerAMXType LAT(F, &SC);
1467 C |= LAT.visit();
1468
1469 // Prepare for fast register allocation at O0.
1470 // Todo: May better check the volatile model of AMX code, not just
1471 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1472 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1473 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1474 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1475 // sure the amx data is volatile, that is nessary for AMX fast
1476 // register allocation.
1477 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1478 X86VolatileTileData VTD(F);
1479 C = VTD.volatileTileData() || C;
1480 }
1481 }
1482
1483 return C;
1484 }
1485
1486 void getAnalysisUsage(AnalysisUsage &AU) const override {
1487 AU.setPreservesCFG();
1488 AU.addRequired<TargetPassConfig>();
1489 AU.addRequired<TargetLibraryInfoWrapperPass>();
1490 }
1491};
1492
1493} // anonymous namespace
1494
1495static const char PassName[] = "Lower AMX type for load/store";
1496char X86LowerAMXTypeLegacyPass::ID = 0;
1497INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1498 false)
1501INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1502 false)
1503
1505 return new X86LowerAMXTypeLegacyPass();
1506}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
Definition DCE.cpp:55
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static const char PassName[]
static bool isAMXCast(Instruction *II)
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
static unsigned getNumDefTiles(IntrinsicInst *II)
static Value * getAllocaPos(BasicBlock *BB)
static bool containsAMXCode(Function &F)
static bool isIncomingOfPHI(Instruction *I)
static bool isAMXIntrinsic(Value *I)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
ShapeCalculator(TargetMachine *TargetM)
Value * getColFromRow(Instruction *II, Value *V, unsigned Granularity)
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
an instruction to allocate memory on the stack
void setAlignment(Align Align)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
This class represents a no-op cast from one type to another.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1450
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1454
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
Definition IRBuilder.h:517
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2783
LLVM_ABI void moveBefore(InstListType::iterator InsertPos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
bool empty() const
Determine if the SetVector is empty or not.
Definition SetVector.h:99
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:150
value_type pop_back_val()
Definition SetVector.h:278
bool contains(const key_type &key) const
Check if the SetVector contains the given key.
Definition SetVector.h:251
void push_back(const T &Elt)
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)
Definition Type.cpp:292
bool isX86_AMXTy() const
Return true if this is X86 AMX.
Definition Type.h:200
Type * getContainedType(unsigned i) const
This method is used to implement the type iterator (defined at the end of the file).
Definition Type.h:381
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
User * getUser() const
Returns the User that contains this Use.
Definition Use.h:61
LLVM_ABI unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition Use.cpp:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:21
void setOperand(unsigned i, Value *Val)
Definition User.h:237
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
iterator_range< user_iterator > users()
Definition Value.h:426
use_iterator use_begin()
Definition Value.h:364
bool use_empty() const
Definition Value.h:346
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
@ User
could "use" a pointer
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
NodeAddr< FuncNode * > Func
Definition RDFGraph.h:393
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:644
LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
Definition Utils.cpp:1725
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:634
iterator_range< po_iterator< T > > post_order(const T &G)
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:402
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:560
LLVM_ABI FunctionPass * createX86LowerAMXTypePass()
The pass transforms load/store <256 x i32> to AMX load/store intrinsics or split the data to two <128...