LLVM 22.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129namespace llvm {
130template <> struct DenseMapInfo<ComplexValue> {
131 static inline ComplexValue getEmptyKey() {
134 }
135 static inline ComplexValue getTombstoneKey() {
138 }
139 static unsigned getHashValue(const ComplexValue &Val) {
142 }
143 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
144 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
145 }
146};
147} // end namespace llvm
148
149namespace {
150template <typename T, typename IterT>
151std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
152 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
153 if (Common != A.end())
154 return std::make_optional(*Common);
155 return std::nullopt;
156}
157
158class ComplexDeinterleavingLegacyPass : public FunctionPass {
159public:
160 static char ID;
161
162 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
163 : FunctionPass(ID), TM(TM) {
166 }
167
168 StringRef getPassName() const override {
169 return "Complex Deinterleaving Pass";
170 }
171
172 bool runOnFunction(Function &F) override;
173 void getAnalysisUsage(AnalysisUsage &AU) const override {
175 AU.setPreservesCFG();
176 }
177
178private:
179 const TargetMachine *TM;
180};
181
182class ComplexDeinterleavingGraph;
183struct ComplexDeinterleavingCompositeNode {
184
185 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
186 Value *R, Value *I)
187 : Operation(Op) {
188 Vals.push_back({R, I});
189 }
190
191 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
193 : Operation(Op), Vals(Other) {}
194
195private:
196 friend class ComplexDeinterleavingGraph;
197 using CompositeNode = ComplexDeinterleavingCompositeNode;
198 bool OperandsValid = true;
199
200public:
202 ComplexValues Vals;
203
204 // This two members are required exclusively for generating
205 // ComplexDeinterleavingOperation::Symmetric operations.
206 unsigned Opcode;
207 std::optional<FastMathFlags> Flags;
208
210 ComplexDeinterleavingRotation::Rotation_0;
212 Value *ReplacementNode = nullptr;
213
214 void addOperand(CompositeNode *Node) {
215 if (!Node)
216 OperandsValid = false;
217 Operands.push_back(Node);
218 }
219
220 void dump() { dump(dbgs()); }
221 void dump(raw_ostream &OS) {
222 auto PrintValue = [&](Value *V) {
223 if (V) {
224 OS << "\"";
225 V->print(OS, true);
226 OS << "\"\n";
227 } else
228 OS << "nullptr\n";
229 };
230 auto PrintNodeRef = [&](CompositeNode *Ptr) {
231 if (Ptr)
232 OS << Ptr << "\n";
233 else
234 OS << "nullptr\n";
235 };
236
237 OS << "- CompositeNode: " << this << "\n";
238 for (unsigned I = 0; I < Vals.size(); I++) {
239 OS << " Real(" << I << ") : ";
240 PrintValue(Vals[I].Real);
241 OS << " Imag(" << I << ") : ";
242 PrintValue(Vals[I].Imag);
243 }
244 OS << " ReplacementNode: ";
245 PrintValue(ReplacementNode);
246 OS << " Operation: " << (int)Operation << "\n";
247 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
248 OS << " Operands: \n";
249 for (const auto &Op : Operands) {
250 OS << " - ";
251 PrintNodeRef(Op);
252 }
253 }
254
255 bool areOperandsValid() { return OperandsValid; }
256};
257
258class ComplexDeinterleavingGraph {
259public:
260 struct Product {
261 Value *Multiplier;
262 Value *Multiplicand;
263 bool IsPositive;
264 };
265
266 using Addend = std::pair<Value *, bool>;
267 using AddendList = BumpPtrList<Addend>;
268 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
269
270 // Helper struct for holding info about potential partial multiplication
271 // candidates
272 struct PartialMulCandidate {
273 Value *Common;
274 CompositeNode *Node;
275 unsigned RealIdx;
276 unsigned ImagIdx;
277 bool IsNodeInverted;
278 };
279
280 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
281 const TargetLibraryInfo *TLI,
282 unsigned Factor)
283 : TL(TL), TLI(TLI), Factor(Factor) {}
284
285private:
286 const TargetLowering *TL = nullptr;
287 const TargetLibraryInfo *TLI = nullptr;
288 unsigned Factor;
289 SmallVector<CompositeNode *> CompositeNodes;
292
293 SmallPtrSet<Instruction *, 16> FinalInstructions;
294
295 /// Root instructions are instructions from which complex computation starts
297
298 /// Topologically sorted root instructions
300
301 /// When examining a basic block for complex deinterleaving, if it is a simple
302 /// one-block loop, then the only incoming block is 'Incoming' and the
303 /// 'BackEdge' block is the block itself."
304 BasicBlock *BackEdge = nullptr;
305 BasicBlock *Incoming = nullptr;
306
307 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
308 /// %OutsideUser as it is shown in the IR:
309 ///
310 /// vector.body:
311 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
312 /// [ %ReductionOp, %vector.body ]
313 /// ...
314 /// %ReductionOp = fadd i64 ...
315 /// ...
316 /// br i1 %condition, label %vector.body, %middle.block
317 ///
318 /// middle.block:
319 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
320 ///
321 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
322 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
324
325 /// In the process of detecting a reduction, we consider a pair of
326 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
327 /// traverse the use-tree to detect complex operations. As this is a reduction
328 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
329 /// to the %ReductionOPs that we suspect to be complex.
330 /// RealPHI and ImagPHI are used by the identifyPHINode method.
331 PHINode *RealPHI = nullptr;
332 PHINode *ImagPHI = nullptr;
333
334 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
335 /// detection.
336 bool PHIsFound = false;
337
338 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
339 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
340 /// This mapping is populated during
341 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
342 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
343 /// replacement process.
345
346 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
347 Value *R, Value *I) {
348 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (R && I)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
352 return new (Allocator.Allocate())
353 ComplexDeinterleavingCompositeNode(Operation, R, I);
354 }
355
356 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
357 ComplexValues &Vals) {
358#ifndef NDEBUG
359 for (auto &V : Vals) {
360 assert(
361 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
362 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
363 (V.Real && V.Imag)) &&
364 "Reduction related nodes must have Real and Imaginary parts");
365 }
366#endif
367 return new (Allocator.Allocate())
368 ComplexDeinterleavingCompositeNode(Operation, Vals);
369 }
370
371 CompositeNode *submitCompositeNode(CompositeNode *Node) {
372 CompositeNodes.push_back(Node);
373 if (Node->Vals[0].Real)
374 CachedResult[Node->Vals] = Node;
375 return Node;
376 }
377
378 /// Identifies a complex partial multiply pattern and its rotation, based on
379 /// the following patterns
380 ///
381 /// 0: r: cr + ar * br
382 /// i: ci + ar * bi
383 /// 90: r: cr - ai * bi
384 /// i: ci + ai * br
385 /// 180: r: cr - ar * br
386 /// i: ci - ar * bi
387 /// 270: r: cr + ai * bi
388 /// i: ci - ai * br
389 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
390
391 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
392 /// is partially known from identifyPartialMul, filling in the other half of
393 /// the complex pair.
394 CompositeNode *
395 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
396 std::pair<Value *, Value *> &CommonOperandI);
397
398 /// Identifies a complex add pattern and its rotation, based on the following
399 /// patterns.
400 ///
401 /// 90: r: ar - bi
402 /// i: ai + br
403 /// 270: r: ar + bi
404 /// i: ai - br
405 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
406 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
407 CompositeNode *identifyPartialReduction(Value *R, Value *I);
408 CompositeNode *identifyDotProduct(Value *Inst);
409
410 CompositeNode *identifyNode(ComplexValues &Vals);
411
412 CompositeNode *identifyNode(Value *R, Value *I) {
413 ComplexValues Vals;
414 Vals.push_back({R, I});
415 return identifyNode(Vals);
416 }
417
418 /// Determine if a sum of complex numbers can be formed from \p RealAddends
419 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
420 /// Return nullptr if it is not possible to construct a complex number.
421 /// \p Flags are needed to generate symmetric Add and Sub operations.
422 CompositeNode *identifyAdditions(AddendList &RealAddends,
423 AddendList &ImagAddends,
424 std::optional<FastMathFlags> Flags,
425 CompositeNode *Accumulator);
426
427 /// Extract one addend that have both real and imaginary parts positive.
428 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
429 AddendList &ImagAddends);
430
431 /// Determine if sum of multiplications of complex numbers can be formed from
432 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
433 /// to it. Return nullptr if it is not possible to construct a complex number.
434 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
435 SmallVectorImpl<Product> &ImagMuls,
436 CompositeNode *Accumulator);
437
438 /// Go through pairs of multiplication (one Real and one Imag) and find all
439 /// possible candidates for partial multiplication and put them into \p
440 /// Candidates. Returns true if all Product has pair with common operand
441 bool collectPartialMuls(ArrayRef<Product> RealMuls,
442 ArrayRef<Product> ImagMuls,
444
445 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
446 /// the order of complex computation operations may be significantly altered,
447 /// and the real and imaginary parts may not be executed in parallel. This
448 /// function takes this into consideration and employs a more general approach
449 /// to identify complex computations. Initially, it gathers all the addends
450 /// and multiplicands and then constructs a complex expression from them.
451 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
452
453 CompositeNode *identifyRoot(Instruction *I);
454
455 /// Identifies the Deinterleave operation applied to a vector containing
456 /// complex numbers. There are two ways to represent the Deinterleave
457 /// operation:
458 /// * Using two shufflevectors with even indices for /pReal instruction and
459 /// odd indices for /pImag instructions (only for fixed-width vectors)
460 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
461 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
462 /// 2.
463 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
464
465 /// identifying the operation that represents a complex number repeated in a
466 /// Splat vector. There are two possible types of splats: ConstantExpr with
467 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
468 /// initialization mask with all values set to zero.
469 CompositeNode *identifySplat(ComplexValues &Vals);
470
471 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
472
473 /// Identifies SelectInsts in a loop that has reduction with predication masks
474 /// and/or predicated tail folding
475 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
476
477 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
478
479 /// Complete IR modifications after producing new reduction operation:
480 /// * Populate the PHINode generated for
481 /// ComplexDeinterleavingOperation::ReductionPHI
482 /// * Deinterleave the final value outside of the loop and repurpose original
483 /// reduction users
484 void processReductionOperation(Value *OperationReplacement,
485 CompositeNode *Node);
486 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
487
488public:
489 void dump() { dump(dbgs()); }
490 void dump(raw_ostream &OS) {
491 for (const auto &Node : CompositeNodes)
492 Node->dump(OS);
493 }
494
495 /// Returns false if the deinterleaving operation should be cancelled for the
496 /// current graph.
497 bool identifyNodes(Instruction *RootI);
498
499 /// In case \pB is one-block loop, this function seeks potential reductions
500 /// and populates ReductionInfo. Returns true if any reductions were
501 /// identified.
502 bool collectPotentialReductions(BasicBlock *B);
503
504 void identifyReductionNodes();
505
506 /// Check that every instruction, from the roots to the leaves, has internal
507 /// uses.
508 bool checkNodes();
509
510 /// Perform the actual replacement of the underlying instruction graph.
511 void replaceNodes();
512};
513
514class ComplexDeinterleaving {
515public:
516 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
517 : TL(tl), TLI(tli) {}
518 bool runOnFunction(Function &F);
519
520private:
521 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
522
523 const TargetLowering *TL = nullptr;
524 const TargetLibraryInfo *TLI = nullptr;
525};
526
527} // namespace
528
529char ComplexDeinterleavingLegacyPass::ID = 0;
530
531INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
532 "Complex Deinterleaving", false, false)
533INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
535
538 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
539 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
540 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
541 return PreservedAnalyses::all();
542
545 return PA;
546}
547
549 return new ComplexDeinterleavingLegacyPass(TM);
550}
551
552bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
553 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
554 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
555 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
556}
557
558bool ComplexDeinterleaving::runOnFunction(Function &F) {
561 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
562 return false;
563 }
564
567 dbgs() << "Complex deinterleaving has been disabled, target does "
568 "not support lowering of complex number operations.\n");
569 return false;
570 }
571
572 bool Changed = false;
573 for (auto &B : F)
574 Changed |= evaluateBasicBlock(&B, 2);
575
576 // TODO: Permit changes for both interleave factors in the same function.
577 if (!Changed) {
578 for (auto &B : F)
579 Changed |= evaluateBasicBlock(&B, 4);
580 }
581
582 // TODO: We can also support interleave factors of 6 and 8 if needed.
583
584 return Changed;
585}
586
588 // If the size is not even, it's not an interleaving mask
589 if ((Mask.size() & 1))
590 return false;
591
592 int HalfNumElements = Mask.size() / 2;
593 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
594 int MaskIdx = Idx * 2;
595 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
596 return false;
597 }
598
599 return true;
600}
601
603 int Offset = Mask[0];
604 int HalfNumElements = Mask.size() / 2;
605
606 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
607 if (Mask[Idx] != (Idx * 2) + Offset)
608 return false;
609 }
610
611 return true;
612}
613
614bool isNeg(Value *V) {
615 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
616}
617
619 assert(isNeg(V));
620 auto *I = cast<Instruction>(V);
621 if (I->getOpcode() == Instruction::FNeg)
622 return I->getOperand(0);
623
624 return I->getOperand(1);
625}
626
627bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
628 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
629 if (Graph.collectPotentialReductions(B))
630 Graph.identifyReductionNodes();
631
632 for (auto &I : *B)
633 Graph.identifyNodes(&I);
634
635 if (Graph.checkNodes()) {
636 Graph.replaceNodes();
637 return true;
638 }
639
640 return false;
641}
642
643ComplexDeinterleavingGraph::CompositeNode *
644ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
645 Instruction *Real, Instruction *Imag,
646 std::pair<Value *, Value *> &PartialMatch) {
647 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
648 << "\n");
649
650 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
651 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
652 return nullptr;
653 }
654
655 if ((Real->getOpcode() != Instruction::FMul &&
656 Real->getOpcode() != Instruction::Mul) ||
657 (Imag->getOpcode() != Instruction::FMul &&
658 Imag->getOpcode() != Instruction::Mul)) {
660 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
661 return nullptr;
662 }
663
664 Value *R0 = Real->getOperand(0);
665 Value *R1 = Real->getOperand(1);
666 Value *I0 = Imag->getOperand(0);
667 Value *I1 = Imag->getOperand(1);
668
669 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
670 // rotations and use the operand.
671 unsigned Negs = 0;
672 Value *Op;
673 if (match(R0, m_Neg(m_Value(Op)))) {
674 Negs |= 1;
675 R0 = Op;
676 } else if (match(R1, m_Neg(m_Value(Op)))) {
677 Negs |= 1;
678 R1 = Op;
679 }
680
681 if (isNeg(I0)) {
682 Negs |= 2;
683 Negs ^= 1;
684 I0 = Op;
685 } else if (match(I1, m_Neg(m_Value(Op)))) {
686 Negs |= 2;
687 Negs ^= 1;
688 I1 = Op;
689 }
690
692
693 Value *CommonOperand;
694 Value *UncommonRealOp;
695 Value *UncommonImagOp;
696
697 if (R0 == I0 || R0 == I1) {
698 CommonOperand = R0;
699 UncommonRealOp = R1;
700 } else if (R1 == I0 || R1 == I1) {
701 CommonOperand = R1;
702 UncommonRealOp = R0;
703 } else {
704 LLVM_DEBUG(dbgs() << " - No equal operand\n");
705 return nullptr;
706 }
707
708 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
709 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
710 Rotation == ComplexDeinterleavingRotation::Rotation_270)
711 std::swap(UncommonRealOp, UncommonImagOp);
712
713 // Between identifyPartialMul and here we need to have found a complete valid
714 // pair from the CommonOperand of each part.
715 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
716 Rotation == ComplexDeinterleavingRotation::Rotation_180)
717 PartialMatch.first = CommonOperand;
718 else
719 PartialMatch.second = CommonOperand;
720
721 if (!PartialMatch.first || !PartialMatch.second) {
722 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
723 return nullptr;
724 }
725
726 CompositeNode *CommonNode =
727 identifyNode(PartialMatch.first, PartialMatch.second);
728 if (!CommonNode) {
729 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
730 return nullptr;
731 }
732
733 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
734 if (!UncommonNode) {
735 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
736 return nullptr;
737 }
738
739 CompositeNode *Node = prepareCompositeNode(
740 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
741 Node->Rotation = Rotation;
742 Node->addOperand(CommonNode);
743 Node->addOperand(UncommonNode);
744 return submitCompositeNode(Node);
745}
746
747ComplexDeinterleavingGraph::CompositeNode *
748ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
749 Instruction *Imag) {
750 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
751 << "\n");
752
753 // Determine rotation
754 auto IsAdd = [](unsigned Op) {
755 return Op == Instruction::FAdd || Op == Instruction::Add;
756 };
757 auto IsSub = [](unsigned Op) {
758 return Op == Instruction::FSub || Op == Instruction::Sub;
759 };
761 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
762 Rotation = ComplexDeinterleavingRotation::Rotation_0;
763 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
764 Rotation = ComplexDeinterleavingRotation::Rotation_90;
765 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
766 Rotation = ComplexDeinterleavingRotation::Rotation_180;
767 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
768 Rotation = ComplexDeinterleavingRotation::Rotation_270;
769 else {
770 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
771 return nullptr;
772 }
773
774 if (isa<FPMathOperator>(Real) &&
775 (!Real->getFastMathFlags().allowContract() ||
776 !Imag->getFastMathFlags().allowContract())) {
777 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
778 return nullptr;
779 }
780
781 Value *CR = Real->getOperand(0);
782 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
783 if (!RealMulI)
784 return nullptr;
785 Value *CI = Imag->getOperand(0);
786 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
787 if (!ImagMulI)
788 return nullptr;
789
790 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
791 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
792 return nullptr;
793 }
794
795 Value *R0 = RealMulI->getOperand(0);
796 Value *R1 = RealMulI->getOperand(1);
797 Value *I0 = ImagMulI->getOperand(0);
798 Value *I1 = ImagMulI->getOperand(1);
799
800 Value *CommonOperand;
801 Value *UncommonRealOp;
802 Value *UncommonImagOp;
803
804 if (R0 == I0 || R0 == I1) {
805 CommonOperand = R0;
806 UncommonRealOp = R1;
807 } else if (R1 == I0 || R1 == I1) {
808 CommonOperand = R1;
809 UncommonRealOp = R0;
810 } else {
811 LLVM_DEBUG(dbgs() << " - No equal operand\n");
812 return nullptr;
813 }
814
815 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
816 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
818 std::swap(UncommonRealOp, UncommonImagOp);
819
820 std::pair<Value *, Value *> PartialMatch(
821 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
822 Rotation == ComplexDeinterleavingRotation::Rotation_180)
823 ? CommonOperand
824 : nullptr,
825 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
826 Rotation == ComplexDeinterleavingRotation::Rotation_270)
827 ? CommonOperand
828 : nullptr);
829
830 auto *CRInst = dyn_cast<Instruction>(CR);
831 auto *CIInst = dyn_cast<Instruction>(CI);
832
833 if (!CRInst || !CIInst) {
834 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
835 return nullptr;
836 }
837
838 CompositeNode *CNode =
839 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
840 if (!CNode) {
841 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
842 return nullptr;
843 }
844
845 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
846 if (!UncommonRes) {
847 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
848 return nullptr;
849 }
850
851 assert(PartialMatch.first && PartialMatch.second);
852 CompositeNode *CommonRes =
853 identifyNode(PartialMatch.first, PartialMatch.second);
854 if (!CommonRes) {
855 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
856 return nullptr;
857 }
858
859 CompositeNode *Node = prepareCompositeNode(
860 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
861 Node->Rotation = Rotation;
862 Node->addOperand(CommonRes);
863 Node->addOperand(UncommonRes);
864 Node->addOperand(CNode);
865 return submitCompositeNode(Node);
866}
867
868ComplexDeinterleavingGraph::CompositeNode *
869ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
870 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
871
872 // Determine rotation
874 if ((Real->getOpcode() == Instruction::FSub &&
875 Imag->getOpcode() == Instruction::FAdd) ||
876 (Real->getOpcode() == Instruction::Sub &&
877 Imag->getOpcode() == Instruction::Add))
878 Rotation = ComplexDeinterleavingRotation::Rotation_90;
879 else if ((Real->getOpcode() == Instruction::FAdd &&
880 Imag->getOpcode() == Instruction::FSub) ||
881 (Real->getOpcode() == Instruction::Add &&
882 Imag->getOpcode() == Instruction::Sub))
883 Rotation = ComplexDeinterleavingRotation::Rotation_270;
884 else {
885 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
886 return nullptr;
887 }
888
889 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
890 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
891 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
892 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
893
894 if (!AR || !AI || !BR || !BI) {
895 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
896 return nullptr;
897 }
898
899 CompositeNode *ResA = identifyNode(AR, AI);
900 if (!ResA) {
901 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
902 return nullptr;
903 }
904 CompositeNode *ResB = identifyNode(BR, BI);
905 if (!ResB) {
906 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
907 return nullptr;
908 }
909
910 CompositeNode *Node =
911 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
912 Node->Rotation = Rotation;
913 Node->addOperand(ResA);
914 Node->addOperand(ResB);
915 return submitCompositeNode(Node);
916}
917
919 unsigned OpcA = A->getOpcode();
920 unsigned OpcB = B->getOpcode();
921
922 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
923 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
924 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
925 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
926}
927
929 auto Pattern =
931
932 return match(A, Pattern) && match(B, Pattern);
933}
934
936 switch (I->getOpcode()) {
937 case Instruction::FAdd:
938 case Instruction::FSub:
939 case Instruction::FMul:
940 case Instruction::FNeg:
941 case Instruction::Add:
942 case Instruction::Sub:
943 case Instruction::Mul:
944 return true;
945 default:
946 return false;
947 }
948}
949
950ComplexDeinterleavingGraph::CompositeNode *
951ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
952 auto *FirstReal = cast<Instruction>(Vals[0].Real);
953 unsigned FirstOpc = FirstReal->getOpcode();
954 for (auto &V : Vals) {
955 auto *Real = cast<Instruction>(V.Real);
956 auto *Imag = cast<Instruction>(V.Imag);
957 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
958 return nullptr;
959
962 return nullptr;
963
964 if (isa<FPMathOperator>(FirstReal))
965 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
966 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
967 return nullptr;
968 }
969
970 ComplexValues OpVals;
971 for (auto &V : Vals) {
972 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
973 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
974 OpVals.push_back({R0, I0});
975 }
976
977 CompositeNode *Op0 = identifyNode(OpVals);
978 CompositeNode *Op1 = nullptr;
979 if (Op0 == nullptr)
980 return nullptr;
981
982 if (FirstReal->isBinaryOp()) {
983 OpVals.clear();
984 for (auto &V : Vals) {
985 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
986 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
987 OpVals.push_back({R1, I1});
988 }
989 Op1 = identifyNode(OpVals);
990 if (Op1 == nullptr)
991 return nullptr;
992 }
993
994 auto Node =
995 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
996 Node->Opcode = FirstReal->getOpcode();
997 if (isa<FPMathOperator>(FirstReal))
998 Node->Flags = FirstReal->getFastMathFlags();
999
1000 Node->addOperand(Op0);
1001 if (FirstReal->isBinaryOp())
1002 Node->addOperand(Op1);
1003
1004 return submitCompositeNode(Node);
1005}
1006
1007ComplexDeinterleavingGraph::CompositeNode *
1008ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1010 ComplexDeinterleavingOperation::CDot, V->getType())) {
1011 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1012 "operation CDot with the type "
1013 << *V->getType() << "\n");
1014 return nullptr;
1015 }
1016
1017 auto *Inst = cast<Instruction>(V);
1018 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1019
1020 CompositeNode *CN =
1021 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1022
1023 CompositeNode *ANode = nullptr;
1024
1025 const Intrinsic::ID PartialReduceInt =
1026 Intrinsic::experimental_vector_partial_reduce_add;
1027
1028 Value *AReal = nullptr;
1029 Value *AImag = nullptr;
1030 Value *BReal = nullptr;
1031 Value *BImag = nullptr;
1032 Value *Phi = nullptr;
1033
1034 auto UnwrapCast = [](Value *V) -> Value * {
1035 if (auto *CI = dyn_cast<CastInst>(V))
1036 return CI->getOperand(0);
1037 return V;
1038 };
1039
1040 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1041 m_Intrinsic<PartialReduceInt>(m_Value(Phi),
1042 m_Mul(m_Value(BReal), m_Value(AReal))),
1043 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1044
1045 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1046 m_Intrinsic<PartialReduceInt>(
1047 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1048 m_Mul(m_Value(BImag), m_Value(AReal)));
1049
1050 if (match(Inst, PatternRot0)) {
1051 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1052 } else if (match(Inst, PatternRot270)) {
1053 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1054 } else {
1055 Value *A0, *A1;
1056 // The rotations 90 and 180 share the same operation pattern, so inspect the
1057 // order of the operands, identifying where the real and imaginary
1058 // components of A go, to discern between the aforementioned rotations.
1059 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1060 m_Intrinsic<PartialReduceInt>(m_Value(Phi),
1061 m_Mul(m_Value(BReal), m_Value(A0))),
1062 m_Mul(m_Value(BImag), m_Value(A1)));
1063
1064 if (!match(Inst, PatternRot90Rot180))
1065 return nullptr;
1066
1067 A0 = UnwrapCast(A0);
1068 A1 = UnwrapCast(A1);
1069
1070 // Test if A0 is real/A1 is imag
1071 ANode = identifyNode(A0, A1);
1072 if (!ANode) {
1073 // Test if A0 is imag/A1 is real
1074 ANode = identifyNode(A1, A0);
1075 // Unable to identify operand components, thus unable to identify rotation
1076 if (!ANode)
1077 return nullptr;
1078 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1079 AReal = A1;
1080 AImag = A0;
1081 } else {
1082 AReal = A0;
1083 AImag = A1;
1084 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1085 }
1086 }
1087
1088 AReal = UnwrapCast(AReal);
1089 AImag = UnwrapCast(AImag);
1090 BReal = UnwrapCast(BReal);
1091 BImag = UnwrapCast(BImag);
1092
1093 VectorType *VTy = cast<VectorType>(V->getType());
1094 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1095 if (AReal->getType() != ExpectedOperandTy)
1096 return nullptr;
1097 if (AImag->getType() != ExpectedOperandTy)
1098 return nullptr;
1099 if (BReal->getType() != ExpectedOperandTy)
1100 return nullptr;
1101 if (BImag->getType() != ExpectedOperandTy)
1102 return nullptr;
1103
1104 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1105 return nullptr;
1106
1107 CompositeNode *Node = identifyNode(AReal, AImag);
1108
1109 // In the case that a node was identified to figure out the rotation, ensure
1110 // that trying to identify a node with AReal and AImag post-unwrap results in
1111 // the same node
1112 if (ANode && Node != ANode) {
1113 LLVM_DEBUG(
1114 dbgs()
1115 << "Identified node is different from previously identified node. "
1116 "Unable to confidently generate a complex operation node\n");
1117 return nullptr;
1118 }
1119
1120 CN->addOperand(Node);
1121 CN->addOperand(identifyNode(BReal, BImag));
1122 CN->addOperand(identifyNode(Phi, RealUser));
1123
1124 return submitCompositeNode(CN);
1125}
1126
1127ComplexDeinterleavingGraph::CompositeNode *
1128ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1129 // Partial reductions don't support non-vector types, so check these first
1130 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1131 return nullptr;
1132
1133 if (!R->hasUseList() || !I->hasUseList())
1134 return nullptr;
1135
1136 auto CommonUser =
1137 findCommonBetweenCollections<Value *>(R->users(), I->users());
1138 if (!CommonUser)
1139 return nullptr;
1140
1141 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1142 if (!IInst || IInst->getIntrinsicID() !=
1143 Intrinsic::experimental_vector_partial_reduce_add)
1144 return nullptr;
1145
1146 if (CompositeNode *CN = identifyDotProduct(IInst))
1147 return CN;
1148
1149 return nullptr;
1150}
1151
1152ComplexDeinterleavingGraph::CompositeNode *
1153ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1154 auto It = CachedResult.find(Vals);
1155 if (It != CachedResult.end()) {
1156 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1157 return It->second;
1158 }
1159
1160 if (Vals.size() == 1) {
1161 assert(Factor == 2 && "Can only handle interleave factors of 2");
1162 Value *R = Vals[0].Real;
1163 Value *I = Vals[0].Imag;
1164 if (CompositeNode *CN = identifyPartialReduction(R, I))
1165 return CN;
1166 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1167 if (!IsReduction && R->getType() != I->getType())
1168 return nullptr;
1169 }
1170
1171 if (CompositeNode *CN = identifySplat(Vals))
1172 return CN;
1173
1174 for (auto &V : Vals) {
1175 auto *Real = dyn_cast<Instruction>(V.Real);
1176 auto *Imag = dyn_cast<Instruction>(V.Imag);
1177 if (!Real || !Imag)
1178 return nullptr;
1179 }
1180
1181 if (CompositeNode *CN = identifyDeinterleave(Vals))
1182 return CN;
1183
1184 if (Vals.size() == 1) {
1185 assert(Factor == 2 && "Can only handle interleave factors of 2");
1186 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1187 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1188 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1189 return CN;
1190
1191 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1192 return CN;
1193
1194 auto *VTy = cast<VectorType>(Real->getType());
1195 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1196
1197 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1198 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1199 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1200 ComplexDeinterleavingOperation::CAdd, NewVTy);
1201
1202 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1203 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1204 return CN;
1205 }
1206
1207 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1208 if (CompositeNode *CN = identifyAdd(Real, Imag))
1209 return CN;
1210 }
1211
1212 if (HasCMulSupport && HasCAddSupport) {
1213 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1214 return CN;
1215 }
1216 }
1217 }
1218
1219 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1220 return CN;
1221
1222 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1223 CachedResult[Vals] = nullptr;
1224 return nullptr;
1225}
1226
1227ComplexDeinterleavingGraph::CompositeNode *
1228ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1229 Instruction *Imag) {
1230 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1231 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1232 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1233 Opcode == Instruction::Sub;
1234 };
1235
1236 if (!IsOperationSupported(Real->getOpcode()) ||
1237 !IsOperationSupported(Imag->getOpcode()))
1238 return nullptr;
1239
1240 std::optional<FastMathFlags> Flags;
1241 if (isa<FPMathOperator>(Real)) {
1242 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1243 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1244 "not identical\n");
1245 return nullptr;
1246 }
1247
1248 Flags = Real->getFastMathFlags();
1249 if (!Flags->allowReassoc()) {
1250 LLVM_DEBUG(
1251 dbgs()
1252 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1253 return nullptr;
1254 }
1255 }
1256
1257 // Collect multiplications and addend instructions from the given instruction
1258 // while traversing it operands. Additionally, verify that all instructions
1259 // have the same fast math flags.
1260 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1261 AddendList &Addends) -> bool {
1262 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1264 while (!Worklist.empty()) {
1265 auto [V, IsPositive] = Worklist.pop_back_val();
1266 if (!Visited.insert(V).second)
1267 continue;
1268
1269 Instruction *I = dyn_cast<Instruction>(V);
1270 if (!I) {
1271 Addends.emplace_back(V, IsPositive);
1272 continue;
1273 }
1274
1275 // If an instruction has more than one user, it indicates that it either
1276 // has an external user, which will be later checked by the checkNodes
1277 // function, or it is a subexpression utilized by multiple expressions. In
1278 // the latter case, we will attempt to separately identify the complex
1279 // operation from here in order to create a shared
1280 // ComplexDeinterleavingCompositeNode.
1281 if (I != Insn && I->hasNUsesOrMore(2)) {
1282 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1283 Addends.emplace_back(I, IsPositive);
1284 continue;
1285 }
1286 switch (I->getOpcode()) {
1287 case Instruction::FAdd:
1288 case Instruction::Add:
1289 Worklist.emplace_back(I->getOperand(1), IsPositive);
1290 Worklist.emplace_back(I->getOperand(0), IsPositive);
1291 break;
1292 case Instruction::FSub:
1293 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1294 Worklist.emplace_back(I->getOperand(0), IsPositive);
1295 break;
1296 case Instruction::Sub:
1297 if (isNeg(I)) {
1298 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1299 } else {
1300 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1301 Worklist.emplace_back(I->getOperand(0), IsPositive);
1302 }
1303 break;
1304 case Instruction::FMul:
1305 case Instruction::Mul: {
1306 Value *A, *B;
1307 if (isNeg(I->getOperand(0))) {
1308 A = getNegOperand(I->getOperand(0));
1309 IsPositive = !IsPositive;
1310 } else {
1311 A = I->getOperand(0);
1312 }
1313
1314 if (isNeg(I->getOperand(1))) {
1315 B = getNegOperand(I->getOperand(1));
1316 IsPositive = !IsPositive;
1317 } else {
1318 B = I->getOperand(1);
1319 }
1320 Muls.push_back(Product{A, B, IsPositive});
1321 break;
1322 }
1323 case Instruction::FNeg:
1324 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1325 break;
1326 default:
1327 Addends.emplace_back(I, IsPositive);
1328 continue;
1329 }
1330
1331 if (Flags && I->getFastMathFlags() != *Flags) {
1332 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1333 "inconsistent with the root instructions' flags: "
1334 << *I << "\n");
1335 return false;
1336 }
1337 }
1338 return true;
1339 };
1340
1341 SmallVector<Product> RealMuls, ImagMuls;
1342 AddendList RealAddends, ImagAddends;
1343 if (!Collect(Real, RealMuls, RealAddends) ||
1344 !Collect(Imag, ImagMuls, ImagAddends))
1345 return nullptr;
1346
1347 if (RealAddends.size() != ImagAddends.size())
1348 return nullptr;
1349
1350 CompositeNode *FinalNode = nullptr;
1351 if (!RealMuls.empty() || !ImagMuls.empty()) {
1352 // If there are multiplicands, extract positive addend and use it as an
1353 // accumulator
1354 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1355 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1356 if (!FinalNode)
1357 return nullptr;
1358 }
1359
1360 // Identify and process remaining additions
1361 if (!RealAddends.empty() || !ImagAddends.empty()) {
1362 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1363 if (!FinalNode)
1364 return nullptr;
1365 }
1366 assert(FinalNode && "FinalNode can not be nullptr here");
1367 assert(FinalNode->Vals.size() == 1);
1368 // Set the Real and Imag fields of the final node and submit it
1369 FinalNode->Vals[0].Real = Real;
1370 FinalNode->Vals[0].Imag = Imag;
1371 submitCompositeNode(FinalNode);
1372 return FinalNode;
1373}
1374
1375bool ComplexDeinterleavingGraph::collectPartialMuls(
1376 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1377 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1378 // Helper function to extract a common operand from two products
1379 auto FindCommonInstruction = [](const Product &Real,
1380 const Product &Imag) -> Value * {
1381 if (Real.Multiplicand == Imag.Multiplicand ||
1382 Real.Multiplicand == Imag.Multiplier)
1383 return Real.Multiplicand;
1384
1385 if (Real.Multiplier == Imag.Multiplicand ||
1386 Real.Multiplier == Imag.Multiplier)
1387 return Real.Multiplier;
1388
1389 return nullptr;
1390 };
1391
1392 // Iterating over real and imaginary multiplications to find common operands
1393 // If a common operand is found, a partial multiplication candidate is created
1394 // and added to the candidates vector The function returns false if no common
1395 // operands are found for any product
1396 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1397 bool FoundCommon = false;
1398 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1399 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1400 if (!Common)
1401 continue;
1402
1403 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1404 : RealMuls[i].Multiplicand;
1405 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1406 : ImagMuls[j].Multiplicand;
1407
1408 auto Node = identifyNode(A, B);
1409 if (Node) {
1410 FoundCommon = true;
1411 PartialMulCandidates.push_back({Common, Node, i, j, false});
1412 }
1413
1414 Node = identifyNode(B, A);
1415 if (Node) {
1416 FoundCommon = true;
1417 PartialMulCandidates.push_back({Common, Node, i, j, true});
1418 }
1419 }
1420 if (!FoundCommon)
1421 return false;
1422 }
1423 return true;
1424}
1425
1426ComplexDeinterleavingGraph::CompositeNode *
1427ComplexDeinterleavingGraph::identifyMultiplications(
1429 CompositeNode *Accumulator = nullptr) {
1430 if (RealMuls.size() != ImagMuls.size())
1431 return nullptr;
1432
1434 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1435 return nullptr;
1436
1437 // Map to store common instruction to node pointers
1439 SmallVector<bool> Processed(Info.size(), false);
1440 for (unsigned I = 0; I < Info.size(); ++I) {
1441 if (Processed[I])
1442 continue;
1443
1444 PartialMulCandidate &InfoA = Info[I];
1445 for (unsigned J = I + 1; J < Info.size(); ++J) {
1446 if (Processed[J])
1447 continue;
1448
1449 PartialMulCandidate &InfoB = Info[J];
1450 auto *InfoReal = &InfoA;
1451 auto *InfoImag = &InfoB;
1452
1453 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1454 if (!NodeFromCommon) {
1455 std::swap(InfoReal, InfoImag);
1456 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1457 }
1458 if (!NodeFromCommon)
1459 continue;
1460
1461 CommonToNode[InfoReal->Common] = NodeFromCommon;
1462 CommonToNode[InfoImag->Common] = NodeFromCommon;
1463 Processed[I] = true;
1464 Processed[J] = true;
1465 }
1466 }
1467
1468 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1469 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1470 CompositeNode *Result = Accumulator;
1471 for (auto &PMI : Info) {
1472 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1473 continue;
1474
1475 auto It = CommonToNode.find(PMI.Common);
1476 // TODO: Process independent complex multiplications. Cases like this:
1477 // A.real() * B where both A and B are complex numbers.
1478 if (It == CommonToNode.end()) {
1479 LLVM_DEBUG({
1480 dbgs() << "Unprocessed independent partial multiplication:\n";
1481 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1482 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1483 << " multiplied by " << *Mul->Multiplicand << "\n";
1484 });
1485 return nullptr;
1486 }
1487
1488 auto &RealMul = RealMuls[PMI.RealIdx];
1489 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1490
1491 auto NodeA = It->second;
1492 auto NodeB = PMI.Node;
1493 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1494 // The following table illustrates the relationship between multiplications
1495 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1496 // can see:
1497 //
1498 // Rotation | Real | Imag |
1499 // ---------+--------+--------+
1500 // 0 | x * u | x * v |
1501 // 90 | -y * v | y * u |
1502 // 180 | -x * u | -x * v |
1503 // 270 | y * v | -y * u |
1504 //
1505 // Check if the candidate can indeed be represented by partial
1506 // multiplication
1507 // TODO: Add support for multiplication by complex one
1508 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1509 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1510 continue;
1511
1512 // Determine the rotation based on the multiplications
1514 if (IsMultiplicandReal) {
1515 // Detect 0 and 180 degrees rotation
1516 if (RealMul.IsPositive && ImagMul.IsPositive)
1518 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1520 else
1521 continue;
1522
1523 } else {
1524 // Detect 90 and 270 degrees rotation
1525 if (!RealMul.IsPositive && ImagMul.IsPositive)
1527 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1529 else
1530 continue;
1531 }
1532
1533 LLVM_DEBUG({
1534 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1535 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1536 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1537 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1538 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1539 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1540 });
1541
1542 CompositeNode *NodeMul = prepareCompositeNode(
1543 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1544 NodeMul->Rotation = Rotation;
1545 NodeMul->addOperand(NodeA);
1546 NodeMul->addOperand(NodeB);
1547 if (Result)
1548 NodeMul->addOperand(Result);
1549 submitCompositeNode(NodeMul);
1550 Result = NodeMul;
1551 ProcessedReal[PMI.RealIdx] = true;
1552 ProcessedImag[PMI.ImagIdx] = true;
1553 }
1554
1555 // Ensure all products have been processed, if not return nullptr.
1556 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1557 !all_of(ProcessedImag, [](bool V) { return V; })) {
1558
1559 // Dump debug information about which partial multiplications are not
1560 // processed.
1561 LLVM_DEBUG({
1562 dbgs() << "Unprocessed products (Real):\n";
1563 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1564 if (!ProcessedReal[i])
1565 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1566 << *RealMuls[i].Multiplier << " multiplied by "
1567 << *RealMuls[i].Multiplicand << "\n";
1568 }
1569 dbgs() << "Unprocessed products (Imag):\n";
1570 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1571 if (!ProcessedImag[i])
1572 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1573 << *ImagMuls[i].Multiplier << " multiplied by "
1574 << *ImagMuls[i].Multiplicand << "\n";
1575 }
1576 });
1577 return nullptr;
1578 }
1579
1580 return Result;
1581}
1582
1583ComplexDeinterleavingGraph::CompositeNode *
1584ComplexDeinterleavingGraph::identifyAdditions(
1585 AddendList &RealAddends, AddendList &ImagAddends,
1586 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1587 if (RealAddends.size() != ImagAddends.size())
1588 return nullptr;
1589
1590 CompositeNode *Result = nullptr;
1591 // If we have accumulator use it as first addend
1592 if (Accumulator)
1594 // Otherwise find an element with both positive real and imaginary parts.
1595 else
1596 Result = extractPositiveAddend(RealAddends, ImagAddends);
1597
1598 if (!Result)
1599 return nullptr;
1600
1601 while (!RealAddends.empty()) {
1602 auto ItR = RealAddends.begin();
1603 auto [R, IsPositiveR] = *ItR;
1604
1605 bool FoundImag = false;
1606 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1607 auto [I, IsPositiveI] = *ItI;
1609 if (IsPositiveR && IsPositiveI)
1610 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1611 else if (!IsPositiveR && IsPositiveI)
1612 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1613 else if (!IsPositiveR && !IsPositiveI)
1614 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1615 else
1616 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1617
1618 CompositeNode *AddNode = nullptr;
1619 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1620 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1621 AddNode = identifyNode(R, I);
1622 } else {
1623 AddNode = identifyNode(I, R);
1624 }
1625 if (AddNode) {
1626 LLVM_DEBUG({
1627 dbgs() << "Identified addition:\n";
1628 dbgs().indent(4) << "X: " << *R << "\n";
1629 dbgs().indent(4) << "Y: " << *I << "\n";
1630 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1631 });
1632
1633 CompositeNode *TmpNode = nullptr;
1635 TmpNode = prepareCompositeNode(
1636 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1637 if (Flags) {
1638 TmpNode->Opcode = Instruction::FAdd;
1639 TmpNode->Flags = *Flags;
1640 } else {
1641 TmpNode->Opcode = Instruction::Add;
1642 }
1643 } else if (Rotation ==
1645 TmpNode = prepareCompositeNode(
1646 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1647 if (Flags) {
1648 TmpNode->Opcode = Instruction::FSub;
1649 TmpNode->Flags = *Flags;
1650 } else {
1651 TmpNode->Opcode = Instruction::Sub;
1652 }
1653 } else {
1654 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1655 nullptr, nullptr);
1656 TmpNode->Rotation = Rotation;
1657 }
1658
1659 TmpNode->addOperand(Result);
1660 TmpNode->addOperand(AddNode);
1661 submitCompositeNode(TmpNode);
1662 Result = TmpNode;
1663 RealAddends.erase(ItR);
1664 ImagAddends.erase(ItI);
1665 FoundImag = true;
1666 break;
1667 }
1668 }
1669 if (!FoundImag)
1670 return nullptr;
1671 }
1672 return Result;
1673}
1674
1675ComplexDeinterleavingGraph::CompositeNode *
1676ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1677 AddendList &ImagAddends) {
1678 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1679 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1680 auto [R, IsPositiveR] = *ItR;
1681 auto [I, IsPositiveI] = *ItI;
1682 if (IsPositiveR && IsPositiveI) {
1683 auto Result = identifyNode(R, I);
1684 if (Result) {
1685 RealAddends.erase(ItR);
1686 ImagAddends.erase(ItI);
1687 return Result;
1688 }
1689 }
1690 }
1691 }
1692 return nullptr;
1693}
1694
1695bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1696 // This potential root instruction might already have been recognized as
1697 // reduction. Because RootToNode maps both Real and Imaginary parts to
1698 // CompositeNode we should choose only one either Real or Imag instruction to
1699 // use as an anchor for generating complex instruction.
1700 auto It = RootToNode.find(RootI);
1701 if (It != RootToNode.end()) {
1702 auto RootNode = It->second;
1703 assert(RootNode->Operation ==
1704 ComplexDeinterleavingOperation::ReductionOperation ||
1705 RootNode->Operation ==
1706 ComplexDeinterleavingOperation::ReductionSingle);
1707 assert(RootNode->Vals.size() == 1 &&
1708 "Cannot handle reductions involving multiple complex values");
1709 // Find out which part, Real or Imag, comes later, and only if we come to
1710 // the latest part, add it to OrderedRoots.
1711 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1712 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1713 : nullptr;
1714
1715 Instruction *ReplacementAnchor;
1716 if (I)
1717 ReplacementAnchor = R->comesBefore(I) ? I : R;
1718 else
1719 ReplacementAnchor = R;
1720
1721 if (ReplacementAnchor != RootI)
1722 return false;
1723 OrderedRoots.push_back(RootI);
1724 return true;
1725 }
1726
1727 auto RootNode = identifyRoot(RootI);
1728 if (!RootNode)
1729 return false;
1730
1731 LLVM_DEBUG({
1732 Function *F = RootI->getFunction();
1733 BasicBlock *B = RootI->getParent();
1734 dbgs() << "Complex deinterleaving graph for " << F->getName()
1735 << "::" << B->getName() << ".\n";
1736 dump(dbgs());
1737 dbgs() << "\n";
1738 });
1739 RootToNode[RootI] = RootNode;
1740 OrderedRoots.push_back(RootI);
1741 return true;
1742}
1743
1744bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1745 bool FoundPotentialReduction = false;
1746 if (Factor != 2)
1747 return false;
1748
1749 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1750 if (!Br || Br->getNumSuccessors() != 2)
1751 return false;
1752
1753 // Identify simple one-block loop
1754 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1755 return false;
1756
1757 for (auto &PHI : B->phis()) {
1758 if (PHI.getNumIncomingValues() != 2)
1759 continue;
1760
1761 if (!PHI.getType()->isVectorTy())
1762 continue;
1763
1764 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1765 if (!ReductionOp)
1766 continue;
1767
1768 // Check if final instruction is reduced outside of current block
1769 Instruction *FinalReduction = nullptr;
1770 auto NumUsers = 0u;
1771 for (auto *U : ReductionOp->users()) {
1772 ++NumUsers;
1773 if (U == &PHI)
1774 continue;
1775 FinalReduction = dyn_cast<Instruction>(U);
1776 }
1777
1778 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1779 isa<PHINode>(FinalReduction))
1780 continue;
1781
1782 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1783 BackEdge = B;
1784 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1785 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1786 Incoming = PHI.getIncomingBlock(IncomingIdx);
1787 FoundPotentialReduction = true;
1788
1789 // If the initial value of PHINode is an Instruction, consider it a leaf
1790 // value of a complex deinterleaving graph.
1791 if (auto *InitPHI =
1792 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1793 FinalInstructions.insert(InitPHI);
1794 }
1795 return FoundPotentialReduction;
1796}
1797
1798void ComplexDeinterleavingGraph::identifyReductionNodes() {
1799 assert(Factor == 2 && "Cannot handle multiple complex values");
1800
1801 SmallVector<bool> Processed(ReductionInfo.size(), false);
1802 SmallVector<Instruction *> OperationInstruction;
1803 for (auto &P : ReductionInfo)
1804 OperationInstruction.push_back(P.first);
1805
1806 // Identify a complex computation by evaluating two reduction operations that
1807 // potentially could be involved
1808 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1809 if (Processed[i])
1810 continue;
1811 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1812 if (Processed[j])
1813 continue;
1814 auto *Real = OperationInstruction[i];
1815 auto *Imag = OperationInstruction[j];
1816 if (Real->getType() != Imag->getType())
1817 continue;
1818
1819 RealPHI = ReductionInfo[Real].first;
1820 ImagPHI = ReductionInfo[Imag].first;
1821 PHIsFound = false;
1822 auto Node = identifyNode(Real, Imag);
1823 if (!Node) {
1824 std::swap(Real, Imag);
1825 std::swap(RealPHI, ImagPHI);
1826 Node = identifyNode(Real, Imag);
1827 }
1828
1829 // If a node is identified and reduction PHINode is used in the chain of
1830 // operations, mark its operation instructions as used to prevent
1831 // re-identification and attach the node to the real part
1832 if (Node && PHIsFound) {
1833 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1834 << *Real << " / " << *Imag << "\n");
1835 Processed[i] = true;
1836 Processed[j] = true;
1837 auto RootNode = prepareCompositeNode(
1838 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1839 RootNode->addOperand(Node);
1840 RootToNode[Real] = RootNode;
1841 RootToNode[Imag] = RootNode;
1842 submitCompositeNode(RootNode);
1843 break;
1844 }
1845 }
1846
1847 auto *Real = OperationInstruction[i];
1848 // We want to check that we have 2 operands, but the function attributes
1849 // being counted as operands bloats this value.
1850 if (Processed[i] || Real->getNumOperands() < 2)
1851 continue;
1852
1853 // Can only combined integer reductions at the moment.
1854 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1855 continue;
1856
1857 RealPHI = ReductionInfo[Real].first;
1858 ImagPHI = nullptr;
1859 PHIsFound = false;
1860 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1861 if (Node && PHIsFound) {
1862 LLVM_DEBUG(
1863 dbgs() << "Identified single reduction starting from instruction: "
1864 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1865
1866 // Reducing to a single vector is not supported, only permit reducing down
1867 // to scalar values.
1868 // Doing this here will leave the prior node in the graph,
1869 // however with no uses the node will be unreachable by the replacement
1870 // process. That along with the usage outside the graph should prevent the
1871 // replacement process from kicking off at all for this graph.
1872 // TODO Add support for reducing to a single vector value
1873 if (ReductionInfo[Real].second->getType()->isVectorTy())
1874 continue;
1875
1876 Processed[i] = true;
1877 auto RootNode = prepareCompositeNode(
1878 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1879 RootNode->addOperand(Node);
1880 RootToNode[Real] = RootNode;
1881 submitCompositeNode(RootNode);
1882 }
1883 }
1884
1885 RealPHI = nullptr;
1886 ImagPHI = nullptr;
1887}
1888
1889bool ComplexDeinterleavingGraph::checkNodes() {
1890 bool FoundDeinterleaveNode = false;
1891 for (CompositeNode *N : CompositeNodes) {
1892 if (!N->areOperandsValid())
1893 return false;
1894
1895 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1896 FoundDeinterleaveNode = true;
1897 }
1898
1899 // We need a deinterleave node in order to guarantee that we're working with
1900 // complex numbers.
1901 if (!FoundDeinterleaveNode) {
1902 LLVM_DEBUG(
1903 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1904 "guarantee safety during graph transformation.\n");
1905 return false;
1906 }
1907
1908 // Collect all instructions from roots to leaves
1909 SmallPtrSet<Instruction *, 16> AllInstructions;
1911 for (auto &Pair : RootToNode)
1912 Worklist.push_back(Pair.first);
1913
1914 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1915 // chains
1916 while (!Worklist.empty()) {
1917 auto *I = Worklist.pop_back_val();
1918
1919 if (!AllInstructions.insert(I).second)
1920 continue;
1921
1922 for (Value *Op : I->operands()) {
1923 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1924 if (!FinalInstructions.count(I))
1925 Worklist.emplace_back(OpI);
1926 }
1927 }
1928 }
1929
1930 // Find instructions that have users outside of chain
1931 for (auto *I : AllInstructions) {
1932 // Skip root nodes
1933 if (RootToNode.count(I))
1934 continue;
1935
1936 for (User *U : I->users()) {
1937 if (AllInstructions.count(cast<Instruction>(U)))
1938 continue;
1939
1940 // Found an instruction that is not used by XCMLA/XCADD chain
1941 Worklist.emplace_back(I);
1942 break;
1943 }
1944 }
1945
1946 // If any instructions are found to be used outside, find and remove roots
1947 // that somehow connect to those instructions.
1949 while (!Worklist.empty()) {
1950 auto *I = Worklist.pop_back_val();
1951 if (!Visited.insert(I).second)
1952 continue;
1953
1954 // Found an impacted root node. Removing it from the nodes to be
1955 // deinterleaved
1956 if (RootToNode.count(I)) {
1957 LLVM_DEBUG(dbgs() << "Instruction " << *I
1958 << " could be deinterleaved but its chain of complex "
1959 "operations have an outside user\n");
1960 RootToNode.erase(I);
1961 }
1962
1963 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1964 continue;
1965
1966 for (User *U : I->users())
1967 Worklist.emplace_back(cast<Instruction>(U));
1968
1969 for (Value *Op : I->operands()) {
1970 if (auto *OpI = dyn_cast<Instruction>(Op))
1971 Worklist.emplace_back(OpI);
1972 }
1973 }
1974 return !RootToNode.empty();
1975}
1976
1977ComplexDeinterleavingGraph::CompositeNode *
1978ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1979 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1981 Intrinsic->getIntrinsicID())
1982 return nullptr;
1983
1984 ComplexValues Vals;
1985 for (unsigned I = 0; I < Factor; I += 2) {
1986 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1987 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1988 if (!Real || !Imag)
1989 return nullptr;
1990 Vals.push_back({Real, Imag});
1991 }
1992
1993 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1994 if (!Node1)
1995 return nullptr;
1996 return Node1;
1997 }
1998
1999 // TODO: We could also add support for fixed-width interleave factors of 4
2000 // and above, but currently for symmetric operations the interleaves and
2001 // deinterleaves are already removed by VectorCombine. If we extend this to
2002 // permit complex multiplications, reductions, etc. then we should also add
2003 // support for fixed-width here.
2004 if (Factor != 2)
2005 return nullptr;
2006
2007 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
2008 if (!SVI)
2009 return nullptr;
2010
2011 // Look for a shufflevector that takes separate vectors of the real and
2012 // imaginary components and recombines them into a single vector.
2013 if (!isInterleavingMask(SVI->getShuffleMask()))
2014 return nullptr;
2015
2016 Instruction *Real;
2017 Instruction *Imag;
2018 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2019 return nullptr;
2020
2021 return identifyNode(Real, Imag);
2022}
2023
2024ComplexDeinterleavingGraph::CompositeNode *
2025ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2026 Instruction *II = nullptr;
2027
2028 // Must be at least one complex value.
2029 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2030 Instruction *ExpectedInsn) -> ExtractValueInst * {
2031 auto *EVI = dyn_cast<ExtractValueInst>(V);
2032 if (!EVI || EVI->getNumIndices() != 1 ||
2033 EVI->getIndices()[0] != ExpectedIdx ||
2034 !isa<Instruction>(EVI->getAggregateOperand()) ||
2035 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2036 return nullptr;
2037 return EVI;
2038 };
2039
2040 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2041 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2042 if (RealEVI && Idx == 0)
2043 II = cast<Instruction>(RealEVI->getAggregateOperand());
2044 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2045 II = nullptr;
2046 break;
2047 }
2048 }
2049
2050 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2051 if (IntrinsicII->getIntrinsicID() !=
2053 return nullptr;
2054
2055 // The remaining should match too.
2056 CompositeNode *PlaceholderNode = prepareCompositeNode(
2058 PlaceholderNode->ReplacementNode = II->getOperand(0);
2059 for (auto &V : Vals) {
2060 FinalInstructions.insert(cast<Instruction>(V.Real));
2061 FinalInstructions.insert(cast<Instruction>(V.Imag));
2062 }
2063 return submitCompositeNode(PlaceholderNode);
2064 }
2065
2066 if (Vals.size() != 1)
2067 return nullptr;
2068
2069 Value *Real = Vals[0].Real;
2070 Value *Imag = Vals[0].Imag;
2071 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2072 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2073 if (!RealShuffle || !ImagShuffle) {
2074 if (RealShuffle || ImagShuffle)
2075 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2076 return nullptr;
2077 }
2078
2079 Value *RealOp1 = RealShuffle->getOperand(1);
2080 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
2081 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2082 return nullptr;
2083 }
2084 Value *ImagOp1 = ImagShuffle->getOperand(1);
2085 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
2086 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2087 return nullptr;
2088 }
2089
2090 Value *RealOp0 = RealShuffle->getOperand(0);
2091 Value *ImagOp0 = ImagShuffle->getOperand(0);
2092
2093 if (RealOp0 != ImagOp0) {
2094 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2095 return nullptr;
2096 }
2097
2098 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2099 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2100 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2101 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2102 return nullptr;
2103 }
2104
2105 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2106 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2107 return nullptr;
2108 }
2109
2110 // Type checking, the shuffle type should be a vector type of the same
2111 // scalar type, but half the size
2112 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2113 Value *Op = Shuffle->getOperand(0);
2114 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2115 auto *OpTy = cast<FixedVectorType>(Op->getType());
2116
2117 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2118 return false;
2119 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2120 return false;
2121
2122 return true;
2123 };
2124
2125 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2126 if (!CheckType(Shuffle))
2127 return false;
2128
2129 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2130 int Last = *Mask.rbegin();
2131
2132 Value *Op = Shuffle->getOperand(0);
2133 auto *OpTy = cast<FixedVectorType>(Op->getType());
2134 int NumElements = OpTy->getNumElements();
2135
2136 // Ensure that the deinterleaving shuffle only pulls from the first
2137 // shuffle operand.
2138 return Last < NumElements;
2139 };
2140
2141 if (RealShuffle->getType() != ImagShuffle->getType()) {
2142 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2143 return nullptr;
2144 }
2145 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2146 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2147 return nullptr;
2148 }
2149 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2150 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2151 return nullptr;
2152 }
2153
2154 CompositeNode *PlaceholderNode =
2156 RealShuffle, ImagShuffle);
2157 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2158 FinalInstructions.insert(RealShuffle);
2159 FinalInstructions.insert(ImagShuffle);
2160 return submitCompositeNode(PlaceholderNode);
2161}
2162
2163ComplexDeinterleavingGraph::CompositeNode *
2164ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2165 auto IsSplat = [](Value *V) -> bool {
2166 // Fixed-width vector with constants
2167 if (isa<ConstantDataVector>(V))
2168 return true;
2169
2170 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2171 return isa<VectorType>(V->getType());
2172
2173 VectorType *VTy;
2175 // Splats are represented differently depending on whether the repeated
2176 // value is a constant or an Instruction
2177 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2178 if (Const->getOpcode() != Instruction::ShuffleVector)
2179 return false;
2180 VTy = cast<VectorType>(Const->getType());
2181 Mask = Const->getShuffleMask();
2182 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2183 VTy = Shuf->getType();
2184 Mask = Shuf->getShuffleMask();
2185 } else {
2186 return false;
2187 }
2188
2189 // When the data type is <1 x Type>, it's not possible to differentiate
2190 // between the ComplexDeinterleaving::Deinterleave and
2191 // ComplexDeinterleaving::Splat operations.
2192 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2193 return false;
2194
2195 return all_equal(Mask) && Mask[0] == 0;
2196 };
2197
2198 // The splats must meet the following requirements:
2199 // 1. Must either be all instructions or all values.
2200 // 2. Non-constant splats must live in the same block.
2201 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2202 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2203 for (auto &V : Vals) {
2204 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2205 return nullptr;
2206
2207 auto *Real = dyn_cast<Instruction>(V.Real);
2208 auto *Imag = dyn_cast<Instruction>(V.Imag);
2209 if (!Real || !Imag || Real->getParent() != FirstBB ||
2210 Imag->getParent() != FirstBB)
2211 return nullptr;
2212 }
2213 } else {
2214 for (auto &V : Vals) {
2215 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2216 isa<Instruction>(V.Imag))
2217 return nullptr;
2218 }
2219 }
2220
2221 for (auto &V : Vals) {
2222 auto *Real = dyn_cast<Instruction>(V.Real);
2223 auto *Imag = dyn_cast<Instruction>(V.Imag);
2224 if (Real && Imag) {
2225 FinalInstructions.insert(Real);
2226 FinalInstructions.insert(Imag);
2227 }
2228 }
2229 CompositeNode *PlaceholderNode =
2230 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2231 return submitCompositeNode(PlaceholderNode);
2232}
2233
2234ComplexDeinterleavingGraph::CompositeNode *
2235ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2236 Instruction *Imag) {
2237 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2238 return nullptr;
2239
2240 PHIsFound = true;
2241 CompositeNode *PlaceholderNode = prepareCompositeNode(
2242 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2243 return submitCompositeNode(PlaceholderNode);
2244}
2245
2246ComplexDeinterleavingGraph::CompositeNode *
2247ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2248 Instruction *Imag) {
2249 auto *SelectReal = dyn_cast<SelectInst>(Real);
2250 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2251 if (!SelectReal || !SelectImag)
2252 return nullptr;
2253
2254 Instruction *MaskA, *MaskB;
2255 Instruction *AR, *AI, *RA, *BI;
2256 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2257 m_Instruction(RA))) ||
2258 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2259 m_Instruction(BI))))
2260 return nullptr;
2261
2262 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2263 return nullptr;
2264
2265 if (!MaskA->getType()->isVectorTy())
2266 return nullptr;
2267
2268 auto NodeA = identifyNode(AR, AI);
2269 if (!NodeA)
2270 return nullptr;
2271
2272 auto NodeB = identifyNode(RA, BI);
2273 if (!NodeB)
2274 return nullptr;
2275
2276 CompositeNode *PlaceholderNode = prepareCompositeNode(
2277 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2278 PlaceholderNode->addOperand(NodeA);
2279 PlaceholderNode->addOperand(NodeB);
2280 FinalInstructions.insert(MaskA);
2281 FinalInstructions.insert(MaskB);
2282 return submitCompositeNode(PlaceholderNode);
2283}
2284
2285static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2286 std::optional<FastMathFlags> Flags,
2287 Value *InputA, Value *InputB) {
2288 Value *I;
2289 switch (Opcode) {
2290 case Instruction::FNeg:
2291 I = B.CreateFNeg(InputA);
2292 break;
2293 case Instruction::FAdd:
2294 I = B.CreateFAdd(InputA, InputB);
2295 break;
2296 case Instruction::Add:
2297 I = B.CreateAdd(InputA, InputB);
2298 break;
2299 case Instruction::FSub:
2300 I = B.CreateFSub(InputA, InputB);
2301 break;
2302 case Instruction::Sub:
2303 I = B.CreateSub(InputA, InputB);
2304 break;
2305 case Instruction::FMul:
2306 I = B.CreateFMul(InputA, InputB);
2307 break;
2308 case Instruction::Mul:
2309 I = B.CreateMul(InputA, InputB);
2310 break;
2311 default:
2312 llvm_unreachable("Incorrect symmetric opcode");
2313 }
2314 if (Flags)
2315 cast<Instruction>(I)->setFastMathFlags(*Flags);
2316 return I;
2317}
2318
2319Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2320 CompositeNode *Node) {
2321 if (Node->ReplacementNode)
2322 return Node->ReplacementNode;
2323
2324 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2325 unsigned Idx) -> Value * {
2326 return Node->Operands.size() > Idx
2327 ? replaceNode(Builder, Node->Operands[Idx])
2328 : nullptr;
2329 };
2330
2331 Value *ReplacementNode = nullptr;
2332 switch (Node->Operation) {
2333 case ComplexDeinterleavingOperation::CDot: {
2334 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2335 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2336 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2337 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2338 "Node inputs need to be of the same type"));
2339 ReplacementNode = TL->createComplexDeinterleavingIR(
2340 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2341 break;
2342 }
2343 case ComplexDeinterleavingOperation::CAdd:
2344 case ComplexDeinterleavingOperation::CMulPartial:
2345 case ComplexDeinterleavingOperation::Symmetric: {
2346 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2347 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2348 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2349 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2350 "Node inputs need to be of the same type"));
2352 (Input0->getType() == Accumulator->getType() &&
2353 "Accumulator and input need to be of the same type"));
2354 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2355 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2356 Input0, Input1);
2357 else
2358 ReplacementNode = TL->createComplexDeinterleavingIR(
2359 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2360 Accumulator);
2361 break;
2362 }
2363 case ComplexDeinterleavingOperation::Deinterleave:
2364 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2365 break;
2366 case ComplexDeinterleavingOperation::Splat: {
2368 for (auto &V : Node->Vals) {
2369 Ops.push_back(V.Real);
2370 Ops.push_back(V.Imag);
2371 }
2372 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2373 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2374 if (R && I) {
2375 // Splats that are not constant are interleaved where they are located
2376 Instruction *InsertPoint = R;
2377 for (auto V : Node->Vals) {
2378 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2379 InsertPoint = cast<Instruction>(V.Real);
2380 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2381 InsertPoint = cast<Instruction>(V.Imag);
2382 }
2383 InsertPoint = InsertPoint->getNextNode();
2384 IRBuilder<> IRB(InsertPoint);
2385 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2386 } else {
2387 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2388 }
2389 break;
2390 }
2391 case ComplexDeinterleavingOperation::ReductionPHI: {
2392 // If Operation is ReductionPHI, a new empty PHINode is created.
2393 // It is filled later when the ReductionOperation is processed.
2394 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2395 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2396 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2397 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2398 OldToNewPHI[OldPHI] = NewPHI;
2399 ReplacementNode = NewPHI;
2400 break;
2401 }
2402 case ComplexDeinterleavingOperation::ReductionSingle:
2403 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2404 processReductionSingle(ReplacementNode, Node);
2405 break;
2406 case ComplexDeinterleavingOperation::ReductionOperation:
2407 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2408 processReductionOperation(ReplacementNode, Node);
2409 break;
2410 case ComplexDeinterleavingOperation::ReductionSelect: {
2411 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2412 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2413 auto *A = replaceNode(Builder, Node->Operands[0]);
2414 auto *B = replaceNode(Builder, Node->Operands[1]);
2415 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2416 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2417 break;
2418 }
2419 }
2420
2421 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2422 NumComplexTransformations += 1;
2423 Node->ReplacementNode = ReplacementNode;
2424 return ReplacementNode;
2425}
2426
2427void ComplexDeinterleavingGraph::processReductionSingle(
2428 Value *OperationReplacement, CompositeNode *Node) {
2429 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2430 auto *OldPHI = ReductionInfo[Real].first;
2431 auto *NewPHI = OldToNewPHI[OldPHI];
2432 auto *VTy = cast<VectorType>(Real->getType());
2433 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2434
2435 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2436
2437 IRBuilder<> Builder(Incoming->getTerminator());
2438
2439 Value *NewInit = nullptr;
2440 if (auto *C = dyn_cast<Constant>(Init)) {
2441 if (C->isZeroValue())
2442 NewInit = Constant::getNullValue(NewVTy);
2443 }
2444
2445 if (!NewInit)
2446 NewInit =
2448
2449 NewPHI->addIncoming(NewInit, Incoming);
2450 NewPHI->addIncoming(OperationReplacement, BackEdge);
2451
2452 auto *FinalReduction = ReductionInfo[Real].second;
2453 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2454
2455 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2456 FinalReduction->replaceAllUsesWith(AddReduce);
2457}
2458
2459void ComplexDeinterleavingGraph::processReductionOperation(
2460 Value *OperationReplacement, CompositeNode *Node) {
2461 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2462 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2463 auto *OldPHIReal = ReductionInfo[Real].first;
2464 auto *OldPHIImag = ReductionInfo[Imag].first;
2465 auto *NewPHI = OldToNewPHI[OldPHIReal];
2466
2467 // We have to interleave initial origin values coming from IncomingBlock
2468 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2469 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2470
2471 IRBuilder<> Builder(Incoming->getTerminator());
2472 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2473
2474 NewPHI->addIncoming(NewInit, Incoming);
2475 NewPHI->addIncoming(OperationReplacement, BackEdge);
2476
2477 // Deinterleave complex vector outside of loop so that it can be finally
2478 // reduced
2479 auto *FinalReductionReal = ReductionInfo[Real].second;
2480 auto *FinalReductionImag = ReductionInfo[Imag].second;
2481
2482 Builder.SetInsertPoint(
2483 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2484 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2485 OperationReplacement->getType(),
2486 OperationReplacement);
2487
2488 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2489 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2490
2491 Builder.SetInsertPoint(FinalReductionImag);
2492 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2493 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2494}
2495
2496void ComplexDeinterleavingGraph::replaceNodes() {
2497 SmallVector<Instruction *, 16> DeadInstrRoots;
2498 for (auto *RootInstruction : OrderedRoots) {
2499 // Check if this potential root went through check process and we can
2500 // deinterleave it
2501 if (!RootToNode.count(RootInstruction))
2502 continue;
2503
2504 IRBuilder<> Builder(RootInstruction);
2505 auto RootNode = RootToNode[RootInstruction];
2506 Value *R = replaceNode(Builder, RootNode);
2507
2508 if (RootNode->Operation ==
2509 ComplexDeinterleavingOperation::ReductionOperation) {
2510 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2511 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2512 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2513 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2514 DeadInstrRoots.push_back(RootReal);
2515 DeadInstrRoots.push_back(RootImag);
2516 } else if (RootNode->Operation ==
2517 ComplexDeinterleavingOperation::ReductionSingle) {
2518 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2519 auto &Info = ReductionInfo[RootInst];
2520 Info.first->removeIncomingValue(BackEdge);
2521 DeadInstrRoots.push_back(Info.second);
2522 } else {
2523 assert(R && "Unable to find replacement for RootInstruction");
2524 DeadInstrRoots.push_back(RootInstruction);
2525 RootInstruction->replaceAllUsesWith(R);
2526 }
2527 }
2528
2529 for (auto *I : DeadInstrRoots)
2531}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
Complex Deinterleaving
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
#define DEBUG_TYPE
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
std::optional< std::vector< StOtherPiece > > Other
Definition: ELFYAML.cpp:1328
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
Basic Register Allocator
SI optimize exec mask operations pre RA
raw_pwrite_stream & OS
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
#define LLVM_DEBUG(...)
Definition: Debug.h:119
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
BinaryOperator * Mul
DEMANGLE_DUMP_METHOD void dump() const
A linked-list with a custom, local allocator.
Definition: AllocatorList.h:33
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:270
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:147
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
Definition: BasicBlock.cpp:337
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:213
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:177
iterator end()
Definition: DenseMap.h:87
This instruction extracts a struct member or array element value from an aggregate value.
bool allowContract() const
Definition: FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition: IRBuilder.h:2618
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Definition: IRBuilder.cpp:1005
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition: IRBuilder.cpp:366
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition: IRBuilder.cpp:834
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
Definition: IRBuilder.cpp:1135
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2780
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:585
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:82
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:312
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
This class implements a map that also provides access to all stored values in a deterministic order.
Definition: MapVector.h:36
size_type size() const
Definition: MapVector.h:56
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:112
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:85
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition: Analysis.h:132
This instruction constructs a fixed permutation of two input vectors.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:470
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:401
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:541
bool empty() const
Definition: SmallVector.h:82
size_t size() const
Definition: SmallVector.h:79
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:938
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
A BumpPtrAllocator that allows only elements of a specific type to be allocated.
Definition: Allocator.h:390
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:55
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:83
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:273
Value * getOperand(unsigned i) const
Definition: User.h:232
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:546
An opaque object representing a hash code.
Definition: Hashing.h:76
const ParentTy * getParent() const
Definition: ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:359
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:126
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
@ BR
Control flow instructions. These all have token chains.
Definition: ISDOpcodes.h:1157
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition: PatternMatch.h:100
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
bind_ty< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
Definition: PatternMatch.h:862
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:444
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition: DWP.cpp:477
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1744
hash_code hash_value(const FixedPointSemantics &Val)
Definition: APFixedPoint.h:137
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:533
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
LLVM_ABI void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &)
DWARFExpression::Operation Op
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1777
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1916
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition: STLExtras.h:2127
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition: Hashing.h:595
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:858
#define N
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...
Definition: DenseMapInfo.h:54
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...