43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
47class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
59 return "Scalarize Masked Memory Intrinsics";
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
81 "Scalarize unsupported masked memory intrinsics",
false,
90 return new ScalarizeMaskedMemIntrinLegacyPass();
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (
unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt =
C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
110 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
153 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
156 Type *EltTy = VecType->getElementType();
166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
176 const Align AdjustedAlignVal =
178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
181 Value *VResult = Src0;
184 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
185 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
201 Mask->getName() +
".first");
207 CondBlock->
setName(
"cond.load");
211 Load->copyMetadata(*CI);
216 Phi->addIncoming(Load, CondBlock);
217 Phi->addIncoming(Src0, IfBlock);
228 Value *SclrMask =
nullptr;
229 if (VectorWidth != 1 && !HasBranchDivergence) {
231 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
234 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
245 if (SclrMask !=
nullptr) {
249 Builder.
getIntN(VectorWidth, 0));
265 CondBlock->
setName(
"cond.load");
276 IfBlock = NewIfBlock;
281 Phi->addIncoming(NewVResult, CondBlock);
282 Phi->addIncoming(VResult, PrevIfBlock);
326 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327 auto *VecType = cast<VectorType>(Src->getType());
329 Type *EltTy = VecType->getElementType();
337 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
340 Store->copyMetadata(*CI);
346 const Align AdjustedAlignVal =
348 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
351 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
352 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
367 Mask->getName() +
".first");
372 CondBlock->
setName(
"cond.store");
377 Store->copyMetadata(*CI);
387 Value *SclrMask =
nullptr;
388 if (VectorWidth != 1 && !HasBranchDivergence) {
390 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
393 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
404 if (SclrMask !=
nullptr) {
408 Builder.
getIntN(VectorWidth, 0));
424 CondBlock->
setName(
"cond.store");
472 bool HasBranchDivergence,
CallInst *CI,
479 auto *VecType = cast<FixedVectorType>(CI->
getType());
480 Type *EltTy = VecType->getElementType();
486 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
491 Value *VResult = Src0;
492 unsigned VectorWidth = VecType->getNumElements();
496 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
497 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
513 Value *SclrMask =
nullptr;
514 if (VectorWidth != 1 && !HasBranchDivergence) {
516 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
519 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
531 if (SclrMask !=
nullptr) {
535 Builder.
getIntN(VectorWidth, 0));
551 CondBlock->
setName(
"cond.load");
564 IfBlock = NewIfBlock;
569 Phi->addIncoming(NewVResult, CondBlock);
570 Phi->addIncoming(VResult, PrevIfBlock);
607 bool HasBranchDivergence,
CallInst *CI,
614 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
617 isa<VectorType>(Ptrs->
getType()) &&
618 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
619 "Vector of pointers is expected in masked scatter intrinsic");
626 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627 unsigned VectorWidth = SrcFVTy->getNumElements();
631 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
632 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
645 Value *SclrMask =
nullptr;
646 if (VectorWidth != 1 && !HasBranchDivergence) {
648 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
651 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
662 if (SclrMask !=
nullptr) {
666 Builder.
getIntN(VectorWidth, 0));
682 CondBlock->
setName(
"cond.store");
701 bool HasBranchDivergence,
CallInst *CI,
708 auto *VecType = cast<FixedVectorType>(CI->
getType());
710 Type *EltTy = VecType->getElementType();
719 unsigned VectorWidth = VecType->getNumElements();
722 Value *VResult = PassThru;
725 const Align AdjustedAlignment =
732 unsigned MemIndex = 0;
735 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
737 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
739 ShuffleMask[
Idx] =
Idx + VectorWidth;
760 Value *SclrMask =
nullptr;
761 if (VectorWidth != 1 && !HasBranchDivergence) {
763 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
766 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
778 if (SclrMask !=
nullptr) {
782 Builder.
getIntN(VectorWidth, 0));
798 CondBlock->
setName(
"cond.load");
806 if ((
Idx + 1) != VectorWidth)
813 IfBlock = NewIfBlock;
823 if ((
Idx + 1) != VectorWidth) {
838 bool HasBranchDivergence,
CallInst *CI,
846 auto *VecType = cast<FixedVectorType>(Src->getType());
855 Type *EltTy = VecType->getElementType();
858 const Align AdjustedAlignment =
861 unsigned VectorWidth = VecType->getNumElements();
865 unsigned MemIndex = 0;
866 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
867 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
882 Value *SclrMask =
nullptr;
883 if (VectorWidth != 1 && !HasBranchDivergence) {
885 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
888 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
898 if (SclrMask !=
nullptr) {
902 Builder.
getIntN(VectorWidth, 0));
918 CondBlock->
setName(
"cond.store");
926 if ((
Idx + 1) != VectorWidth)
933 IfBlock = NewIfBlock;
938 if ((
Idx + 1) != VectorWidth) {
960 auto *AddrType = cast<FixedVectorType>(Ptrs->
getType());
970 unsigned VectorWidth = AddrType->getNumElements();
975 case Intrinsic::experimental_vector_histogram_add:
978 case Intrinsic::experimental_vector_histogram_uadd_sat:
982 case Intrinsic::experimental_vector_histogram_umin:
983 UpdateOp = Builder.
CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
985 case Intrinsic::experimental_vector_histogram_umax:
986 UpdateOp = Builder.
CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
997 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
998 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
1003 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1010 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
1019 CondBlock->
setName(
"cond.histogram.update");
1025 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1040 std::optional<DomTreeUpdater> DTU;
1042 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1044 bool EverMadeChange =
false;
1045 bool MadeChange =
true;
1046 auto &
DL =
F.getDataLayout();
1048 while (MadeChange) {
1051 bool ModifiedDTOnIteration =
false;
1053 HasBranchDivergence, DTU ? &*DTU :
nullptr);
1056 if (ModifiedDTOnIteration)
1060 EverMadeChange |= MadeChange;
1062 return EverMadeChange;
1065bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
1066 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
1068 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1069 DT = &DTWP->getDomTree();
1088 bool MadeChange =
false;
1091 while (CurInstIterator != BB.
end()) {
1092 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1109 if (isa<ScalableVectorType>(
II->getType()) ||
1111 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1113 switch (
II->getIntrinsicID()) {
1116 case Intrinsic::experimental_vector_histogram_add:
1117 case Intrinsic::experimental_vector_histogram_uadd_sat:
1118 case Intrinsic::experimental_vector_histogram_umin:
1119 case Intrinsic::experimental_vector_histogram_umax:
1125 case Intrinsic::masked_load:
1131 ->getAddressSpace()))
1135 case Intrinsic::masked_store:
1140 ->getAddressSpace()))
1144 case Intrinsic::masked_gather: {
1146 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
1148 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1156 case Intrinsic::masked_scatter: {
1158 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
1160 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1169 case Intrinsic::masked_expandload:
1176 case Intrinsic::masked_compressstore:
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM_ABI AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM_ABI MaybeAlign getAlignment() const
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
AttributeList getAttributes() const
Return the attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
LLVM_ABI FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
LLVM_ABI void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
LLVM_ABI Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)