LLVM 22.0.0git
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// intrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
17#include "llvm/ADT/Twine.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Constant.h"
23#include "llvm/IR/Constants.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instruction.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
34#include "llvm/Pass.h"
38#include <cassert>
39#include <optional>
40
41using namespace llvm;
42
43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
44
45namespace {
46
47class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48public:
49 static char ID; // Pass identification, replacement for typeid
50
51 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
54 }
55
56 bool runOnFunction(Function &F) override;
57
58 StringRef getPassName() const override {
59 return "Scalarize Masked Memory Intrinsics";
60 }
61
62 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 }
66};
67
68} // end anonymous namespace
69
70static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
71 const TargetTransformInfo &TTI, const DataLayout &DL,
72 bool HasBranchDivergence, DomTreeUpdater *DTU);
73static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75 const DataLayout &DL, bool HasBranchDivergence,
76 DomTreeUpdater *DTU);
77
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79
80INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81 "Scalarize unsupported masked memory intrinsics", false,
82 false)
85INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86 "Scalarize unsupported masked memory intrinsics", false,
87 false)
88
90 return new ScalarizeMaskedMemIntrinLegacyPass();
91}
92
93static bool isConstantIntVector(Value *Mask) {
94 Constant *C = dyn_cast<Constant>(Mask);
95 if (!C)
96 return false;
97
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt = C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
102 return false;
103 }
104
105 return true;
106}
107
108static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109 unsigned Idx) {
110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111}
112
113// Translate a masked load intrinsic like
114// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115// <16 x i1> %mask, <16 x i32> %passthru)
116// to a chain of basic blocks, with loading element one-by-one if
117// the appropriate mask bit is set
118//
119// %1 = bitcast i8* %addr to i32*
120// %2 = extractelement <16 x i1> %mask, i32 0
121// br i1 %2, label %cond.load, label %else
122//
123// cond.load: ; preds = %0
124// %3 = getelementptr i32* %1, i32 0
125// %4 = load i32* %3
126// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127// br label %else
128//
129// else: ; preds = %0, %cond.load
130// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131// %6 = extractelement <16 x i1> %mask, i32 1
132// br i1 %6, label %cond.load1, label %else2
133//
134// cond.load1: ; preds = %else
135// %7 = getelementptr i32* %1, i32 1
136// %8 = load i32* %7
137// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138// br label %else2
139//
140// else2: ; preds = %else, %cond.load1
141// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142// %10 = extractelement <16 x i1> %mask, i32 2
143// br i1 %10, label %cond.load4, label %else5
144//
145static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146 CallInst *CI, DomTreeUpdater *DTU,
147 bool &ModifiedDT) {
148 Value *Ptr = CI->getArgOperand(0);
149 Value *Alignment = CI->getArgOperand(1);
150 Value *Mask = CI->getArgOperand(2);
151 Value *Src0 = CI->getArgOperand(3);
152
153 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
154 VectorType *VecType = cast<FixedVectorType>(CI->getType());
155
156 Type *EltTy = VecType->getElementType();
157
158 IRBuilder<> Builder(CI->getContext());
159 Instruction *InsertPt = CI;
160 BasicBlock *IfBlock = CI->getParent();
161
162 Builder.SetInsertPoint(InsertPt);
164
165 // Short-cut if the mask is all-true.
166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
167 LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
168 NewI->copyMetadata(*CI);
169 NewI->takeName(CI);
170 CI->replaceAllUsesWith(NewI);
171 CI->eraseFromParent();
172 return;
173 }
174
175 // Adjust alignment for the scalar instruction.
176 const Align AdjustedAlignVal =
177 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179
180 // The result vector
181 Value *VResult = Src0;
182
183 if (isConstantIntVector(Mask)) {
184 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186 continue;
187 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
188 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
189 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190 }
191 CI->replaceAllUsesWith(VResult);
192 CI->eraseFromParent();
193 return;
194 }
195
196 // Optimize the case where the "masked load" is a predicated load - that is,
197 // where the mask is the splat of a non-constant scalar boolean. In that case,
198 // use that splated value as the guard on a conditional vector load.
199 if (isSplatValue(Mask, /*Index=*/0)) {
200 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
201 Mask->getName() + ".first");
202 Instruction *ThenTerm =
203 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
204 /*BranchWeights=*/nullptr, DTU);
205
206 BasicBlock *CondBlock = ThenTerm->getParent();
207 CondBlock->setName("cond.load");
208 Builder.SetInsertPoint(CondBlock->getTerminator());
209 LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
210 CI->getName() + ".cond.load");
211 Load->copyMetadata(*CI);
212
213 BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
214 Builder.SetInsertPoint(PostLoad, PostLoad->begin());
215 PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
216 Phi->addIncoming(Load, CondBlock);
217 Phi->addIncoming(Src0, IfBlock);
218 Phi->takeName(CI);
219
220 CI->replaceAllUsesWith(Phi);
221 CI->eraseFromParent();
222 ModifiedDT = true;
223 return;
224 }
225 // If the mask is not v1i1, use scalar bit test operations. This generates
226 // better results on X86 at least. However, don't do this on GPUs and other
227 // machines with divergence, as there each i1 needs a vector register.
228 Value *SclrMask = nullptr;
229 if (VectorWidth != 1 && !HasBranchDivergence) {
230 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
231 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
232 }
233
234 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
235 // Fill the "else" block, created in the previous iteration
236 //
237 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240 //
241 // On GPUs, use
242 // %cond = extrectelement %mask, Idx
243 // instead
245 if (SclrMask != nullptr) {
246 Value *Mask = Builder.getInt(APInt::getOneBitSet(
247 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
248 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
249 Builder.getIntN(VectorWidth, 0));
250 } else {
251 Predicate = Builder.CreateExtractElement(Mask, Idx);
252 }
253
254 // Create "cond" block
255 //
256 // %EltAddr = getelementptr i32* %1, i32 0
257 // %Elt = load i32* %EltAddr
258 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
259 //
260 Instruction *ThenTerm =
261 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
262 /*BranchWeights=*/nullptr, DTU);
263
264 BasicBlock *CondBlock = ThenTerm->getParent();
265 CondBlock->setName("cond.load");
266
267 Builder.SetInsertPoint(CondBlock->getTerminator());
268 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
269 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
270 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
271
272 // Create "else" block, fill it in the next iteration
273 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
274 NewIfBlock->setName("else");
275 BasicBlock *PrevIfBlock = IfBlock;
276 IfBlock = NewIfBlock;
277
278 // Create the phi to join the new and previous value.
279 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
280 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
281 Phi->addIncoming(NewVResult, CondBlock);
282 Phi->addIncoming(VResult, PrevIfBlock);
283 VResult = Phi;
284 }
285
286 CI->replaceAllUsesWith(VResult);
287 CI->eraseFromParent();
288
289 ModifiedDT = true;
290}
291
292// Translate a masked store intrinsic, like
293// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
294// <16 x i1> %mask)
295// to a chain of basic blocks, that stores element one-by-one if
296// the appropriate mask bit is set
297//
298// %1 = bitcast i8* %addr to i32*
299// %2 = extractelement <16 x i1> %mask, i32 0
300// br i1 %2, label %cond.store, label %else
301//
302// cond.store: ; preds = %0
303// %3 = extractelement <16 x i32> %val, i32 0
304// %4 = getelementptr i32* %1, i32 0
305// store i32 %3, i32* %4
306// br label %else
307//
308// else: ; preds = %0, %cond.store
309// %5 = extractelement <16 x i1> %mask, i32 1
310// br i1 %5, label %cond.store1, label %else2
311//
312// cond.store1: ; preds = %else
313// %6 = extractelement <16 x i32> %val, i32 1
314// %7 = getelementptr i32* %1, i32 1
315// store i32 %6, i32* %7
316// br label %else2
317// . . .
318static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319 CallInst *CI, DomTreeUpdater *DTU,
320 bool &ModifiedDT) {
321 Value *Src = CI->getArgOperand(0);
322 Value *Ptr = CI->getArgOperand(1);
323 Value *Alignment = CI->getArgOperand(2);
324 Value *Mask = CI->getArgOperand(3);
325
326 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327 auto *VecType = cast<VectorType>(Src->getType());
328
329 Type *EltTy = VecType->getElementType();
330
331 IRBuilder<> Builder(CI->getContext());
332 Instruction *InsertPt = CI;
333 Builder.SetInsertPoint(InsertPt);
335
336 // Short-cut if the mask is all-true.
337 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
338 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
339 Store->takeName(CI);
340 Store->copyMetadata(*CI);
341 CI->eraseFromParent();
342 return;
343 }
344
345 // Adjust alignment for the scalar instruction.
346 const Align AdjustedAlignVal =
347 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
348 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
349
350 if (isConstantIntVector(Mask)) {
351 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
352 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
353 continue;
354 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
355 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
356 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
357 }
358 CI->eraseFromParent();
359 return;
360 }
361
362 // Optimize the case where the "masked store" is a predicated store - that is,
363 // when the mask is the splat of a non-constant scalar boolean. In that case,
364 // optimize to a conditional store.
365 if (isSplatValue(Mask, /*Index=*/0)) {
366 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
367 Mask->getName() + ".first");
368 Instruction *ThenTerm =
369 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
370 /*BranchWeights=*/nullptr, DTU);
371 BasicBlock *CondBlock = ThenTerm->getParent();
372 CondBlock->setName("cond.store");
373 Builder.SetInsertPoint(CondBlock->getTerminator());
374
375 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
376 Store->takeName(CI);
377 Store->copyMetadata(*CI);
378
379 CI->eraseFromParent();
380 ModifiedDT = true;
381 return;
382 }
383
384 // If the mask is not v1i1, use scalar bit test operations. This generates
385 // better results on X86 at least. However, don't do this on GPUs or other
386 // machines with branch divergence, as there each i1 takes up a register.
387 Value *SclrMask = nullptr;
388 if (VectorWidth != 1 && !HasBranchDivergence) {
389 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
390 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
391 }
392
393 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
394 // Fill the "else" block, created in the previous iteration
395 //
396 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
397 // %cond = icmp ne i16 %mask_1, 0
398 // br i1 %mask_1, label %cond.store, label %else
399 //
400 // On GPUs, use
401 // %cond = extrectelement %mask, Idx
402 // instead
404 if (SclrMask != nullptr) {
405 Value *Mask = Builder.getInt(APInt::getOneBitSet(
406 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
407 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
408 Builder.getIntN(VectorWidth, 0));
409 } else {
410 Predicate = Builder.CreateExtractElement(Mask, Idx);
411 }
412
413 // Create "cond" block
414 //
415 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
416 // %EltAddr = getelementptr i32* %1, i32 0
417 // %store i32 %OneElt, i32* %EltAddr
418 //
419 Instruction *ThenTerm =
420 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
421 /*BranchWeights=*/nullptr, DTU);
422
423 BasicBlock *CondBlock = ThenTerm->getParent();
424 CondBlock->setName("cond.store");
425
426 Builder.SetInsertPoint(CondBlock->getTerminator());
427 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
428 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
429 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
430
431 // Create "else" block, fill it in the next iteration
432 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
433 NewIfBlock->setName("else");
434
435 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
436 }
437 CI->eraseFromParent();
438
439 ModifiedDT = true;
440}
441
442// Translate a masked gather intrinsic like
443// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
444// <16 x i1> %Mask, <16 x i32> %Src)
445// to a chain of basic blocks, with loading element one-by-one if
446// the appropriate mask bit is set
447//
448// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
449// %Mask0 = extractelement <16 x i1> %Mask, i32 0
450// br i1 %Mask0, label %cond.load, label %else
451//
452// cond.load:
453// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454// %Load0 = load i32, i32* %Ptr0, align 4
455// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
456// br label %else
457//
458// else:
459// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
460// %Mask1 = extractelement <16 x i1> %Mask, i32 1
461// br i1 %Mask1, label %cond.load1, label %else2
462//
463// cond.load1:
464// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
465// %Load1 = load i32, i32* %Ptr1, align 4
466// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
467// br label %else2
468// . . .
469// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
470// ret <16 x i32> %Result
472 bool HasBranchDivergence, CallInst *CI,
473 DomTreeUpdater *DTU, bool &ModifiedDT) {
474 Value *Ptrs = CI->getArgOperand(0);
475 Value *Alignment = CI->getArgOperand(1);
476 Value *Mask = CI->getArgOperand(2);
477 Value *Src0 = CI->getArgOperand(3);
478
479 auto *VecType = cast<FixedVectorType>(CI->getType());
480 Type *EltTy = VecType->getElementType();
481
482 IRBuilder<> Builder(CI->getContext());
483 Instruction *InsertPt = CI;
484 BasicBlock *IfBlock = CI->getParent();
485 Builder.SetInsertPoint(InsertPt);
486 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
487
489
490 // The result vector
491 Value *VResult = Src0;
492 unsigned VectorWidth = VecType->getNumElements();
493
494 // Shorten the way if the mask is a vector of constants.
495 if (isConstantIntVector(Mask)) {
496 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
497 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
498 continue;
499 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
500 LoadInst *Load =
501 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
502 VResult =
503 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
504 }
505 CI->replaceAllUsesWith(VResult);
506 CI->eraseFromParent();
507 return;
508 }
509
510 // If the mask is not v1i1, use scalar bit test operations. This generates
511 // better results on X86 at least. However, don't do this on GPUs or other
512 // machines with branch divergence, as there, each i1 takes up a register.
513 Value *SclrMask = nullptr;
514 if (VectorWidth != 1 && !HasBranchDivergence) {
515 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
516 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
517 }
518
519 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
520 // Fill the "else" block, created in the previous iteration
521 //
522 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
523 // %cond = icmp ne i16 %mask_1, 0
524 // br i1 %Mask1, label %cond.load, label %else
525 //
526 // On GPUs, use
527 // %cond = extrectelement %mask, Idx
528 // instead
529
531 if (SclrMask != nullptr) {
532 Value *Mask = Builder.getInt(APInt::getOneBitSet(
533 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
534 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
535 Builder.getIntN(VectorWidth, 0));
536 } else {
537 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
538 }
539
540 // Create "cond" block
541 //
542 // %EltAddr = getelementptr i32* %1, i32 0
543 // %Elt = load i32* %EltAddr
544 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
545 //
546 Instruction *ThenTerm =
547 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
548 /*BranchWeights=*/nullptr, DTU);
549
550 BasicBlock *CondBlock = ThenTerm->getParent();
551 CondBlock->setName("cond.load");
552
553 Builder.SetInsertPoint(CondBlock->getTerminator());
554 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
555 LoadInst *Load =
556 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
557 Value *NewVResult =
558 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
559
560 // Create "else" block, fill it in the next iteration
561 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
562 NewIfBlock->setName("else");
563 BasicBlock *PrevIfBlock = IfBlock;
564 IfBlock = NewIfBlock;
565
566 // Create the phi to join the new and previous value.
567 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
568 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
569 Phi->addIncoming(NewVResult, CondBlock);
570 Phi->addIncoming(VResult, PrevIfBlock);
571 VResult = Phi;
572 }
573
574 CI->replaceAllUsesWith(VResult);
575 CI->eraseFromParent();
576
577 ModifiedDT = true;
578}
579
580// Translate a masked scatter intrinsic, like
581// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
582// <16 x i1> %Mask)
583// to a chain of basic blocks, that stores element one-by-one if
584// the appropriate mask bit is set.
585//
586// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
587// %Mask0 = extractelement <16 x i1> %Mask, i32 0
588// br i1 %Mask0, label %cond.store, label %else
589//
590// cond.store:
591// %Elt0 = extractelement <16 x i32> %Src, i32 0
592// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
593// store i32 %Elt0, i32* %Ptr0, align 4
594// br label %else
595//
596// else:
597// %Mask1 = extractelement <16 x i1> %Mask, i32 1
598// br i1 %Mask1, label %cond.store1, label %else2
599//
600// cond.store1:
601// %Elt1 = extractelement <16 x i32> %Src, i32 1
602// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
603// store i32 %Elt1, i32* %Ptr1, align 4
604// br label %else2
605// . . .
607 bool HasBranchDivergence, CallInst *CI,
608 DomTreeUpdater *DTU, bool &ModifiedDT) {
609 Value *Src = CI->getArgOperand(0);
610 Value *Ptrs = CI->getArgOperand(1);
611 Value *Alignment = CI->getArgOperand(2);
612 Value *Mask = CI->getArgOperand(3);
613
614 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
615
616 assert(
617 isa<VectorType>(Ptrs->getType()) &&
618 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
619 "Vector of pointers is expected in masked scatter intrinsic");
620
621 IRBuilder<> Builder(CI->getContext());
622 Instruction *InsertPt = CI;
623 Builder.SetInsertPoint(InsertPt);
625
626 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627 unsigned VectorWidth = SrcFVTy->getNumElements();
628
629 // Shorten the way if the mask is a vector of constants.
630 if (isConstantIntVector(Mask)) {
631 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
633 continue;
634 Value *OneElt =
635 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
636 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
637 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
638 }
639 CI->eraseFromParent();
640 return;
641 }
642
643 // If the mask is not v1i1, use scalar bit test operations. This generates
644 // better results on X86 at least.
645 Value *SclrMask = nullptr;
646 if (VectorWidth != 1 && !HasBranchDivergence) {
647 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
648 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
649 }
650
651 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
652 // Fill the "else" block, created in the previous iteration
653 //
654 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
655 // %cond = icmp ne i16 %mask_1, 0
656 // br i1 %Mask1, label %cond.store, label %else
657 //
658 // On GPUs, use
659 // %cond = extrectelement %mask, Idx
660 // instead
662 if (SclrMask != nullptr) {
663 Value *Mask = Builder.getInt(APInt::getOneBitSet(
664 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
665 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
666 Builder.getIntN(VectorWidth, 0));
667 } else {
668 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
669 }
670
671 // Create "cond" block
672 //
673 // %Elt1 = extractelement <16 x i32> %Src, i32 1
674 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
675 // %store i32 %Elt1, i32* %Ptr1
676 //
677 Instruction *ThenTerm =
678 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
679 /*BranchWeights=*/nullptr, DTU);
680
681 BasicBlock *CondBlock = ThenTerm->getParent();
682 CondBlock->setName("cond.store");
683
684 Builder.SetInsertPoint(CondBlock->getTerminator());
685 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
686 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
687 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
688
689 // Create "else" block, fill it in the next iteration
690 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
691 NewIfBlock->setName("else");
692
693 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
694 }
695 CI->eraseFromParent();
696
697 ModifiedDT = true;
698}
699
701 bool HasBranchDivergence, CallInst *CI,
702 DomTreeUpdater *DTU, bool &ModifiedDT) {
703 Value *Ptr = CI->getArgOperand(0);
704 Value *Mask = CI->getArgOperand(1);
705 Value *PassThru = CI->getArgOperand(2);
706 Align Alignment = CI->getParamAlign(0).valueOrOne();
707
708 auto *VecType = cast<FixedVectorType>(CI->getType());
709
710 Type *EltTy = VecType->getElementType();
711
712 IRBuilder<> Builder(CI->getContext());
713 Instruction *InsertPt = CI;
714 BasicBlock *IfBlock = CI->getParent();
715
716 Builder.SetInsertPoint(InsertPt);
718
719 unsigned VectorWidth = VecType->getNumElements();
720
721 // The result vector
722 Value *VResult = PassThru;
723
724 // Adjust alignment for the scalar instruction.
725 const Align AdjustedAlignment =
726 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
727
728 // Shorten the way if the mask is a vector of constants.
729 // Create a build_vector pattern, with loads/poisons as necessary and then
730 // shuffle blend with the pass through value.
731 if (isConstantIntVector(Mask)) {
732 unsigned MemIndex = 0;
733 VResult = PoisonValue::get(VecType);
734 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
735 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
736 Value *InsertElt;
737 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
738 InsertElt = PoisonValue::get(EltTy);
739 ShuffleMask[Idx] = Idx + VectorWidth;
740 } else {
741 Value *NewPtr =
742 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
743 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
744 "Load" + Twine(Idx));
745 ShuffleMask[Idx] = Idx;
746 ++MemIndex;
747 }
748 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
749 "Res" + Twine(Idx));
750 }
751 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
752 CI->replaceAllUsesWith(VResult);
753 CI->eraseFromParent();
754 return;
755 }
756
757 // If the mask is not v1i1, use scalar bit test operations. This generates
758 // better results on X86 at least. However, don't do this on GPUs or other
759 // machines with branch divergence, as there, each i1 takes up a register.
760 Value *SclrMask = nullptr;
761 if (VectorWidth != 1 && !HasBranchDivergence) {
762 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
763 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
764 }
765
766 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
767 // Fill the "else" block, created in the previous iteration
768 //
769 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771 // label %cond.load, label %else
772 //
773 // On GPUs, use
774 // %cond = extrectelement %mask, Idx
775 // instead
776
778 if (SclrMask != nullptr) {
779 Value *Mask = Builder.getInt(APInt::getOneBitSet(
780 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
781 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
782 Builder.getIntN(VectorWidth, 0));
783 } else {
784 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
785 }
786
787 // Create "cond" block
788 //
789 // %EltAddr = getelementptr i32* %1, i32 0
790 // %Elt = load i32* %EltAddr
791 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
792 //
793 Instruction *ThenTerm =
794 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
795 /*BranchWeights=*/nullptr, DTU);
796
797 BasicBlock *CondBlock = ThenTerm->getParent();
798 CondBlock->setName("cond.load");
799
800 Builder.SetInsertPoint(CondBlock->getTerminator());
801 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
802 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
803
804 // Move the pointer if there are more blocks to come.
805 Value *NewPtr;
806 if ((Idx + 1) != VectorWidth)
807 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
808
809 // Create "else" block, fill it in the next iteration
810 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
811 NewIfBlock->setName("else");
812 BasicBlock *PrevIfBlock = IfBlock;
813 IfBlock = NewIfBlock;
814
815 // Create the phi to join the new and previous value.
816 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
817 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
818 ResultPhi->addIncoming(NewVResult, CondBlock);
819 ResultPhi->addIncoming(VResult, PrevIfBlock);
820 VResult = ResultPhi;
821
822 // Add a PHI for the pointer if this isn't the last iteration.
823 if ((Idx + 1) != VectorWidth) {
824 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
825 PtrPhi->addIncoming(NewPtr, CondBlock);
826 PtrPhi->addIncoming(Ptr, PrevIfBlock);
827 Ptr = PtrPhi;
828 }
829 }
830
831 CI->replaceAllUsesWith(VResult);
832 CI->eraseFromParent();
833
834 ModifiedDT = true;
835}
836
838 bool HasBranchDivergence, CallInst *CI,
839 DomTreeUpdater *DTU,
840 bool &ModifiedDT) {
841 Value *Src = CI->getArgOperand(0);
842 Value *Ptr = CI->getArgOperand(1);
843 Value *Mask = CI->getArgOperand(2);
844 Align Alignment = CI->getParamAlign(1).valueOrOne();
845
846 auto *VecType = cast<FixedVectorType>(Src->getType());
847
848 IRBuilder<> Builder(CI->getContext());
849 Instruction *InsertPt = CI;
850 BasicBlock *IfBlock = CI->getParent();
851
852 Builder.SetInsertPoint(InsertPt);
854
855 Type *EltTy = VecType->getElementType();
856
857 // Adjust alignment for the scalar instruction.
858 const Align AdjustedAlignment =
859 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
860
861 unsigned VectorWidth = VecType->getNumElements();
862
863 // Shorten the way if the mask is a vector of constants.
864 if (isConstantIntVector(Mask)) {
865 unsigned MemIndex = 0;
866 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
867 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
868 continue;
869 Value *OneElt =
870 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
871 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
872 Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
873 ++MemIndex;
874 }
875 CI->eraseFromParent();
876 return;
877 }
878
879 // If the mask is not v1i1, use scalar bit test operations. This generates
880 // better results on X86 at least. However, don't do this on GPUs or other
881 // machines with branch divergence, as there, each i1 takes up a register.
882 Value *SclrMask = nullptr;
883 if (VectorWidth != 1 && !HasBranchDivergence) {
884 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
885 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
886 }
887
888 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
889 // Fill the "else" block, created in the previous iteration
890 //
891 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
892 // br i1 %mask_1, label %cond.store, label %else
893 //
894 // On GPUs, use
895 // %cond = extrectelement %mask, Idx
896 // instead
898 if (SclrMask != nullptr) {
899 Value *Mask = Builder.getInt(APInt::getOneBitSet(
900 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
901 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
902 Builder.getIntN(VectorWidth, 0));
903 } else {
904 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
905 }
906
907 // Create "cond" block
908 //
909 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
910 // %EltAddr = getelementptr i32* %1, i32 0
911 // %store i32 %OneElt, i32* %EltAddr
912 //
913 Instruction *ThenTerm =
914 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
915 /*BranchWeights=*/nullptr, DTU);
916
917 BasicBlock *CondBlock = ThenTerm->getParent();
918 CondBlock->setName("cond.store");
919
920 Builder.SetInsertPoint(CondBlock->getTerminator());
921 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
922 Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
923
924 // Move the pointer if there are more blocks to come.
925 Value *NewPtr;
926 if ((Idx + 1) != VectorWidth)
927 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
928
929 // Create "else" block, fill it in the next iteration
930 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
931 NewIfBlock->setName("else");
932 BasicBlock *PrevIfBlock = IfBlock;
933 IfBlock = NewIfBlock;
934
935 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
936
937 // Add a PHI for the pointer if this isn't the last iteration.
938 if ((Idx + 1) != VectorWidth) {
939 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
940 PtrPhi->addIncoming(NewPtr, CondBlock);
941 PtrPhi->addIncoming(Ptr, PrevIfBlock);
942 Ptr = PtrPhi;
943 }
944 }
945 CI->eraseFromParent();
946
947 ModifiedDT = true;
948}
949
951 DomTreeUpdater *DTU,
952 bool &ModifiedDT) {
953 // If we extend histogram to return a result someday (like the updated vector)
954 // then we'll need to support it here.
955 assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
956 Value *Ptrs = CI->getArgOperand(0);
957 Value *Inc = CI->getArgOperand(1);
958 Value *Mask = CI->getArgOperand(2);
959
960 auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
961 Type *EltTy = Inc->getType();
962
963 IRBuilder<> Builder(CI->getContext());
964 Instruction *InsertPt = CI;
965 Builder.SetInsertPoint(InsertPt);
966
968
969 // FIXME: Do we need to add an alignment parameter to the intrinsic?
970 unsigned VectorWidth = AddrType->getNumElements();
971 auto CreateHistogramUpdateValue = [&](IntrinsicInst *CI, Value *Load,
972 Value *Inc) -> Value * {
973 Value *UpdateOp;
974 switch (CI->getIntrinsicID()) {
975 case Intrinsic::experimental_vector_histogram_add:
976 UpdateOp = Builder.CreateAdd(Load, Inc);
977 break;
978 case Intrinsic::experimental_vector_histogram_uadd_sat:
979 UpdateOp =
980 Builder.CreateIntrinsic(Intrinsic::uadd_sat, {EltTy}, {Load, Inc});
981 break;
982 case Intrinsic::experimental_vector_histogram_umin:
983 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
984 break;
985 case Intrinsic::experimental_vector_histogram_umax:
986 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
987 break;
988
989 default:
990 llvm_unreachable("Unexpected histogram intrinsic");
991 }
992 return UpdateOp;
993 };
994
995 // Shorten the way if the mask is a vector of constants.
996 if (isConstantIntVector(Mask)) {
997 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
998 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
999 continue;
1000 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1001 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1002 Value *Update =
1003 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1004 Builder.CreateStore(Update, Ptr);
1005 }
1006 CI->eraseFromParent();
1007 return;
1008 }
1009
1010 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1011 Value *Predicate =
1012 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
1013
1014 Instruction *ThenTerm =
1015 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
1016 /*BranchWeights=*/nullptr, DTU);
1017
1018 BasicBlock *CondBlock = ThenTerm->getParent();
1019 CondBlock->setName("cond.histogram.update");
1020
1021 Builder.SetInsertPoint(CondBlock->getTerminator());
1022 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1023 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1024 Value *UpdateOp =
1025 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1026 Builder.CreateStore(UpdateOp, Ptr);
1027
1028 // Create "else" block, fill it in the next iteration
1029 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1030 NewIfBlock->setName("else");
1031 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1032 }
1033
1034 CI->eraseFromParent();
1035 ModifiedDT = true;
1036}
1037
1039 DominatorTree *DT) {
1040 std::optional<DomTreeUpdater> DTU;
1041 if (DT)
1042 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1043
1044 bool EverMadeChange = false;
1045 bool MadeChange = true;
1046 auto &DL = F.getDataLayout();
1047 bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1048 while (MadeChange) {
1049 MadeChange = false;
1051 bool ModifiedDTOnIteration = false;
1052 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1053 HasBranchDivergence, DTU ? &*DTU : nullptr);
1054
1055 // Restart BB iteration if the dominator tree of the Function was changed
1056 if (ModifiedDTOnIteration)
1057 break;
1058 }
1059
1060 EverMadeChange |= MadeChange;
1061 }
1062 return EverMadeChange;
1063}
1064
1065bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1066 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1067 DominatorTree *DT = nullptr;
1068 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1069 DT = &DTWP->getDomTree();
1070 return runImpl(F, TTI, DT);
1071}
1072
1075 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1076 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
1077 if (!runImpl(F, TTI, DT))
1078 return PreservedAnalyses::all();
1082 return PA;
1083}
1084
1085static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1086 const TargetTransformInfo &TTI, const DataLayout &DL,
1087 bool HasBranchDivergence, DomTreeUpdater *DTU) {
1088 bool MadeChange = false;
1089
1090 BasicBlock::iterator CurInstIterator = BB.begin();
1091 while (CurInstIterator != BB.end()) {
1092 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1093 MadeChange |=
1094 optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1095 if (ModifiedDT)
1096 return true;
1097 }
1098
1099 return MadeChange;
1100}
1101
1102static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1103 const TargetTransformInfo &TTI,
1104 const DataLayout &DL, bool HasBranchDivergence,
1105 DomTreeUpdater *DTU) {
1106 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1107 if (II) {
1108 // The scalarization code below does not work for scalable vectors.
1109 if (isa<ScalableVectorType>(II->getType()) ||
1110 any_of(II->args(),
1111 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1112 return false;
1113 switch (II->getIntrinsicID()) {
1114 default:
1115 break;
1116 case Intrinsic::experimental_vector_histogram_add:
1117 case Intrinsic::experimental_vector_histogram_uadd_sat:
1118 case Intrinsic::experimental_vector_histogram_umin:
1119 case Intrinsic::experimental_vector_histogram_umax:
1121 CI->getArgOperand(1)->getType()))
1122 return false;
1123 scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1124 return true;
1125 case Intrinsic::masked_load:
1126 // Scalarize unsupported vector masked load
1128 CI->getType(),
1129 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(),
1130 cast<PointerType>(CI->getArgOperand(0)->getType())
1131 ->getAddressSpace()))
1132 return false;
1133 scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1134 return true;
1135 case Intrinsic::masked_store:
1137 CI->getArgOperand(0)->getType(),
1138 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
1139 cast<PointerType>(CI->getArgOperand(1)->getType())
1140 ->getAddressSpace()))
1141 return false;
1142 scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1143 return true;
1144 case Intrinsic::masked_gather: {
1145 MaybeAlign MA =
1146 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
1147 Type *LoadTy = CI->getType();
1148 Align Alignment = DL.getValueOrABITypeAlignment(MA,
1149 LoadTy->getScalarType());
1150 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1151 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1152 return false;
1153 scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1154 return true;
1155 }
1156 case Intrinsic::masked_scatter: {
1157 MaybeAlign MA =
1158 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
1159 Type *StoreTy = CI->getArgOperand(0)->getType();
1160 Align Alignment = DL.getValueOrABITypeAlignment(MA,
1161 StoreTy->getScalarType());
1162 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1163 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1164 Alignment))
1165 return false;
1166 scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1167 return true;
1168 }
1169 case Intrinsic::masked_expandload:
1171 CI->getType(),
1173 return false;
1174 scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1175 return true;
1176 case Intrinsic::masked_compressstore:
1178 CI->getArgOperand(0)->getType(),
1180 return false;
1181 scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1182 ModifiedDT);
1183 return true;
1184 }
1185 }
1186
1187 return false;
1188}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static Error unsupported(const char *Str, const Triple &T)
Definition: MachO.cpp:71
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
Definition: ExpandFp.cpp:597
#define F(x, y, z)
Definition: MD5.cpp:55
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
#define DEBUG_TYPE
Scalarize unsupported masked memory intrinsics
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
This pass exposes codegen information to IR-level passes.
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
Definition: PassManager.h:431
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:412
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM_ABI AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM_ABI MaybeAlign getAlignment() const
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
iterator end()
Definition: BasicBlock.h:472
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:459
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:170
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:233
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1778
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1292
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1424
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition: Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:165
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2571
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2559
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
Definition: IRBuilder.h:575
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1864
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
Definition: IRBuilder.h:1946
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Definition: IRBuilder.h:247
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2333
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:834
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2494
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2204
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:533
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1847
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2593
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1551
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1860
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1403
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:207
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1883
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
Definition: IRBuilder.h:538
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:513
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:49
An instruction for reading from memory.
Definition: Instructions.h:180
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:112
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:85
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1885
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition: Analysis.h:132
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
An instruction for storing to memory.
Definition: Instructions.h:296
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
LLVM_ABI bool isLegalMaskedScatter(Type *DataType, Align Alignment) const
Return true if the target supports masked scatter.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace) const
Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const
Return true if the target supports masked expand load.
LLVM_ABI bool hasBranchDivergence(const Function *F=nullptr) const
Return true if branch divergence exists.
LLVM_ABI bool isLegalMaskedGather(Type *DataType, Align Alignment) const
Return true if the target supports masked gather.
LLVM_ABI bool forceScalarizeMaskedGather(VectorType *Type, Align Alignment) const
Return true if the target forces scalarizing of llvm.masked.gather intrinsics.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace) const
Return true if the target supports masked store.
LLVM_ABI bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const
Return true if the target supports masked compress store.
LLVM_ABI bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) const
LLVM_ABI bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const
Return true if the target forces scalarizing of llvm.masked.scatter intrinsics.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:352
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:390
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:546
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1098
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:322
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition: Value.cpp:396
const ParentTy * getParent() const
Definition: ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:663
LLVM_ABI FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1751
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
LLVM_ABI void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
LLVM_ABI Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)