80using namespace PatternMatch;
82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
143 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
150template <
typename T,
typename IterT>
151std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
153 if (Common !=
A.end())
154 return std::make_optional(*Common);
158class ComplexDeinterleavingLegacyPass :
public FunctionPass {
162 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
169 return "Complex Deinterleaving Pass";
182class ComplexDeinterleavingGraph;
183struct ComplexDeinterleavingCompositeNode {
188 Vals.push_back({
R,
I});
196 friend class ComplexDeinterleavingGraph;
197 using CompositeNode = ComplexDeinterleavingCompositeNode;
198 bool OperandsValid =
true;
207 std::optional<FastMathFlags>
Flags;
210 ComplexDeinterleavingRotation::Rotation_0;
212 Value *ReplacementNode =
nullptr;
216 OperandsValid =
false;
222 auto PrintValue = [&](
Value *
V) {
230 auto PrintNodeRef = [&](CompositeNode *
Ptr) {
237 OS <<
"- CompositeNode: " <<
this <<
"\n";
238 for (
unsigned I = 0;
I < Vals.
size();
I++) {
239 OS <<
" Real(" <<
I <<
") : ";
240 PrintValue(Vals[
I].Real);
241 OS <<
" Imag(" <<
I <<
") : ";
242 PrintValue(Vals[
I].Imag);
244 OS <<
" ReplacementNode: ";
245 PrintValue(ReplacementNode);
247 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
248 OS <<
" Operands: \n";
255 bool areOperandsValid() {
return OperandsValid; }
258class ComplexDeinterleavingGraph {
266 using Addend = std::pair<Value *, bool>;
268 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
272 struct PartialMulCandidate {
283 : TL(TL), TLI(TLI), Factor(Factor) {}
336 bool PHIsFound =
false;
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
351 "Reduction related nodes must have Real and Imaginary parts");
353 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
359 for (
auto &V : Vals) {
361 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
362 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
363 (
V.Real &&
V.Imag)) &&
364 "Reduction related nodes must have Real and Imaginary parts");
368 ComplexDeinterleavingCompositeNode(
Operation, Vals);
371 CompositeNode *submitCompositeNode(CompositeNode *
Node) {
373 if (
Node->Vals[0].Real)
396 std::pair<Value *, Value *> &CommonOperandI);
406 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
407 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
408 CompositeNode *identifyDotProduct(
Value *Inst);
415 return identifyNode(Vals);
422 CompositeNode *identifyAdditions(AddendList &RealAddends,
423 AddendList &ImagAddends,
424 std::optional<FastMathFlags> Flags,
428 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
429 AddendList &ImagAddends);
484 void processReductionOperation(
Value *OperationReplacement,
485 CompositeNode *
Node);
486 void processReductionSingle(
Value *OperationReplacement, CompositeNode *
Node);
491 for (
const auto &
Node : CompositeNodes)
504 void identifyReductionNodes();
514class ComplexDeinterleaving {
517 : TL(tl), TLI(tli) {}
521 bool evaluateBasicBlock(
BasicBlock *
B,
unsigned Factor);
529char ComplexDeinterleavingLegacyPass::ID = 0;
532 "Complex Deinterleaving",
false,
false)
538 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
540 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(
F))
549 return new ComplexDeinterleavingLegacyPass(TM);
552bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
553 const auto *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
554 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
555 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
558bool ComplexDeinterleaving::runOnFunction(
Function &
F) {
561 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
567 dbgs() <<
"Complex deinterleaving has been disabled, target does "
568 "not support lowering of complex number operations.\n");
572 bool Changed =
false;
574 Changed |= evaluateBasicBlock(&
B, 2);
579 Changed |= evaluateBasicBlock(&
B, 4);
589 if ((Mask.size() & 1))
592 int HalfNumElements = Mask.size() / 2;
593 for (
int Idx = 0;
Idx < HalfNumElements; ++
Idx) {
594 int MaskIdx =
Idx * 2;
595 if (Mask[MaskIdx] !=
Idx || Mask[MaskIdx + 1] != (
Idx + HalfNumElements))
604 int HalfNumElements = Mask.size() / 2;
606 for (
int Idx = 1;
Idx < HalfNumElements; ++
Idx) {
620 auto *
I = cast<Instruction>(V);
621 if (
I->getOpcode() == Instruction::FNeg)
622 return I->getOperand(0);
624 return I->getOperand(1);
627bool ComplexDeinterleaving::evaluateBasicBlock(
BasicBlock *
B,
unsigned Factor) {
628 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
629 if (Graph.collectPotentialReductions(
B))
630 Graph.identifyReductionNodes();
633 Graph.identifyNodes(&
I);
635 if (Graph.checkNodes()) {
636 Graph.replaceNodes();
643ComplexDeinterleavingGraph::CompositeNode *
644ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
646 std::pair<Value *, Value *> &PartialMatch) {
647 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
655 if ((Real->
getOpcode() != Instruction::FMul &&
656 Real->
getOpcode() != Instruction::Mul) ||
657 (Imag->
getOpcode() != Instruction::FMul &&
658 Imag->
getOpcode() != Instruction::Mul)) {
660 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
693 Value *CommonOperand;
694 Value *UncommonRealOp;
695 Value *UncommonImagOp;
697 if (R0 == I0 || R0 == I1) {
700 }
else if (R1 == I0 || R1 == I1) {
708 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
709 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
710 Rotation == ComplexDeinterleavingRotation::Rotation_270)
711 std::swap(UncommonRealOp, UncommonImagOp);
715 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
716 Rotation == ComplexDeinterleavingRotation::Rotation_180)
717 PartialMatch.first = CommonOperand;
719 PartialMatch.second = CommonOperand;
721 if (!PartialMatch.first || !PartialMatch.second) {
726 CompositeNode *CommonNode =
727 identifyNode(PartialMatch.first, PartialMatch.second);
733 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
739 CompositeNode *
Node = prepareCompositeNode(
740 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
741 Node->Rotation = Rotation;
742 Node->addOperand(CommonNode);
743 Node->addOperand(UncommonNode);
744 return submitCompositeNode(
Node);
747ComplexDeinterleavingGraph::CompositeNode *
748ComplexDeinterleavingGraph::identifyPartialMul(
Instruction *Real,
750 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
754 auto IsAdd = [](
unsigned Op) {
755 return Op == Instruction::FAdd ||
Op == Instruction::Add;
757 auto IsSub = [](
unsigned Op) {
758 return Op == Instruction::FSub ||
Op == Instruction::Sub;
762 Rotation = ComplexDeinterleavingRotation::Rotation_0;
764 Rotation = ComplexDeinterleavingRotation::Rotation_90;
766 Rotation = ComplexDeinterleavingRotation::Rotation_180;
768 Rotation = ComplexDeinterleavingRotation::Rotation_270;
774 if (isa<FPMathOperator>(Real) &&
777 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
800 Value *CommonOperand;
801 Value *UncommonRealOp;
802 Value *UncommonImagOp;
804 if (R0 == I0 || R0 == I1) {
807 }
else if (R1 == I0 || R1 == I1) {
815 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
816 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
818 std::swap(UncommonRealOp, UncommonImagOp);
820 std::pair<Value *, Value *> PartialMatch(
821 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
822 Rotation == ComplexDeinterleavingRotation::Rotation_180)
825 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
826 Rotation == ComplexDeinterleavingRotation::Rotation_270)
830 auto *CRInst = dyn_cast<Instruction>(CR);
831 auto *CIInst = dyn_cast<Instruction>(CI);
833 if (!CRInst || !CIInst) {
834 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
838 CompositeNode *CNode =
839 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
845 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
851 assert(PartialMatch.first && PartialMatch.second);
852 CompositeNode *CommonRes =
853 identifyNode(PartialMatch.first, PartialMatch.second);
859 CompositeNode *
Node = prepareCompositeNode(
860 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
861 Node->Rotation = Rotation;
862 Node->addOperand(CommonRes);
863 Node->addOperand(UncommonRes);
864 Node->addOperand(CNode);
865 return submitCompositeNode(
Node);
868ComplexDeinterleavingGraph::CompositeNode *
870 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
874 if ((Real->
getOpcode() == Instruction::FSub &&
875 Imag->
getOpcode() == Instruction::FAdd) ||
876 (Real->
getOpcode() == Instruction::Sub &&
878 Rotation = ComplexDeinterleavingRotation::Rotation_90;
879 else if ((Real->
getOpcode() == Instruction::FAdd &&
880 Imag->
getOpcode() == Instruction::FSub) ||
881 (Real->
getOpcode() == Instruction::Add &&
883 Rotation = ComplexDeinterleavingRotation::Rotation_270;
885 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
889 auto *AR = dyn_cast<Instruction>(Real->
getOperand(0));
890 auto *BI = dyn_cast<Instruction>(Real->
getOperand(1));
891 auto *AI = dyn_cast<Instruction>(Imag->
getOperand(0));
894 if (!AR || !AI || !BR || !BI) {
899 CompositeNode *ResA = identifyNode(AR, AI);
901 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
904 CompositeNode *ResB = identifyNode(BR, BI);
906 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
910 CompositeNode *
Node =
911 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
912 Node->Rotation = Rotation;
913 Node->addOperand(ResA);
914 Node->addOperand(ResB);
915 return submitCompositeNode(
Node);
919 unsigned OpcA =
A->getOpcode();
920 unsigned OpcB =
B->getOpcode();
922 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
923 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
924 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
925 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
936 switch (
I->getOpcode()) {
937 case Instruction::FAdd:
938 case Instruction::FSub:
939 case Instruction::FMul:
940 case Instruction::FNeg:
941 case Instruction::Add:
942 case Instruction::Sub:
943 case Instruction::Mul:
950ComplexDeinterleavingGraph::CompositeNode *
951ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
952 auto *FirstReal = cast<Instruction>(Vals[0].Real);
953 unsigned FirstOpc = FirstReal->getOpcode();
954 for (
auto &V : Vals) {
955 auto *Real = cast<Instruction>(
V.Real);
956 auto *Imag = cast<Instruction>(
V.Imag);
964 if (isa<FPMathOperator>(FirstReal))
971 for (
auto &V : Vals) {
972 auto *R0 = cast<Instruction>(
V.Real)->getOperand(0);
973 auto *I0 = cast<Instruction>(
V.Imag)->getOperand(0);
977 CompositeNode *Op0 = identifyNode(OpVals);
978 CompositeNode *Op1 =
nullptr;
982 if (FirstReal->isBinaryOp()) {
984 for (
auto &V : Vals) {
985 auto *R1 = cast<Instruction>(
V.Real)->getOperand(1);
986 auto *
I1 = cast<Instruction>(
V.Imag)->getOperand(1);
989 Op1 = identifyNode(OpVals);
995 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
996 Node->Opcode = FirstReal->getOpcode();
997 if (isa<FPMathOperator>(FirstReal))
998 Node->Flags = FirstReal->getFastMathFlags();
1000 Node->addOperand(Op0);
1001 if (FirstReal->isBinaryOp())
1002 Node->addOperand(Op1);
1004 return submitCompositeNode(
Node);
1007ComplexDeinterleavingGraph::CompositeNode *
1008ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
1010 ComplexDeinterleavingOperation::CDot,
V->getType())) {
1011 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
1012 "operation CDot with the type "
1013 << *
V->getType() <<
"\n");
1017 auto *Inst = cast<Instruction>(V);
1018 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1021 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1023 CompositeNode *ANode =
nullptr;
1026 Intrinsic::experimental_vector_partial_reduce_add;
1028 Value *AReal =
nullptr;
1029 Value *AImag =
nullptr;
1030 Value *BReal =
nullptr;
1031 Value *BImag =
nullptr;
1035 if (
auto *CI = dyn_cast<CastInst>(V))
1036 return CI->getOperand(0);
1040 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1041 m_Intrinsic<PartialReduceInt>(
m_Value(Phi),
1045 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1046 m_Intrinsic<PartialReduceInt>(
1050 if (
match(Inst, PatternRot0)) {
1051 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1052 }
else if (
match(Inst, PatternRot270)) {
1053 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1059 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1060 m_Intrinsic<PartialReduceInt>(
m_Value(Phi),
1064 if (!
match(Inst, PatternRot90Rot180))
1067 A0 = UnwrapCast(A0);
1068 A1 = UnwrapCast(A1);
1071 ANode = identifyNode(A0, A1);
1074 ANode = identifyNode(A1, A0);
1078 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1084 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1088 AReal = UnwrapCast(AReal);
1089 AImag = UnwrapCast(AImag);
1090 BReal = UnwrapCast(BReal);
1091 BImag = UnwrapCast(BImag);
1094 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1095 if (AReal->
getType() != ExpectedOperandTy)
1097 if (AImag->
getType() != ExpectedOperandTy)
1099 if (BReal->
getType() != ExpectedOperandTy)
1101 if (BImag->
getType() != ExpectedOperandTy)
1104 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1107 CompositeNode *
Node = identifyNode(AReal, AImag);
1112 if (ANode &&
Node != ANode) {
1115 <<
"Identified node is different from previously identified node. "
1116 "Unable to confidently generate a complex operation node\n");
1120 CN->addOperand(
Node);
1121 CN->addOperand(identifyNode(BReal, BImag));
1122 CN->addOperand(identifyNode(Phi, RealUser));
1124 return submitCompositeNode(CN);
1127ComplexDeinterleavingGraph::CompositeNode *
1128ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1130 if (!isa<VectorType>(
R->getType()) || !isa<VectorType>(
I->getType()))
1133 if (!
R->hasUseList() || !
I->hasUseList())
1137 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1141 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1142 if (!IInst || IInst->getIntrinsicID() !=
1143 Intrinsic::experimental_vector_partial_reduce_add)
1146 if (CompositeNode *CN = identifyDotProduct(IInst))
1152ComplexDeinterleavingGraph::CompositeNode *
1153ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1154 auto It = CachedResult.
find(Vals);
1155 if (It != CachedResult.
end()) {
1160 if (Vals.
size() == 1) {
1161 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1164 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1166 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1167 if (!IsReduction &&
R->getType() !=
I->getType())
1171 if (CompositeNode *CN = identifySplat(Vals))
1174 for (
auto &V : Vals) {
1175 auto *Real = dyn_cast<Instruction>(
V.Real);
1176 auto *Imag = dyn_cast<Instruction>(
V.Imag);
1181 if (CompositeNode *CN = identifyDeinterleave(Vals))
1184 if (Vals.size() == 1) {
1185 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1186 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1187 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1188 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1191 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1194 auto *VTy = cast<VectorType>(Real->
getType());
1195 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1198 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1200 ComplexDeinterleavingOperation::CAdd, NewVTy);
1203 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1208 if (CompositeNode *CN = identifyAdd(Real, Imag))
1212 if (HasCMulSupport && HasCAddSupport) {
1213 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1219 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1223 CachedResult[Vals] =
nullptr;
1227ComplexDeinterleavingGraph::CompositeNode *
1228ComplexDeinterleavingGraph::identifyReassocNodes(
Instruction *Real,
1230 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1231 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1232 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1233 Opcode == Instruction::Sub;
1236 if (!IsOperationSupported(Real->
getOpcode()) ||
1237 !IsOperationSupported(Imag->
getOpcode()))
1240 std::optional<FastMathFlags>
Flags;
1241 if (isa<FPMathOperator>(Real)) {
1243 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1249 if (!
Flags->allowReassoc()) {
1252 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1261 AddendList &Addends) ->
bool {
1264 while (!Worklist.
empty()) {
1266 if (!Visited.
insert(V).second)
1271 Addends.emplace_back(V, IsPositive);
1281 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1282 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1283 Addends.emplace_back(
I, IsPositive);
1286 switch (
I->getOpcode()) {
1287 case Instruction::FAdd:
1288 case Instruction::Add:
1292 case Instruction::FSub:
1296 case Instruction::Sub:
1304 case Instruction::FMul:
1305 case Instruction::Mul: {
1307 if (
isNeg(
I->getOperand(0))) {
1309 IsPositive = !IsPositive;
1311 A =
I->getOperand(0);
1314 if (
isNeg(
I->getOperand(1))) {
1316 IsPositive = !IsPositive;
1318 B =
I->getOperand(1);
1320 Muls.push_back(Product{
A,
B, IsPositive});
1323 case Instruction::FNeg:
1327 Addends.emplace_back(
I, IsPositive);
1331 if (Flags &&
I->getFastMathFlags() != *Flags) {
1333 "inconsistent with the root instructions' flags: "
1342 AddendList RealAddends, ImagAddends;
1343 if (!Collect(Real, RealMuls, RealAddends) ||
1344 !Collect(Imag, ImagMuls, ImagAddends))
1347 if (RealAddends.size() != ImagAddends.size())
1350 CompositeNode *FinalNode =
nullptr;
1351 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1354 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1355 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1361 if (!RealAddends.empty() || !ImagAddends.empty()) {
1362 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1366 assert(FinalNode &&
"FinalNode can not be nullptr here");
1367 assert(FinalNode->Vals.size() == 1);
1369 FinalNode->Vals[0].Real = Real;
1370 FinalNode->Vals[0].Imag = Imag;
1371 submitCompositeNode(FinalNode);
1375bool ComplexDeinterleavingGraph::collectPartialMuls(
1379 auto FindCommonInstruction = [](
const Product &Real,
1380 const Product &Imag) ->
Value * {
1381 if (Real.Multiplicand == Imag.Multiplicand ||
1382 Real.Multiplicand == Imag.Multiplier)
1383 return Real.Multiplicand;
1385 if (Real.Multiplier == Imag.Multiplicand ||
1386 Real.Multiplier == Imag.Multiplier)
1387 return Real.Multiplier;
1396 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1397 bool FoundCommon =
false;
1398 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1399 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1403 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1404 : RealMuls[i].Multiplicand;
1405 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1406 : ImagMuls[
j].Multiplicand;
1408 auto Node = identifyNode(
A,
B);
1414 Node = identifyNode(
B,
A);
1426ComplexDeinterleavingGraph::CompositeNode *
1427ComplexDeinterleavingGraph::identifyMultiplications(
1430 if (RealMuls.
size() != ImagMuls.
size())
1434 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1440 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1444 PartialMulCandidate &InfoA =
Info[
I];
1445 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1449 PartialMulCandidate &InfoB =
Info[J];
1450 auto *InfoReal = &InfoA;
1451 auto *InfoImag = &InfoB;
1453 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1454 if (!NodeFromCommon) {
1456 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1458 if (!NodeFromCommon)
1461 CommonToNode[InfoReal->Common] = NodeFromCommon;
1462 CommonToNode[InfoImag->Common] = NodeFromCommon;
1463 Processed[
I] =
true;
1464 Processed[J] =
true;
1471 for (
auto &PMI : Info) {
1472 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1475 auto It = CommonToNode.
find(PMI.Common);
1478 if (It == CommonToNode.
end()) {
1480 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1481 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1483 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1488 auto &RealMul = RealMuls[PMI.RealIdx];
1489 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1491 auto NodeA = It->second;
1492 auto NodeB = PMI.Node;
1493 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1508 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1509 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1514 if (IsMultiplicandReal) {
1516 if (RealMul.IsPositive && ImagMul.IsPositive)
1518 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1525 if (!RealMul.IsPositive && ImagMul.IsPositive)
1527 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1534 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1535 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1536 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1537 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1538 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1539 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1542 CompositeNode *NodeMul = prepareCompositeNode(
1543 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1544 NodeMul->Rotation = Rotation;
1545 NodeMul->addOperand(NodeA);
1546 NodeMul->addOperand(NodeB);
1548 NodeMul->addOperand(Result);
1549 submitCompositeNode(NodeMul);
1551 ProcessedReal[PMI.RealIdx] =
true;
1552 ProcessedImag[PMI.ImagIdx] =
true;
1556 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1557 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1562 dbgs() <<
"Unprocessed products (Real):\n";
1563 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1564 if (!ProcessedReal[i])
1565 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1566 << *RealMuls[i].Multiplier <<
" multiplied by "
1567 << *RealMuls[i].Multiplicand <<
"\n";
1569 dbgs() <<
"Unprocessed products (Imag):\n";
1570 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1571 if (!ProcessedImag[i])
1572 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1573 << *ImagMuls[i].Multiplier <<
" multiplied by "
1574 << *ImagMuls[i].Multiplicand <<
"\n";
1583ComplexDeinterleavingGraph::CompositeNode *
1584ComplexDeinterleavingGraph::identifyAdditions(
1585 AddendList &RealAddends, AddendList &ImagAddends,
1586 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1587 if (RealAddends.size() != ImagAddends.size())
1590 CompositeNode *
Result =
nullptr;
1596 Result = extractPositiveAddend(RealAddends, ImagAddends);
1601 while (!RealAddends.empty()) {
1602 auto ItR = RealAddends.begin();
1603 auto [
R, IsPositiveR] = *ItR;
1605 bool FoundImag =
false;
1606 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1607 auto [
I, IsPositiveI] = *ItI;
1609 if (IsPositiveR && IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1611 else if (!IsPositiveR && IsPositiveI)
1612 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1613 else if (!IsPositiveR && !IsPositiveI)
1614 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1616 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1618 CompositeNode *AddNode =
nullptr;
1619 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1620 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1621 AddNode = identifyNode(R,
I);
1623 AddNode = identifyNode(
I, R);
1627 dbgs() <<
"Identified addition:\n";
1630 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1633 CompositeNode *TmpNode =
nullptr;
1635 TmpNode = prepareCompositeNode(
1636 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1638 TmpNode->Opcode = Instruction::FAdd;
1639 TmpNode->Flags = *
Flags;
1641 TmpNode->Opcode = Instruction::Add;
1643 }
else if (Rotation ==
1645 TmpNode = prepareCompositeNode(
1646 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1648 TmpNode->Opcode = Instruction::FSub;
1649 TmpNode->Flags = *
Flags;
1651 TmpNode->Opcode = Instruction::Sub;
1654 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1656 TmpNode->Rotation = Rotation;
1659 TmpNode->addOperand(Result);
1660 TmpNode->addOperand(AddNode);
1661 submitCompositeNode(TmpNode);
1663 RealAddends.erase(ItR);
1664 ImagAddends.erase(ItI);
1675ComplexDeinterleavingGraph::CompositeNode *
1676ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1677 AddendList &ImagAddends) {
1678 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1679 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1680 auto [
R, IsPositiveR] = *ItR;
1681 auto [
I, IsPositiveI] = *ItI;
1682 if (IsPositiveR && IsPositiveI) {
1683 auto Result = identifyNode(R,
I);
1685 RealAddends.erase(ItR);
1686 ImagAddends.erase(ItI);
1695bool ComplexDeinterleavingGraph::identifyNodes(
Instruction *RootI) {
1700 auto It = RootToNode.
find(RootI);
1701 if (It != RootToNode.
end()) {
1702 auto RootNode = It->second;
1703 assert(RootNode->Operation ==
1704 ComplexDeinterleavingOperation::ReductionOperation ||
1705 RootNode->Operation ==
1706 ComplexDeinterleavingOperation::ReductionSingle);
1707 assert(RootNode->Vals.size() == 1 &&
1708 "Cannot handle reductions involving multiple complex values");
1711 auto *
R = cast<Instruction>(RootNode->Vals[0].Real);
1712 auto *
I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1717 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1719 ReplacementAnchor =
R;
1721 if (ReplacementAnchor != RootI)
1727 auto RootNode = identifyRoot(RootI);
1734 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1735 <<
"::" <<
B->getName() <<
".\n";
1739 RootToNode[RootI] = RootNode;
1744bool ComplexDeinterleavingGraph::collectPotentialReductions(
BasicBlock *
B) {
1745 bool FoundPotentialReduction =
false;
1749 auto *Br = dyn_cast<BranchInst>(
B->getTerminator());
1750 if (!Br || Br->getNumSuccessors() != 2)
1754 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1757 for (
auto &
PHI :
B->phis()) {
1758 if (
PHI.getNumIncomingValues() != 2)
1761 if (!
PHI.getType()->isVectorTy())
1764 auto *ReductionOp = dyn_cast<Instruction>(
PHI.getIncomingValueForBlock(
B));
1771 for (
auto *U : ReductionOp->users()) {
1775 FinalReduction = dyn_cast<Instruction>(U);
1778 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1779 isa<PHINode>(FinalReduction))
1782 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1784 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1785 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1787 FoundPotentialReduction =
true;
1792 dyn_cast<Instruction>(
PHI.getIncomingValueForBlock(
Incoming)))
1793 FinalInstructions.
insert(InitPHI);
1795 return FoundPotentialReduction;
1798void ComplexDeinterleavingGraph::identifyReductionNodes() {
1799 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1803 for (
auto &
P : ReductionInfo)
1808 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1811 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1814 auto *Real = OperationInstruction[i];
1815 auto *Imag = OperationInstruction[
j];
1816 if (Real->getType() != Imag->
getType())
1819 RealPHI = ReductionInfo[Real].first;
1820 ImagPHI = ReductionInfo[Imag].first;
1822 auto Node = identifyNode(Real, Imag);
1826 Node = identifyNode(Real, Imag);
1832 if (
Node && PHIsFound) {
1833 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1834 << *Real <<
" / " << *Imag <<
"\n");
1835 Processed[i] =
true;
1836 Processed[
j] =
true;
1837 auto RootNode = prepareCompositeNode(
1838 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1839 RootNode->addOperand(
Node);
1840 RootToNode[Real] = RootNode;
1841 RootToNode[Imag] = RootNode;
1842 submitCompositeNode(RootNode);
1847 auto *Real = OperationInstruction[i];
1850 if (Processed[i] || Real->getNumOperands() < 2)
1854 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1857 RealPHI = ReductionInfo[Real].first;
1860 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1861 if (
Node && PHIsFound) {
1863 dbgs() <<
"Identified single reduction starting from instruction: "
1864 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1873 if (ReductionInfo[Real].second->getType()->isVectorTy())
1876 Processed[i] =
true;
1877 auto RootNode = prepareCompositeNode(
1878 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1879 RootNode->addOperand(
Node);
1880 RootToNode[Real] = RootNode;
1881 submitCompositeNode(RootNode);
1889bool ComplexDeinterleavingGraph::checkNodes() {
1890 bool FoundDeinterleaveNode =
false;
1891 for (CompositeNode *
N : CompositeNodes) {
1892 if (!
N->areOperandsValid())
1895 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1896 FoundDeinterleaveNode =
true;
1901 if (!FoundDeinterleaveNode) {
1903 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1904 "guarantee safety during graph transformation.\n");
1911 for (
auto &Pair : RootToNode)
1916 while (!Worklist.
empty()) {
1919 if (!AllInstructions.
insert(
I).second)
1923 if (
auto *OpI = dyn_cast<Instruction>(
Op)) {
1924 if (!FinalInstructions.
count(
I))
1931 for (
auto *
I : AllInstructions) {
1933 if (RootToNode.count(
I))
1936 for (
User *U :
I->users()) {
1937 if (AllInstructions.count(cast<Instruction>(U)))
1949 while (!Worklist.
empty()) {
1951 if (!Visited.
insert(
I).second)
1956 if (RootToNode.count(
I)) {
1958 <<
" could be deinterleaved but its chain of complex "
1959 "operations have an outside user\n");
1960 RootToNode.erase(
I);
1963 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1966 for (
User *U :
I->users())
1970 if (
auto *OpI = dyn_cast<Instruction>(
Op))
1974 return !RootToNode.empty();
1977ComplexDeinterleavingGraph::CompositeNode *
1978ComplexDeinterleavingGraph::identifyRoot(
Instruction *RootI) {
1979 if (
auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1985 for (
unsigned I = 0;
I < Factor;
I += 2) {
1986 auto *Real = dyn_cast<Instruction>(
Intrinsic->getOperand(
I));
1987 auto *Imag = dyn_cast<Instruction>(
Intrinsic->getOperand(
I + 1));
1993 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2007 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
2021 return identifyNode(Real, Imag);
2024ComplexDeinterleavingGraph::CompositeNode *
2025ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2029 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2031 auto *EVI = dyn_cast<ExtractValueInst>(V);
2032 if (!EVI || EVI->getNumIndices() != 1 ||
2033 EVI->getIndices()[0] != ExpectedIdx ||
2034 !isa<Instruction>(EVI->getAggregateOperand()) ||
2035 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2042 if (RealEVI &&
Idx == 0)
2044 if (!RealEVI || !CheckExtract(Vals[
Idx].Imag, (
Idx * 2) + 1,
II)) {
2050 if (
auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(
II)) {
2051 if (IntrinsicII->getIntrinsicID() !=
2056 CompositeNode *PlaceholderNode = prepareCompositeNode(
2058 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2059 for (
auto &V : Vals) {
2060 FinalInstructions.
insert(cast<Instruction>(
V.Real));
2061 FinalInstructions.
insert(cast<Instruction>(
V.Imag));
2063 return submitCompositeNode(PlaceholderNode);
2066 if (Vals.size() != 1)
2069 Value *Real = Vals[0].Real;
2070 Value *Imag = Vals[0].Imag;
2071 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2072 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2073 if (!RealShuffle || !ImagShuffle) {
2074 if (RealShuffle || ImagShuffle)
2075 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2079 Value *RealOp1 = RealShuffle->getOperand(1);
2080 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
2084 Value *ImagOp1 = ImagShuffle->getOperand(1);
2085 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
2090 Value *RealOp0 = RealShuffle->getOperand(0);
2091 Value *ImagOp0 = ImagShuffle->getOperand(0);
2093 if (RealOp0 != ImagOp0) {
2105 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2106 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2113 Value *
Op = Shuffle->getOperand(0);
2114 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2115 auto *OpTy = cast<FixedVectorType>(
Op->getType());
2117 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2119 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2132 Value *
Op = Shuffle->getOperand(0);
2133 auto *OpTy = cast<FixedVectorType>(
Op->getType());
2134 int NumElements = OpTy->getNumElements();
2138 return Last < NumElements;
2141 if (RealShuffle->getType() != ImagShuffle->getType()) {
2145 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2149 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2154 CompositeNode *PlaceholderNode =
2156 RealShuffle, ImagShuffle);
2157 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2158 FinalInstructions.
insert(RealShuffle);
2159 FinalInstructions.
insert(ImagShuffle);
2160 return submitCompositeNode(PlaceholderNode);
2163ComplexDeinterleavingGraph::CompositeNode *
2164ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2165 auto IsSplat = [](
Value *
V) ->
bool {
2167 if (isa<ConstantDataVector>(V))
2170 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2171 return isa<VectorType>(
V->getType());
2177 if (
auto *Const = dyn_cast<ConstantExpr>(V)) {
2178 if (
Const->getOpcode() != Instruction::ShuffleVector)
2180 VTy = cast<VectorType>(
Const->getType());
2182 }
else if (
auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2183 VTy = Shuf->getType();
2184 Mask = Shuf->getShuffleMask();
2192 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2201 if (
auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2203 for (
auto &V : Vals) {
2204 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2207 auto *Real = dyn_cast<Instruction>(
V.Real);
2208 auto *Imag = dyn_cast<Instruction>(
V.Imag);
2209 if (!Real || !Imag || Real->getParent() != FirstBB ||
2210 Imag->getParent() != FirstBB)
2214 for (
auto &V : Vals) {
2215 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag) || isa<Instruction>(
V.Real) ||
2216 isa<Instruction>(
V.Imag))
2221 for (
auto &V : Vals) {
2222 auto *Real = dyn_cast<Instruction>(
V.Real);
2223 auto *Imag = dyn_cast<Instruction>(
V.Imag);
2225 FinalInstructions.
insert(Real);
2226 FinalInstructions.
insert(Imag);
2229 CompositeNode *PlaceholderNode =
2230 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2231 return submitCompositeNode(PlaceholderNode);
2234ComplexDeinterleavingGraph::CompositeNode *
2235ComplexDeinterleavingGraph::identifyPHINode(
Instruction *Real,
2237 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2241 CompositeNode *PlaceholderNode = prepareCompositeNode(
2242 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2243 return submitCompositeNode(PlaceholderNode);
2246ComplexDeinterleavingGraph::CompositeNode *
2247ComplexDeinterleavingGraph::identifySelectNode(
Instruction *Real,
2249 auto *SelectReal = dyn_cast<SelectInst>(Real);
2250 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2251 if (!SelectReal || !SelectImag)
2268 auto NodeA = identifyNode(AR, AI);
2272 auto NodeB = identifyNode(
RA, BI);
2276 CompositeNode *PlaceholderNode = prepareCompositeNode(
2277 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2278 PlaceholderNode->addOperand(NodeA);
2279 PlaceholderNode->addOperand(NodeB);
2280 FinalInstructions.
insert(MaskA);
2281 FinalInstructions.
insert(MaskB);
2282 return submitCompositeNode(PlaceholderNode);
2286 std::optional<FastMathFlags> Flags,
2290 case Instruction::FNeg:
2291 I =
B.CreateFNeg(InputA);
2293 case Instruction::FAdd:
2294 I =
B.CreateFAdd(InputA, InputB);
2296 case Instruction::Add:
2297 I =
B.CreateAdd(InputA, InputB);
2299 case Instruction::FSub:
2300 I =
B.CreateFSub(InputA, InputB);
2302 case Instruction::Sub:
2303 I =
B.CreateSub(InputA, InputB);
2305 case Instruction::FMul:
2306 I =
B.CreateFMul(InputA, InputB);
2308 case Instruction::Mul:
2309 I =
B.CreateMul(InputA, InputB);
2315 cast<Instruction>(
I)->setFastMathFlags(*Flags);
2320 CompositeNode *
Node) {
2321 if (
Node->ReplacementNode)
2322 return Node->ReplacementNode;
2324 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2326 return Node->Operands.size() >
Idx
2327 ? replaceNode(Builder,
Node->Operands[
Idx])
2331 Value *ReplacementNode =
nullptr;
2332 switch (
Node->Operation) {
2333 case ComplexDeinterleavingOperation::CDot: {
2334 Value *Input0 = ReplaceOperandIfExist(
Node, 0);
2335 Value *Input1 = ReplaceOperandIfExist(
Node, 1);
2338 "Node inputs need to be of the same type"));
2343 case ComplexDeinterleavingOperation::CAdd:
2344 case ComplexDeinterleavingOperation::CMulPartial:
2345 case ComplexDeinterleavingOperation::Symmetric: {
2346 Value *Input0 = ReplaceOperandIfExist(
Node, 0);
2347 Value *Input1 = ReplaceOperandIfExist(
Node, 1);
2350 "Node inputs need to be of the same type"));
2353 "Accumulator and input need to be of the same type"));
2354 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2359 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2363 case ComplexDeinterleavingOperation::Deinterleave:
2366 case ComplexDeinterleavingOperation::Splat: {
2368 for (
auto &V :
Node->Vals) {
2372 auto *
R = dyn_cast<Instruction>(
Node->Vals[0].Real);
2373 auto *
I = dyn_cast<Instruction>(
Node->Vals[0].Imag);
2377 for (
auto V :
Node->Vals) {
2378 if (InsertPoint->
comesBefore(cast<Instruction>(
V.Real)))
2379 InsertPoint = cast<Instruction>(
V.Real);
2380 if (InsertPoint->
comesBefore(cast<Instruction>(
V.Imag)))
2381 InsertPoint = cast<Instruction>(
V.Imag);
2385 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2391 case ComplexDeinterleavingOperation::ReductionPHI: {
2394 auto *OldPHI = cast<PHINode>(
Node->Vals[0].Real);
2395 auto *VTy = cast<VectorType>(
Node->Vals[0].Real->getType());
2396 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2398 OldToNewPHI[OldPHI] = NewPHI;
2399 ReplacementNode = NewPHI;
2402 case ComplexDeinterleavingOperation::ReductionSingle:
2403 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2404 processReductionSingle(ReplacementNode,
Node);
2406 case ComplexDeinterleavingOperation::ReductionOperation:
2407 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2408 processReductionOperation(ReplacementNode,
Node);
2410 case ComplexDeinterleavingOperation::ReductionSelect: {
2411 auto *MaskReal = cast<Instruction>(
Node->Vals[0].Real)->getOperand(0);
2412 auto *MaskImag = cast<Instruction>(
Node->Vals[0].Imag)->getOperand(0);
2413 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2414 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2421 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2422 NumComplexTransformations += 1;
2423 Node->ReplacementNode = ReplacementNode;
2424 return ReplacementNode;
2427void ComplexDeinterleavingGraph::processReductionSingle(
2428 Value *OperationReplacement, CompositeNode *
Node) {
2429 auto *Real = cast<Instruction>(
Node->Vals[0].Real);
2430 auto *OldPHI = ReductionInfo[Real].first;
2431 auto *NewPHI = OldToNewPHI[OldPHI];
2432 auto *VTy = cast<VectorType>(Real->
getType());
2433 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2439 Value *NewInit =
nullptr;
2440 if (
auto *
C = dyn_cast<Constant>(
Init)) {
2441 if (
C->isZeroValue())
2449 NewPHI->addIncoming(NewInit,
Incoming);
2450 NewPHI->addIncoming(OperationReplacement, BackEdge);
2452 auto *FinalReduction = ReductionInfo[Real].second;
2459void ComplexDeinterleavingGraph::processReductionOperation(
2460 Value *OperationReplacement, CompositeNode *
Node) {
2461 auto *Real = cast<Instruction>(
Node->Vals[0].Real);
2462 auto *Imag = cast<Instruction>(
Node->Vals[0].Imag);
2463 auto *OldPHIReal = ReductionInfo[Real].first;
2464 auto *OldPHIImag = ReductionInfo[Imag].first;
2465 auto *NewPHI = OldToNewPHI[OldPHIReal];
2468 Value *InitReal = OldPHIReal->getIncomingValueForBlock(
Incoming);
2469 Value *InitImag = OldPHIImag->getIncomingValueForBlock(
Incoming);
2474 NewPHI->addIncoming(NewInit,
Incoming);
2475 NewPHI->addIncoming(OperationReplacement, BackEdge);
2479 auto *FinalReductionReal = ReductionInfo[Real].second;
2480 auto *FinalReductionImag = ReductionInfo[Imag].second;
2483 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2485 OperationReplacement->
getType(),
2486 OperationReplacement);
2489 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2493 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2496void ComplexDeinterleavingGraph::replaceNodes() {
2498 for (
auto *RootInstruction : OrderedRoots) {
2501 if (!RootToNode.count(RootInstruction))
2505 auto RootNode = RootToNode[RootInstruction];
2506 Value *
R = replaceNode(Builder, RootNode);
2508 if (RootNode->Operation ==
2509 ComplexDeinterleavingOperation::ReductionOperation) {
2510 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2511 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2512 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2513 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2516 }
else if (RootNode->Operation ==
2517 ComplexDeinterleavingOperation::ReductionSingle) {
2518 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2519 auto &
Info = ReductionInfo[RootInst];
2520 Info.first->removeIncomingValue(BackEdge);
2523 assert(R &&
"Unable to find replacement for RootInstruction");
2524 DeadInstrRoots.
push_back(RootInstruction);
2525 RootInstruction->replaceAllUsesWith(R);
2529 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
std::optional< std::vector< StOtherPiece > > Other
static bool runOnFunction(Function &F, bool PostInlining)
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
DEMANGLE_DUMP_METHOD void dump() const
A linked-list with a custom, local allocator.
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM Basic Block Representation.
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Function * getParent() const
Return the enclosing method, or null if none.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
A BumpPtrAllocator that allows only elements of a specific type to be allocated.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
This class implements an extremely fast bulk output stream that can only output to a stream.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ BR
Control flow instructions. These all have token chains.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
ComplexDeinterleavingRotation
DWARFExpression::Operation Op
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static ComplexValue getEmptyKey()
static unsigned getHashValue(const ComplexValue &Val)
static ComplexValue getTombstoneKey()
An information struct used to provide DenseMap with the various necessary components for a given valu...
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...