LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
295 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
302 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
452 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
487 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
548void SCEVUnknown::deleted() {
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
685 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
686 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
704 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
705 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
830 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1163 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1167 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1171 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1186 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1194 isa<SCEVTruncateExpr>(S))
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1365 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1569 if (!isa<SCEVZeroExtendExpr>(S))
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1586 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1606 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1622 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1633 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1722 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1767 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (matchURem(Op, LHS, RHS))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 if (SM->getNumOperands() == 2)
1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845 if (MulLHS->getAPInt().isPowerOf2())
1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848 MulLHS->getAPInt().logBase2();
1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850 return getMulExpr(
1851 getZeroExtendExpr(MulLHS, Ty),
1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1860 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1861 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1865 if (isa<SCEVUMinExpr>(MinMax))
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1871 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1903 if (!isa<SCEVSignExtendExpr>(S))
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1921 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1925 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1946 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1995 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2006 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2093 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2118 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2131 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2132 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2136 if (isa<SCEVSMinExpr>(MinMax))
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2183 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2260 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342 // TODO: Lift this limitation.
2343 if (!RHSC)
2344 return false;
2345 APInt C = RHSC->getAPInt();
2346 unsigned NumBits = C.getBitWidth();
2347 if (BinOp == Instruction::Mul) {
2348 // Multiplying by 0 or 1 never overflows
2349 if (C.isZero() || C.isOne())
2350 return true;
2351 if (Signed)
2352 return false;
2353 APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
2354 // To avoid overflow, we need to make sure that LHS <= MAX / C.
2356 CtxI);
2357 }
2358 bool IsSub = (BinOp == Instruction::Sub);
2359 bool IsNegativeConst = (Signed && C.isNegative());
2360 // Compute the direction and magnitude by which we need to check overflow.
2361 bool OverflowDown = IsSub ^ IsNegativeConst;
2362 APInt Magnitude = C;
2363 if (IsNegativeConst) {
2364 if (C == APInt::getSignedMinValue(NumBits))
2365 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2366 // want to deal with that.
2367 return false;
2368 Magnitude = -C;
2369 }
2370
2372 if (OverflowDown) {
2373 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2374 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2375 : APInt::getMinValue(NumBits);
2376 APInt Limit = Min + Magnitude;
2377 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2378 } else {
2379 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2380 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2381 : APInt::getMaxValue(NumBits);
2382 APInt Limit = Max - Magnitude;
2383 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2384 }
2385}
2386
2387std::optional<SCEV::NoWrapFlags>
2389 const OverflowingBinaryOperator *OBO) {
2390 // It cannot be done any better.
2391 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2392 return std::nullopt;
2393
2395
2396 if (OBO->hasNoUnsignedWrap())
2398 if (OBO->hasNoSignedWrap())
2400
2401 bool Deduced = false;
2402
2403 if (OBO->getOpcode() != Instruction::Add &&
2404 OBO->getOpcode() != Instruction::Sub &&
2405 OBO->getOpcode() != Instruction::Mul)
2406 return std::nullopt;
2407
2408 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2409 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2410
2411 const Instruction *CtxI =
2412 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2413 if (!OBO->hasNoUnsignedWrap() &&
2415 /* Signed */ false, LHS, RHS, CtxI)) {
2417 Deduced = true;
2418 }
2419
2420 if (!OBO->hasNoSignedWrap() &&
2422 /* Signed */ true, LHS, RHS, CtxI)) {
2424 Deduced = true;
2425 }
2426
2427 if (Deduced)
2428 return Flags;
2429 return std::nullopt;
2430}
2431
2432// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2433// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2434// can't-overflow flags for the operation if possible.
2435static SCEV::NoWrapFlags
2437 const ArrayRef<const SCEV *> Ops,
2438 SCEV::NoWrapFlags Flags) {
2439 using namespace std::placeholders;
2440
2441 using OBO = OverflowingBinaryOperator;
2442
2443 bool CanAnalyze =
2445 (void)CanAnalyze;
2446 assert(CanAnalyze && "don't call from other places!");
2447
2448 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2449 SCEV::NoWrapFlags SignOrUnsignWrap =
2450 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2451
2452 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2453 auto IsKnownNonNegative = [&](const SCEV *S) {
2454 return SE->isKnownNonNegative(S);
2455 };
2456
2457 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2458 Flags =
2459 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2460
2461 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2462
2463 if (SignOrUnsignWrap != SignOrUnsignMask &&
2464 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2465 isa<SCEVConstant>(Ops[0])) {
2466
2467 auto Opcode = [&] {
2468 switch (Type) {
2469 case scAddExpr:
2470 return Instruction::Add;
2471 case scMulExpr:
2472 return Instruction::Mul;
2473 default:
2474 llvm_unreachable("Unexpected SCEV op.");
2475 }
2476 }();
2477
2478 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2483 Opcode, C, OBO::NoSignedWrap);
2484 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2486 }
2487
2488 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2489 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2491 Opcode, C, OBO::NoUnsignedWrap);
2492 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2494 }
2495 }
2496
2497 // <0,+,nonnegative><nw> is also nuw
2498 // TODO: Add corresponding nsw case
2500 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2501 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2503
2504 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2506 Ops.size() == 2) {
2507 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2508 if (UDiv->getOperand(1) == Ops[1])
2510 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2511 if (UDiv->getOperand(1) == Ops[0])
2513 }
2514
2515 return Flags;
2516}
2517
2519 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2520}
2521
2522/// Get a canonical add expression, or something simpler if possible.
2524 SCEV::NoWrapFlags OrigFlags,
2525 unsigned Depth) {
2526 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2527 "only nuw or nsw allowed");
2528 assert(!Ops.empty() && "Cannot get empty add!");
2529 if (Ops.size() == 1) return Ops[0];
2530#ifndef NDEBUG
2531 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2532 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2533 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2534 "SCEVAddExpr operand types don't match!");
2535 unsigned NumPtrs = count_if(
2536 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2537 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2538#endif
2539
2540 const SCEV *Folded = constantFoldAndGroupOps(
2541 *this, LI, DT, Ops,
2542 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2543 [](const APInt &C) { return C.isZero(); }, // identity
2544 [](const APInt &C) { return false; }); // absorber
2545 if (Folded)
2546 return Folded;
2547
2548 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2549
2550 // Delay expensive flag strengthening until necessary.
2551 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2552 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2553 };
2554
2555 // Limit recursion calls depth.
2557 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2558
2559 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2560 // Don't strengthen flags if we have no new information.
2561 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2562 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2563 Add->setNoWrapFlags(ComputeFlags(Ops));
2564 return S;
2565 }
2566
2567 // Okay, check to see if the same value occurs in the operand list more than
2568 // once. If so, merge them together into an multiply expression. Since we
2569 // sorted the list, these values are required to be adjacent.
2570 Type *Ty = Ops[0]->getType();
2571 bool FoundMatch = false;
2572 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2573 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2574 // Scan ahead to count how many equal operands there are.
2575 unsigned Count = 2;
2576 while (i+Count != e && Ops[i+Count] == Ops[i])
2577 ++Count;
2578 // Merge the values into a multiply.
2579 const SCEV *Scale = getConstant(Ty, Count);
2580 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2581 if (Ops.size() == Count)
2582 return Mul;
2583 Ops[i] = Mul;
2584 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2585 --i; e -= Count - 1;
2586 FoundMatch = true;
2587 }
2588 if (FoundMatch)
2589 return getAddExpr(Ops, OrigFlags, Depth + 1);
2590
2591 // Check for truncates. If all the operands are truncated from the same
2592 // type, see if factoring out the truncate would permit the result to be
2593 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2594 // if the contents of the resulting outer trunc fold to something simple.
2595 auto FindTruncSrcType = [&]() -> Type * {
2596 // We're ultimately looking to fold an addrec of truncs and muls of only
2597 // constants and truncs, so if we find any other types of SCEV
2598 // as operands of the addrec then we bail and return nullptr here.
2599 // Otherwise, we return the type of the operand of a trunc that we find.
2600 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2601 return T->getOperand()->getType();
2602 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2603 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2604 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2605 return T->getOperand()->getType();
2606 }
2607 return nullptr;
2608 };
2609 if (auto *SrcType = FindTruncSrcType()) {
2611 bool Ok = true;
2612 // Check all the operands to see if they can be represented in the
2613 // source type of the truncate.
2614 for (const SCEV *Op : Ops) {
2615 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2616 if (T->getOperand()->getType() != SrcType) {
2617 Ok = false;
2618 break;
2619 }
2620 LargeOps.push_back(T->getOperand());
2621 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2622 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2623 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2624 SmallVector<const SCEV *, 8> LargeMulOps;
2625 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2626 if (const SCEVTruncateExpr *T =
2627 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2628 if (T->getOperand()->getType() != SrcType) {
2629 Ok = false;
2630 break;
2631 }
2632 LargeMulOps.push_back(T->getOperand());
2633 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2634 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2635 } else {
2636 Ok = false;
2637 break;
2638 }
2639 }
2640 if (Ok)
2641 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2642 } else {
2643 Ok = false;
2644 break;
2645 }
2646 }
2647 if (Ok) {
2648 // Evaluate the expression in the larger type.
2649 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2650 // If it folds to something simple, use it. Otherwise, don't.
2651 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2652 return getTruncateExpr(Fold, Ty);
2653 }
2654 }
2655
2656 if (Ops.size() == 2) {
2657 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2658 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2659 // C1).
2660 const SCEV *A = Ops[0];
2661 const SCEV *B = Ops[1];
2662 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2663 auto *C = dyn_cast<SCEVConstant>(A);
2664 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2665 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2666 auto C2 = C->getAPInt();
2667 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2668
2669 APInt ConstAdd = C1 + C2;
2670 auto AddFlags = AddExpr->getNoWrapFlags();
2671 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2673 ConstAdd.ule(C1)) {
2674 PreservedFlags =
2676 }
2677
2678 // Adding a constant with the same sign and small magnitude is NSW, if the
2679 // original AddExpr was NSW.
2681 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2682 ConstAdd.abs().ule(C1.abs())) {
2683 PreservedFlags =
2685 }
2686
2687 if (PreservedFlags != SCEV::FlagAnyWrap) {
2688 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2689 NewOps[0] = getConstant(ConstAdd);
2690 return getAddExpr(NewOps, PreservedFlags);
2691 }
2692 }
2693
2694 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2695 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2696 const SCEVAddExpr *InnerAdd;
2697 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2698 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2699 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2700 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2701 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2703 SCEV::FlagNUW)) {
2704 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2705 }
2706 }
2707 }
2708
2709 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2710 if (Ops.size() == 2) {
2711 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2712 if (Mul && Mul->getNumOperands() == 2 &&
2713 Mul->getOperand(0)->isAllOnesValue()) {
2714 const SCEV *X;
2715 const SCEV *Y;
2716 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2717 return getMulExpr(Y, getUDivExpr(X, Y));
2718 }
2719 }
2720 }
2721
2722 // Skip past any other cast SCEVs.
2723 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2724 ++Idx;
2725
2726 // If there are add operands they would be next.
2727 if (Idx < Ops.size()) {
2728 bool DeletedAdd = false;
2729 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2730 // common NUW flag for expression after inlining. Other flags cannot be
2731 // preserved, because they may depend on the original order of operations.
2732 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2733 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2734 if (Ops.size() > AddOpsInlineThreshold ||
2735 Add->getNumOperands() > AddOpsInlineThreshold)
2736 break;
2737 // If we have an add, expand the add operands onto the end of the operands
2738 // list.
2739 Ops.erase(Ops.begin()+Idx);
2740 append_range(Ops, Add->operands());
2741 DeletedAdd = true;
2742 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2743 }
2744
2745 // If we deleted at least one add, we added operands to the end of the list,
2746 // and they are not necessarily sorted. Recurse to resort and resimplify
2747 // any operands we just acquired.
2748 if (DeletedAdd)
2749 return getAddExpr(Ops, CommonFlags, Depth + 1);
2750 }
2751
2752 // Skip over the add expression until we get to a multiply.
2753 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2754 ++Idx;
2755
2756 // Check to see if there are any folding opportunities present with
2757 // operands multiplied by constant values.
2758 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2762 APInt AccumulatedConstant(BitWidth, 0);
2763 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2764 Ops, APInt(BitWidth, 1), *this)) {
2765 struct APIntCompare {
2766 bool operator()(const APInt &LHS, const APInt &RHS) const {
2767 return LHS.ult(RHS);
2768 }
2769 };
2770
2771 // Some interesting folding opportunity is present, so its worthwhile to
2772 // re-generate the operands list. Group the operands by constant scale,
2773 // to avoid multiplying by the same constant scale multiple times.
2774 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2775 for (const SCEV *NewOp : NewOps)
2776 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2777 // Re-generate the operands list.
2778 Ops.clear();
2779 if (AccumulatedConstant != 0)
2780 Ops.push_back(getConstant(AccumulatedConstant));
2781 for (auto &MulOp : MulOpLists) {
2782 if (MulOp.first == 1) {
2783 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2784 } else if (MulOp.first != 0) {
2786 getConstant(MulOp.first),
2787 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2788 SCEV::FlagAnyWrap, Depth + 1));
2789 }
2790 }
2791 if (Ops.empty())
2792 return getZero(Ty);
2793 if (Ops.size() == 1)
2794 return Ops[0];
2795 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2796 }
2797 }
2798
2799 // If we are adding something to a multiply expression, make sure the
2800 // something is not already an operand of the multiply. If so, merge it into
2801 // the multiply.
2802 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2803 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2804 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2805 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2806 if (isa<SCEVConstant>(MulOpSCEV))
2807 continue;
2808 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2809 if (MulOpSCEV == Ops[AddOp]) {
2810 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2811 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2812 if (Mul->getNumOperands() != 2) {
2813 // If the multiply has more than two operands, we must get the
2814 // Y*Z term.
2816 Mul->operands().take_front(MulOp));
2817 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2818 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2819 }
2820 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2821 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2822 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2824 if (Ops.size() == 2) return OuterMul;
2825 if (AddOp < Idx) {
2826 Ops.erase(Ops.begin()+AddOp);
2827 Ops.erase(Ops.begin()+Idx-1);
2828 } else {
2829 Ops.erase(Ops.begin()+Idx);
2830 Ops.erase(Ops.begin()+AddOp-1);
2831 }
2832 Ops.push_back(OuterMul);
2833 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2834 }
2835
2836 // Check this multiply against other multiplies being added together.
2837 for (unsigned OtherMulIdx = Idx+1;
2838 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2839 ++OtherMulIdx) {
2840 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2841 // If MulOp occurs in OtherMul, we can fold the two multiplies
2842 // together.
2843 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2844 OMulOp != e; ++OMulOp)
2845 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2846 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2847 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2848 if (Mul->getNumOperands() != 2) {
2850 Mul->operands().take_front(MulOp));
2851 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2852 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2853 }
2854 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2855 if (OtherMul->getNumOperands() != 2) {
2857 OtherMul->operands().take_front(OMulOp));
2858 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2859 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2860 }
2861 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2862 const SCEV *InnerMulSum =
2863 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2864 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2866 if (Ops.size() == 2) return OuterMul;
2867 Ops.erase(Ops.begin()+Idx);
2868 Ops.erase(Ops.begin()+OtherMulIdx-1);
2869 Ops.push_back(OuterMul);
2870 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2871 }
2872 }
2873 }
2874 }
2875
2876 // If there are any add recurrences in the operands list, see if any other
2877 // added values are loop invariant. If so, we can fold them into the
2878 // recurrence.
2879 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2880 ++Idx;
2881
2882 // Scan over all recurrences, trying to fold loop invariants into them.
2883 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2884 // Scan all of the other operands to this add and add them to the vector if
2885 // they are loop invariant w.r.t. the recurrence.
2887 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2888 const Loop *AddRecLoop = AddRec->getLoop();
2889 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2890 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2891 LIOps.push_back(Ops[i]);
2892 Ops.erase(Ops.begin()+i);
2893 --i; --e;
2894 }
2895
2896 // If we found some loop invariants, fold them into the recurrence.
2897 if (!LIOps.empty()) {
2898 // Compute nowrap flags for the addition of the loop-invariant ops and
2899 // the addrec. Temporarily push it as an operand for that purpose. These
2900 // flags are valid in the scope of the addrec only.
2901 LIOps.push_back(AddRec);
2902 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2903 LIOps.pop_back();
2904
2905 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2906 LIOps.push_back(AddRec->getStart());
2907
2908 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2909
2910 // It is not in general safe to propagate flags valid on an add within
2911 // the addrec scope to one outside it. We must prove that the inner
2912 // scope is guaranteed to execute if the outer one does to be able to
2913 // safely propagate. We know the program is undefined if poison is
2914 // produced on the inner scoped addrec. We also know that *for this use*
2915 // the outer scoped add can't overflow (because of the flags we just
2916 // computed for the inner scoped add) without the program being undefined.
2917 // Proving that entry to the outer scope neccesitates entry to the inner
2918 // scope, thus proves the program undefined if the flags would be violated
2919 // in the outer scope.
2920 SCEV::NoWrapFlags AddFlags = Flags;
2921 if (AddFlags != SCEV::FlagAnyWrap) {
2922 auto *DefI = getDefiningScopeBound(LIOps);
2923 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2924 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2925 AddFlags = SCEV::FlagAnyWrap;
2926 }
2927 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2928
2929 // Build the new addrec. Propagate the NUW and NSW flags if both the
2930 // outer add and the inner addrec are guaranteed to have no overflow.
2931 // Always propagate NW.
2932 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2933 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2934
2935 // If all of the other operands were loop invariant, we are done.
2936 if (Ops.size() == 1) return NewRec;
2937
2938 // Otherwise, add the folded AddRec by the non-invariant parts.
2939 for (unsigned i = 0;; ++i)
2940 if (Ops[i] == AddRec) {
2941 Ops[i] = NewRec;
2942 break;
2943 }
2944 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2945 }
2946
2947 // Okay, if there weren't any loop invariants to be folded, check to see if
2948 // there are multiple AddRec's with the same loop induction variable being
2949 // added together. If so, we can fold them.
2950 for (unsigned OtherIdx = Idx+1;
2951 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 ++OtherIdx) {
2953 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2954 // so that the 1st found AddRecExpr is dominated by all others.
2955 assert(DT.dominates(
2956 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2957 AddRec->getLoop()->getHeader()) &&
2958 "AddRecExprs are not sorted in reverse dominance order?");
2959 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2960 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2961 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2962 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2963 ++OtherIdx) {
2964 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2965 if (OtherAddRec->getLoop() == AddRecLoop) {
2966 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2967 i != e; ++i) {
2968 if (i >= AddRecOps.size()) {
2969 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2970 break;
2971 }
2973 AddRecOps[i], OtherAddRec->getOperand(i)};
2974 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2975 }
2976 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2977 }
2978 }
2979 // Step size has changed, so we cannot guarantee no self-wraparound.
2980 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2981 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2982 }
2983 }
2984
2985 // Otherwise couldn't fold anything into this recurrence. Move onto the
2986 // next one.
2987 }
2988
2989 // Okay, it looks like we really DO need an add expr. Check to see if we
2990 // already have one, otherwise create a new one.
2991 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2992}
2993
2994const SCEV *
2995ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2996 SCEV::NoWrapFlags Flags) {
2998 ID.AddInteger(scAddExpr);
2999 for (const SCEV *Op : Ops)
3000 ID.AddPointer(Op);
3001 void *IP = nullptr;
3002 SCEVAddExpr *S =
3003 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3004 if (!S) {
3005 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3007 S = new (SCEVAllocator)
3008 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3009 UniqueSCEVs.InsertNode(S, IP);
3010 registerUser(S, Ops);
3011 }
3012 S->setNoWrapFlags(Flags);
3013 return S;
3014}
3015
3016const SCEV *
3017ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3018 const Loop *L, SCEV::NoWrapFlags Flags) {
3020 ID.AddInteger(scAddRecExpr);
3021 for (const SCEV *Op : Ops)
3022 ID.AddPointer(Op);
3023 ID.AddPointer(L);
3024 void *IP = nullptr;
3025 SCEVAddRecExpr *S =
3026 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3027 if (!S) {
3028 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3030 S = new (SCEVAllocator)
3031 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3032 UniqueSCEVs.InsertNode(S, IP);
3033 LoopUsers[L].push_back(S);
3034 registerUser(S, Ops);
3035 }
3036 setNoWrapFlags(S, Flags);
3037 return S;
3038}
3039
3040const SCEV *
3041ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3042 SCEV::NoWrapFlags Flags) {
3044 ID.AddInteger(scMulExpr);
3045 for (const SCEV *Op : Ops)
3046 ID.AddPointer(Op);
3047 void *IP = nullptr;
3048 SCEVMulExpr *S =
3049 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3050 if (!S) {
3051 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3053 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3054 O, Ops.size());
3055 UniqueSCEVs.InsertNode(S, IP);
3056 registerUser(S, Ops);
3057 }
3058 S->setNoWrapFlags(Flags);
3059 return S;
3060}
3061
3062static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3063 uint64_t k = i*j;
3064 if (j > 1 && k / j != i) Overflow = true;
3065 return k;
3066}
3067
3068/// Compute the result of "n choose k", the binomial coefficient. If an
3069/// intermediate computation overflows, Overflow will be set and the return will
3070/// be garbage. Overflow is not cleared on absence of overflow.
3071static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3072 // We use the multiplicative formula:
3073 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3074 // At each iteration, we take the n-th term of the numeral and divide by the
3075 // (k-n)th term of the denominator. This division will always produce an
3076 // integral result, and helps reduce the chance of overflow in the
3077 // intermediate computations. However, we can still overflow even when the
3078 // final result would fit.
3079
3080 if (n == 0 || n == k) return 1;
3081 if (k > n) return 0;
3082
3083 if (k > n/2)
3084 k = n-k;
3085
3086 uint64_t r = 1;
3087 for (uint64_t i = 1; i <= k; ++i) {
3088 r = umul_ov(r, n-(i-1), Overflow);
3089 r /= i;
3090 }
3091 return r;
3092}
3093
3094/// Determine if any of the operands in this SCEV are a constant or if
3095/// any of the add or multiply expressions in this SCEV contain a constant.
3096static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3097 struct FindConstantInAddMulChain {
3098 bool FoundConstant = false;
3099
3100 bool follow(const SCEV *S) {
3101 FoundConstant |= isa<SCEVConstant>(S);
3102 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3103 }
3104
3105 bool isDone() const {
3106 return FoundConstant;
3107 }
3108 };
3109
3110 FindConstantInAddMulChain F;
3112 ST.visitAll(StartExpr);
3113 return F.FoundConstant;
3114}
3115
3116/// Get a canonical multiply expression, or something simpler if possible.
3118 SCEV::NoWrapFlags OrigFlags,
3119 unsigned Depth) {
3120 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3121 "only nuw or nsw allowed");
3122 assert(!Ops.empty() && "Cannot get empty mul!");
3123 if (Ops.size() == 1) return Ops[0];
3124#ifndef NDEBUG
3125 Type *ETy = Ops[0]->getType();
3126 assert(!ETy->isPointerTy());
3127 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3128 assert(Ops[i]->getType() == ETy &&
3129 "SCEVMulExpr operand types don't match!");
3130#endif
3131
3132 const SCEV *Folded = constantFoldAndGroupOps(
3133 *this, LI, DT, Ops,
3134 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3135 [](const APInt &C) { return C.isOne(); }, // identity
3136 [](const APInt &C) { return C.isZero(); }); // absorber
3137 if (Folded)
3138 return Folded;
3139
3140 // Delay expensive flag strengthening until necessary.
3141 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3142 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3143 };
3144
3145 // Limit recursion calls depth.
3147 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3148
3149 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3150 // Don't strengthen flags if we have no new information.
3151 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3152 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3153 Mul->setNoWrapFlags(ComputeFlags(Ops));
3154 return S;
3155 }
3156
3157 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3158 if (Ops.size() == 2) {
3159 // C1*(C2+V) -> C1*C2 + C1*V
3160 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3161 // If any of Add's ops are Adds or Muls with a constant, apply this
3162 // transformation as well.
3163 //
3164 // TODO: There are some cases where this transformation is not
3165 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3166 // this transformation should be narrowed down.
3167 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3168 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3170 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3172 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3173 }
3174
3175 if (Ops[0]->isAllOnesValue()) {
3176 // If we have a mul by -1 of an add, try distributing the -1 among the
3177 // add operands.
3178 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3180 bool AnyFolded = false;
3181 for (const SCEV *AddOp : Add->operands()) {
3182 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3183 Depth + 1);
3184 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3185 NewOps.push_back(Mul);
3186 }
3187 if (AnyFolded)
3188 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3189 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3190 // Negation preserves a recurrence's no self-wrap property.
3192 for (const SCEV *AddRecOp : AddRec->operands())
3193 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3194 Depth + 1));
3195 // Let M be the minimum representable signed value. AddRec with nsw
3196 // multiplied by -1 can have signed overflow if and only if it takes a
3197 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3198 // maximum signed value. In all other cases signed overflow is
3199 // impossible.
3200 auto FlagsMask = SCEV::FlagNW;
3201 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3202 auto MinInt =
3203 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3204 if (getSignedRangeMin(AddRec) != MinInt)
3205 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3206 }
3207 return getAddRecExpr(Operands, AddRec->getLoop(),
3208 AddRec->getNoWrapFlags(FlagsMask));
3209 }
3210 }
3211
3212 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3213 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3214 const SCEVAddExpr *InnerAdd;
3215 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3216 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3217 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3218 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3219 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3221 SCEV::FlagNUW)) {
3222 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3223 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3224 };
3225 }
3226 }
3227 }
3228
3229 // Skip over the add expression until we get to a multiply.
3230 unsigned Idx = 0;
3231 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3232 ++Idx;
3233
3234 // If there are mul operands inline them all into this expression.
3235 if (Idx < Ops.size()) {
3236 bool DeletedMul = false;
3237 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3238 if (Ops.size() > MulOpsInlineThreshold)
3239 break;
3240 // If we have an mul, expand the mul operands onto the end of the
3241 // operands list.
3242 Ops.erase(Ops.begin()+Idx);
3243 append_range(Ops, Mul->operands());
3244 DeletedMul = true;
3245 }
3246
3247 // If we deleted at least one mul, we added operands to the end of the
3248 // list, and they are not necessarily sorted. Recurse to resort and
3249 // resimplify any operands we just acquired.
3250 if (DeletedMul)
3251 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3252 }
3253
3254 // If there are any add recurrences in the operands list, see if any other
3255 // added values are loop invariant. If so, we can fold them into the
3256 // recurrence.
3257 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3258 ++Idx;
3259
3260 // Scan over all recurrences, trying to fold loop invariants into them.
3261 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3262 // Scan all of the other operands to this mul and add them to the vector
3263 // if they are loop invariant w.r.t. the recurrence.
3265 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3266 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3267 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3268 LIOps.push_back(Ops[i]);
3269 Ops.erase(Ops.begin()+i);
3270 --i; --e;
3271 }
3272
3273 // If we found some loop invariants, fold them into the recurrence.
3274 if (!LIOps.empty()) {
3275 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3277 NewOps.reserve(AddRec->getNumOperands());
3278 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3279
3280 // If both the mul and addrec are nuw, we can preserve nuw.
3281 // If both the mul and addrec are nsw, we can only preserve nsw if either
3282 // a) they are also nuw, or
3283 // b) all multiplications of addrec operands with scale are nsw.
3284 SCEV::NoWrapFlags Flags =
3285 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3286
3287 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3288 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3289 SCEV::FlagAnyWrap, Depth + 1));
3290
3291 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3293 Instruction::Mul, getSignedRange(Scale),
3295 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3296 Flags = clearFlags(Flags, SCEV::FlagNSW);
3297 }
3298 }
3299
3300 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3301
3302 // If all of the other operands were loop invariant, we are done.
3303 if (Ops.size() == 1) return NewRec;
3304
3305 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3306 for (unsigned i = 0;; ++i)
3307 if (Ops[i] == AddRec) {
3308 Ops[i] = NewRec;
3309 break;
3310 }
3311 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3312 }
3313
3314 // Okay, if there weren't any loop invariants to be folded, check to see
3315 // if there are multiple AddRec's with the same loop induction variable
3316 // being multiplied together. If so, we can fold them.
3317
3318 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3319 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3320 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3321 // ]]],+,...up to x=2n}.
3322 // Note that the arguments to choose() are always integers with values
3323 // known at compile time, never SCEV objects.
3324 //
3325 // The implementation avoids pointless extra computations when the two
3326 // addrec's are of different length (mathematically, it's equivalent to
3327 // an infinite stream of zeros on the right).
3328 bool OpsModified = false;
3329 for (unsigned OtherIdx = Idx+1;
3330 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3331 ++OtherIdx) {
3332 const SCEVAddRecExpr *OtherAddRec =
3333 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3334 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3335 continue;
3336
3337 // Limit max number of arguments to avoid creation of unreasonably big
3338 // SCEVAddRecs with very complex operands.
3339 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3340 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3341 continue;
3342
3343 bool Overflow = false;
3344 Type *Ty = AddRec->getType();
3345 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3347 for (int x = 0, xe = AddRec->getNumOperands() +
3348 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3349 SmallVector <const SCEV *, 7> SumOps;
3350 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3351 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3352 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3353 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3354 z < ze && !Overflow; ++z) {
3355 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3356 uint64_t Coeff;
3357 if (LargerThan64Bits)
3358 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3359 else
3360 Coeff = Coeff1*Coeff2;
3361 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3362 const SCEV *Term1 = AddRec->getOperand(y-z);
3363 const SCEV *Term2 = OtherAddRec->getOperand(z);
3364 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3365 SCEV::FlagAnyWrap, Depth + 1));
3366 }
3367 }
3368 if (SumOps.empty())
3369 SumOps.push_back(getZero(Ty));
3370 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3371 }
3372 if (!Overflow) {
3373 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3375 if (Ops.size() == 2) return NewAddRec;
3376 Ops[Idx] = NewAddRec;
3377 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3378 OpsModified = true;
3379 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3380 if (!AddRec)
3381 break;
3382 }
3383 }
3384 if (OpsModified)
3385 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3386
3387 // Otherwise couldn't fold anything into this recurrence. Move onto the
3388 // next one.
3389 }
3390
3391 // Okay, it looks like we really DO need an mul expr. Check to see if we
3392 // already have one, otherwise create a new one.
3393 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3394}
3395
3396/// Represents an unsigned remainder expression based on unsigned division.
3398 const SCEV *RHS) {
3401 "SCEVURemExpr operand types don't match!");
3402
3403 // Short-circuit easy cases
3404 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3405 // If constant is one, the result is trivial
3406 if (RHSC->getValue()->isOne())
3407 return getZero(LHS->getType()); // X urem 1 --> 0
3408
3409 // If constant is a power of two, fold into a zext(trunc(LHS)).
3410 if (RHSC->getAPInt().isPowerOf2()) {
3411 Type *FullTy = LHS->getType();
3412 Type *TruncTy =
3413 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3414 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3415 }
3416 }
3417
3418 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3419 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3420 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3421 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3422}
3423
3424/// Get a canonical unsigned division expression, or something simpler if
3425/// possible.
3427 const SCEV *RHS) {
3428 assert(!LHS->getType()->isPointerTy() &&
3429 "SCEVUDivExpr operand can't be pointer!");
3430 assert(LHS->getType() == RHS->getType() &&
3431 "SCEVUDivExpr operand types don't match!");
3432
3434 ID.AddInteger(scUDivExpr);
3435 ID.AddPointer(LHS);
3436 ID.AddPointer(RHS);
3437 void *IP = nullptr;
3438 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3439 return S;
3440
3441 // 0 udiv Y == 0
3442 if (match(LHS, m_scev_Zero()))
3443 return LHS;
3444
3445 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3446 if (RHSC->getValue()->isOne())
3447 return LHS; // X udiv 1 --> x
3448 // If the denominator is zero, the result of the udiv is undefined. Don't
3449 // try to analyze it, because the resolution chosen here may differ from
3450 // the resolution chosen in other parts of the compiler.
3451 if (!RHSC->getValue()->isZero()) {
3452 // Determine if the division can be folded into the operands of
3453 // its operands.
3454 // TODO: Generalize this to non-constants by using known-bits information.
3455 Type *Ty = LHS->getType();
3456 unsigned LZ = RHSC->getAPInt().countl_zero();
3457 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3458 // For non-power-of-two values, effectively round the value up to the
3459 // nearest power of two.
3460 if (!RHSC->getAPInt().isPowerOf2())
3461 ++MaxShiftAmt;
3462 IntegerType *ExtTy =
3463 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3464 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3465 if (const SCEVConstant *Step =
3466 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3467 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3468 const APInt &StepInt = Step->getAPInt();
3469 const APInt &DivInt = RHSC->getAPInt();
3470 if (!StepInt.urem(DivInt) &&
3471 getZeroExtendExpr(AR, ExtTy) ==
3472 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3473 getZeroExtendExpr(Step, ExtTy),
3474 AR->getLoop(), SCEV::FlagAnyWrap)) {
3476 for (const SCEV *Op : AR->operands())
3477 Operands.push_back(getUDivExpr(Op, RHS));
3478 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3479 }
3480 /// Get a canonical UDivExpr for a recurrence.
3481 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3482 // We can currently only fold X%N if X is constant.
3483 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3484 if (StartC && !DivInt.urem(StepInt) &&
3485 getZeroExtendExpr(AR, ExtTy) ==
3486 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3487 getZeroExtendExpr(Step, ExtTy),
3488 AR->getLoop(), SCEV::FlagAnyWrap)) {
3489 const APInt &StartInt = StartC->getAPInt();
3490 const APInt &StartRem = StartInt.urem(StepInt);
3491 if (StartRem != 0) {
3492 const SCEV *NewLHS =
3493 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3494 AR->getLoop(), SCEV::FlagNW);
3495 if (LHS != NewLHS) {
3496 LHS = NewLHS;
3497
3498 // Reset the ID to include the new LHS, and check if it is
3499 // already cached.
3500 ID.clear();
3501 ID.AddInteger(scUDivExpr);
3502 ID.AddPointer(LHS);
3503 ID.AddPointer(RHS);
3504 IP = nullptr;
3505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3506 return S;
3507 }
3508 }
3509 }
3510 }
3511 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3512 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3514 for (const SCEV *Op : M->operands())
3515 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3516 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3517 // Find an operand that's safely divisible.
3518 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3519 const SCEV *Op = M->getOperand(i);
3520 const SCEV *Div = getUDivExpr(Op, RHSC);
3521 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3522 Operands = SmallVector<const SCEV *, 4>(M->operands());
3523 Operands[i] = Div;
3524 return getMulExpr(Operands);
3525 }
3526 }
3527 }
3528
3529 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3530 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3531 if (auto *DivisorConstant =
3532 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3533 bool Overflow = false;
3534 APInt NewRHS =
3535 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3536 if (Overflow) {
3537 return getConstant(RHSC->getType(), 0, false);
3538 }
3539 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3540 }
3541 }
3542
3543 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3544 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3546 for (const SCEV *Op : A->operands())
3547 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3548 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3549 Operands.clear();
3550 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3551 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3552 if (isa<SCEVUDivExpr>(Op) ||
3553 getMulExpr(Op, RHS) != A->getOperand(i))
3554 break;
3555 Operands.push_back(Op);
3556 }
3557 if (Operands.size() == A->getNumOperands())
3558 return getAddExpr(Operands);
3559 }
3560 }
3561
3562 // Fold if both operands are constant.
3563 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3564 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3565 }
3566 }
3567
3568 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3569 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3570 AE && AE->getNumOperands() == 2) {
3571 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3572 const APInt &NegC = VC->getAPInt();
3573 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3574 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3575 if (MME && MME->getNumOperands() == 2 &&
3576 isa<SCEVConstant>(MME->getOperand(0)) &&
3577 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3578 MME->getOperand(1) == RHS)
3579 return getZero(LHS->getType());
3580 }
3581 }
3582 }
3583
3584 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3585 // changes). Make sure we get a new one.
3586 IP = nullptr;
3587 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3588 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3589 LHS, RHS);
3590 UniqueSCEVs.InsertNode(S, IP);
3591 registerUser(S, {LHS, RHS});
3592 return S;
3593}
3594
3595APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3596 APInt A = C1->getAPInt().abs();
3597 APInt B = C2->getAPInt().abs();
3598 uint32_t ABW = A.getBitWidth();
3599 uint32_t BBW = B.getBitWidth();
3600
3601 if (ABW > BBW)
3602 B = B.zext(ABW);
3603 else if (ABW < BBW)
3604 A = A.zext(BBW);
3605
3606 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3607}
3608
3609/// Get a canonical unsigned division expression, or something simpler if
3610/// possible. There is no representation for an exact udiv in SCEV IR, but we
3611/// can attempt to remove factors from the LHS and RHS. We can't do this when
3612/// it's not exact because the udiv may be clearing bits.
3614 const SCEV *RHS) {
3615 // TODO: we could try to find factors in all sorts of things, but for now we
3616 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3617 // end of this file for inspiration.
3618
3619 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3620 if (!Mul || !Mul->hasNoUnsignedWrap())
3621 return getUDivExpr(LHS, RHS);
3622
3623 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3624 // If the mulexpr multiplies by a constant, then that constant must be the
3625 // first element of the mulexpr.
3626 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3627 if (LHSCst == RHSCst) {
3629 return getMulExpr(Operands);
3630 }
3631
3632 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3633 // that there's a factor provided by one of the other terms. We need to
3634 // check.
3635 APInt Factor = gcd(LHSCst, RHSCst);
3636 if (!Factor.isIntN(1)) {
3637 LHSCst =
3638 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3639 RHSCst =
3640 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3642 Operands.push_back(LHSCst);
3643 append_range(Operands, Mul->operands().drop_front());
3645 RHS = RHSCst;
3646 Mul = dyn_cast<SCEVMulExpr>(LHS);
3647 if (!Mul)
3648 return getUDivExactExpr(LHS, RHS);
3649 }
3650 }
3651 }
3652
3653 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3654 if (Mul->getOperand(i) == RHS) {
3656 append_range(Operands, Mul->operands().take_front(i));
3657 append_range(Operands, Mul->operands().drop_front(i + 1));
3658 return getMulExpr(Operands);
3659 }
3660 }
3661
3662 return getUDivExpr(LHS, RHS);
3663}
3664
3665/// Get an add recurrence expression for the specified loop. Simplify the
3666/// expression as much as possible.
3667const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3668 const Loop *L,
3669 SCEV::NoWrapFlags Flags) {
3671 Operands.push_back(Start);
3672 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3673 if (StepChrec->getLoop() == L) {
3674 append_range(Operands, StepChrec->operands());
3675 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3676 }
3677
3678 Operands.push_back(Step);
3679 return getAddRecExpr(Operands, L, Flags);
3680}
3681
3682/// Get an add recurrence expression for the specified loop. Simplify the
3683/// expression as much as possible.
3684const SCEV *
3686 const Loop *L, SCEV::NoWrapFlags Flags) {
3687 if (Operands.size() == 1) return Operands[0];
3688#ifndef NDEBUG
3690 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3691 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3692 "SCEVAddRecExpr operand types don't match!");
3693 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3694 }
3695 for (const SCEV *Op : Operands)
3697 "SCEVAddRecExpr operand is not available at loop entry!");
3698#endif
3699
3700 if (Operands.back()->isZero()) {
3701 Operands.pop_back();
3702 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3703 }
3704
3705 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3706 // use that information to infer NUW and NSW flags. However, computing a
3707 // BE count requires calling getAddRecExpr, so we may not yet have a
3708 // meaningful BE count at this point (and if we don't, we'd be stuck
3709 // with a SCEVCouldNotCompute as the cached BE count).
3710
3711 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3712
3713 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3714 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3715 const Loop *NestedLoop = NestedAR->getLoop();
3716 if (L->contains(NestedLoop)
3717 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3718 : (!NestedLoop->contains(L) &&
3719 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3720 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3721 Operands[0] = NestedAR->getStart();
3722 // AddRecs require their operands be loop-invariant with respect to their
3723 // loops. Don't perform this transformation if it would break this
3724 // requirement.
3725 bool AllInvariant = all_of(
3726 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3727
3728 if (AllInvariant) {
3729 // Create a recurrence for the outer loop with the same step size.
3730 //
3731 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3732 // inner recurrence has the same property.
3733 SCEV::NoWrapFlags OuterFlags =
3734 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3735
3736 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3737 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3738 return isLoopInvariant(Op, NestedLoop);
3739 });
3740
3741 if (AllInvariant) {
3742 // Ok, both add recurrences are valid after the transformation.
3743 //
3744 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3745 // the outer recurrence has the same property.
3746 SCEV::NoWrapFlags InnerFlags =
3747 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3748 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3749 }
3750 }
3751 // Reset Operands to its original state.
3752 Operands[0] = NestedAR;
3753 }
3754 }
3755
3756 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3757 // already have one, otherwise create a new one.
3758 return getOrCreateAddRecExpr(Operands, L, Flags);
3759}
3760
3761const SCEV *
3763 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3764 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3765 // getSCEV(Base)->getType() has the same address space as Base->getType()
3766 // because SCEV::getType() preserves the address space.
3767 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3768 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3769 if (NW != GEPNoWrapFlags::none()) {
3770 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3771 // but to do that, we have to ensure that said flag is valid in the entire
3772 // defined scope of the SCEV.
3773 // TODO: non-instructions have global scope. We might be able to prove
3774 // some global scope cases
3775 auto *GEPI = dyn_cast<Instruction>(GEP);
3776 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3777 NW = GEPNoWrapFlags::none();
3778 }
3779
3781 if (NW.hasNoUnsignedSignedWrap())
3782 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3783 if (NW.hasNoUnsignedWrap())
3784 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3785
3786 Type *CurTy = GEP->getType();
3787 bool FirstIter = true;
3789 for (const SCEV *IndexExpr : IndexExprs) {
3790 // Compute the (potentially symbolic) offset in bytes for this index.
3791 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3792 // For a struct, add the member offset.
3793 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3794 unsigned FieldNo = Index->getZExtValue();
3795 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3796 Offsets.push_back(FieldOffset);
3797
3798 // Update CurTy to the type of the field at Index.
3799 CurTy = STy->getTypeAtIndex(Index);
3800 } else {
3801 // Update CurTy to its element type.
3802 if (FirstIter) {
3803 assert(isa<PointerType>(CurTy) &&
3804 "The first index of a GEP indexes a pointer");
3805 CurTy = GEP->getSourceElementType();
3806 FirstIter = false;
3807 } else {
3809 }
3810 // For an array, add the element offset, explicitly scaled.
3811 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3812 // Getelementptr indices are signed.
3813 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3814
3815 // Multiply the index by the element size to compute the element offset.
3816 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3817 Offsets.push_back(LocalOffset);
3818 }
3819 }
3820
3821 // Handle degenerate case of GEP without offsets.
3822 if (Offsets.empty())
3823 return BaseExpr;
3824
3825 // Add the offsets together, assuming nsw if inbounds.
3826 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3827 // Add the base address and the offset. We cannot use the nsw flag, as the
3828 // base address is unsigned. However, if we know that the offset is
3829 // non-negative, we can use nuw.
3830 bool NUW = NW.hasNoUnsignedWrap() ||
3833 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3834 assert(BaseExpr->getType() == GEPExpr->getType() &&
3835 "GEP should not change type mid-flight.");
3836 return GEPExpr;
3837}
3838
3839SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3842 ID.AddInteger(SCEVType);
3843 for (const SCEV *Op : Ops)
3844 ID.AddPointer(Op);
3845 void *IP = nullptr;
3846 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3847}
3848
3849const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3851 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3852}
3853
3856 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3857 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3858 if (Ops.size() == 1) return Ops[0];
3859#ifndef NDEBUG
3860 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3861 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3862 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3863 "Operand types don't match!");
3864 assert(Ops[0]->getType()->isPointerTy() ==
3865 Ops[i]->getType()->isPointerTy() &&
3866 "min/max should be consistently pointerish");
3867 }
3868#endif
3869
3870 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3871 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3872
3873 const SCEV *Folded = constantFoldAndGroupOps(
3874 *this, LI, DT, Ops,
3875 [&](const APInt &C1, const APInt &C2) {
3876 switch (Kind) {
3877 case scSMaxExpr:
3878 return APIntOps::smax(C1, C2);
3879 case scSMinExpr:
3880 return APIntOps::smin(C1, C2);
3881 case scUMaxExpr:
3882 return APIntOps::umax(C1, C2);
3883 case scUMinExpr:
3884 return APIntOps::umin(C1, C2);
3885 default:
3886 llvm_unreachable("Unknown SCEV min/max opcode");
3887 }
3888 },
3889 [&](const APInt &C) {
3890 // identity
3891 if (IsMax)
3892 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3893 else
3894 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3895 },
3896 [&](const APInt &C) {
3897 // absorber
3898 if (IsMax)
3899 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3900 else
3901 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3902 });
3903 if (Folded)
3904 return Folded;
3905
3906 // Check if we have created the same expression before.
3907 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3908 return S;
3909 }
3910
3911 // Find the first operation of the same kind
3912 unsigned Idx = 0;
3913 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3914 ++Idx;
3915
3916 // Check to see if one of the operands is of the same kind. If so, expand its
3917 // operands onto our operand list, and recurse to simplify.
3918 if (Idx < Ops.size()) {
3919 bool DeletedAny = false;
3920 while (Ops[Idx]->getSCEVType() == Kind) {
3921 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3922 Ops.erase(Ops.begin()+Idx);
3923 append_range(Ops, SMME->operands());
3924 DeletedAny = true;
3925 }
3926
3927 if (DeletedAny)
3928 return getMinMaxExpr(Kind, Ops);
3929 }
3930
3931 // Okay, check to see if the same value occurs in the operand list twice. If
3932 // so, delete one. Since we sorted the list, these values are required to
3933 // be adjacent.
3938 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3939 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3940 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3941 if (Ops[i] == Ops[i + 1] ||
3942 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3943 // X op Y op Y --> X op Y
3944 // X op Y --> X, if we know X, Y are ordered appropriately
3945 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3946 --i;
3947 --e;
3948 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3949 Ops[i + 1])) {
3950 // X op Y --> Y, if we know X, Y are ordered appropriately
3951 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3952 --i;
3953 --e;
3954 }
3955 }
3956
3957 if (Ops.size() == 1) return Ops[0];
3958
3959 assert(!Ops.empty() && "Reduced smax down to nothing!");
3960
3961 // Okay, it looks like we really DO need an expr. Check to see if we
3962 // already have one, otherwise create a new one.
3964 ID.AddInteger(Kind);
3965 for (const SCEV *Op : Ops)
3966 ID.AddPointer(Op);
3967 void *IP = nullptr;
3968 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3969 if (ExistingSCEV)
3970 return ExistingSCEV;
3971 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3973 SCEV *S = new (SCEVAllocator)
3974 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3975
3976 UniqueSCEVs.InsertNode(S, IP);
3977 registerUser(S, Ops);
3978 return S;
3979}
3980
3981namespace {
3982
3983class SCEVSequentialMinMaxDeduplicatingVisitor final
3984 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3985 std::optional<const SCEV *>> {
3986 using RetVal = std::optional<const SCEV *>;
3988
3989 ScalarEvolution &SE;
3990 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3991 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3993
3994 bool canRecurseInto(SCEVTypes Kind) const {
3995 // We can only recurse into the SCEV expression of the same effective type
3996 // as the type of our root SCEV expression.
3997 return RootKind == Kind || NonSequentialRootKind == Kind;
3998 };
3999
4000 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4001 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
4002 "Only for min/max expressions.");
4003 SCEVTypes Kind = S->getSCEVType();
4004
4005 if (!canRecurseInto(Kind))
4006 return S;
4007
4008 auto *NAry = cast<SCEVNAryExpr>(S);
4010 bool Changed = visit(Kind, NAry->operands(), NewOps);
4011
4012 if (!Changed)
4013 return S;
4014 if (NewOps.empty())
4015 return std::nullopt;
4016
4017 return isa<SCEVSequentialMinMaxExpr>(S)
4018 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4019 : SE.getMinMaxExpr(Kind, NewOps);
4020 }
4021
4022 RetVal visit(const SCEV *S) {
4023 // Has the whole operand been seen already?
4024 if (!SeenOps.insert(S).second)
4025 return std::nullopt;
4026 return Base::visit(S);
4027 }
4028
4029public:
4030 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4031 SCEVTypes RootKind)
4032 : SE(SE), RootKind(RootKind),
4033 NonSequentialRootKind(
4034 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4035 RootKind)) {}
4036
4037 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4039 bool Changed = false;
4041 Ops.reserve(OrigOps.size());
4042
4043 for (const SCEV *Op : OrigOps) {
4044 RetVal NewOp = visit(Op);
4045 if (NewOp != Op)
4046 Changed = true;
4047 if (NewOp)
4048 Ops.emplace_back(*NewOp);
4049 }
4050
4051 if (Changed)
4052 NewOps = std::move(Ops);
4053 return Changed;
4054 }
4055
4056 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4057
4058 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4059
4060 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4061
4062 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4063
4064 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4065
4066 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4067
4068 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4069
4070 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4071
4072 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4073
4074 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4075
4076 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4077 return visitAnyMinMaxExpr(Expr);
4078 }
4079
4080 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4081 return visitAnyMinMaxExpr(Expr);
4082 }
4083
4084 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4085 return visitAnyMinMaxExpr(Expr);
4086 }
4087
4088 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4089 return visitAnyMinMaxExpr(Expr);
4090 }
4091
4092 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4093 return visitAnyMinMaxExpr(Expr);
4094 }
4095
4096 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4097
4098 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4099};
4100
4101} // namespace
4102
4104 switch (Kind) {
4105 case scConstant:
4106 case scVScale:
4107 case scTruncate:
4108 case scZeroExtend:
4109 case scSignExtend:
4110 case scPtrToInt:
4111 case scAddExpr:
4112 case scMulExpr:
4113 case scUDivExpr:
4114 case scAddRecExpr:
4115 case scUMaxExpr:
4116 case scSMaxExpr:
4117 case scUMinExpr:
4118 case scSMinExpr:
4119 case scUnknown:
4120 // If any operand is poison, the whole expression is poison.
4121 return true;
4123 // FIXME: if the *first* operand is poison, the whole expression is poison.
4124 return false; // Pessimistically, say that it does not propagate poison.
4125 case scCouldNotCompute:
4126 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4127 }
4128 llvm_unreachable("Unknown SCEV kind!");
4129}
4130
4131namespace {
4132// The only way poison may be introduced in a SCEV expression is from a
4133// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4134// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4135// introduce poison -- they encode guaranteed, non-speculated knowledge.
4136//
4137// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4138// with the notable exception of umin_seq, where only poison from the first
4139// operand is (unconditionally) propagated.
4140struct SCEVPoisonCollector {
4141 bool LookThroughMaybePoisonBlocking;
4143 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4144 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4145
4146 bool follow(const SCEV *S) {
4147 if (!LookThroughMaybePoisonBlocking &&
4149 return false;
4150
4151 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4152 if (!isGuaranteedNotToBePoison(SU->getValue()))
4153 MaybePoison.insert(SU);
4154 }
4155 return true;
4156 }
4157 bool isDone() const { return false; }
4158};
4159} // namespace
4160
4161/// Return true if V is poison given that AssumedPoison is already poison.
4162static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4163 // First collect all SCEVs that might result in AssumedPoison to be poison.
4164 // We need to look through potentially poison-blocking operations here,
4165 // because we want to find all SCEVs that *might* result in poison, not only
4166 // those that are *required* to.
4167 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4168 visitAll(AssumedPoison, PC1);
4169
4170 // AssumedPoison is never poison. As the assumption is false, the implication
4171 // is true. Don't bother walking the other SCEV in this case.
4172 if (PC1.MaybePoison.empty())
4173 return true;
4174
4175 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4176 // as well. We cannot look through potentially poison-blocking operations
4177 // here, as their arguments only *may* make the result poison.
4178 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4179 visitAll(S, PC2);
4180
4181 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4182 // it will also make S poison by being part of PC2.MaybePoison.
4183 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4184}
4185
4187 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4188 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4189 visitAll(S, PC);
4190 for (const SCEVUnknown *SU : PC.MaybePoison)
4191 Result.insert(SU->getValue());
4192}
4193
4195 const SCEV *S, Instruction *I,
4196 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4197 // If the instruction cannot be poison, it's always safe to reuse.
4199 return true;
4200
4201 // Otherwise, it is possible that I is more poisonous that S. Collect the
4202 // poison-contributors of S, and then check whether I has any additional
4203 // poison-contributors. Poison that is contributed through poison-generating
4204 // flags is handled by dropping those flags instead.
4206 getPoisonGeneratingValues(PoisonVals, S);
4207
4208 SmallVector<Value *> Worklist;
4210 Worklist.push_back(I);
4211 while (!Worklist.empty()) {
4212 Value *V = Worklist.pop_back_val();
4213 if (!Visited.insert(V).second)
4214 continue;
4215
4216 // Avoid walking large instruction graphs.
4217 if (Visited.size() > 16)
4218 return false;
4219
4220 // Either the value can't be poison, or the S would also be poison if it
4221 // is.
4222 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4223 continue;
4224
4225 auto *I = dyn_cast<Instruction>(V);
4226 if (!I)
4227 return false;
4228
4229 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4230 // can't replace an arbitrary add with disjoint or, even if we drop the
4231 // flag. We would need to convert the or into an add.
4232 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4233 if (PDI->isDisjoint())
4234 return false;
4235
4236 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4237 // because SCEV currently assumes it can't be poison. Remove this special
4238 // case once we proper model when vscale can be poison.
4239 if (auto *II = dyn_cast<IntrinsicInst>(I);
4240 II && II->getIntrinsicID() == Intrinsic::vscale)
4241 continue;
4242
4243 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4244 return false;
4245
4246 // If the instruction can't create poison, we can recurse to its operands.
4247 if (I->hasPoisonGeneratingAnnotations())
4248 DropPoisonGeneratingInsts.push_back(I);
4249
4250 llvm::append_range(Worklist, I->operands());
4251 }
4252 return true;
4253}
4254
4255const SCEV *
4258 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4259 "Not a SCEVSequentialMinMaxExpr!");
4260 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4261 if (Ops.size() == 1)
4262 return Ops[0];
4263#ifndef NDEBUG
4264 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4265 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4266 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4267 "Operand types don't match!");
4268 assert(Ops[0]->getType()->isPointerTy() ==
4269 Ops[i]->getType()->isPointerTy() &&
4270 "min/max should be consistently pointerish");
4271 }
4272#endif
4273
4274 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4275 // so we can *NOT* do any kind of sorting of the expressions!
4276
4277 // Check if we have created the same expression before.
4278 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4279 return S;
4280
4281 // FIXME: there are *some* simplifications that we can do here.
4282
4283 // Keep only the first instance of an operand.
4284 {
4285 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4286 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4287 if (Changed)
4288 return getSequentialMinMaxExpr(Kind, Ops);
4289 }
4290
4291 // Check to see if one of the operands is of the same kind. If so, expand its
4292 // operands onto our operand list, and recurse to simplify.
4293 {
4294 unsigned Idx = 0;
4295 bool DeletedAny = false;
4296 while (Idx < Ops.size()) {
4297 if (Ops[Idx]->getSCEVType() != Kind) {
4298 ++Idx;
4299 continue;
4300 }
4301 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4302 Ops.erase(Ops.begin() + Idx);
4303 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4304 SMME->operands().end());
4305 DeletedAny = true;
4306 }
4307
4308 if (DeletedAny)
4309 return getSequentialMinMaxExpr(Kind, Ops);
4310 }
4311
4312 const SCEV *SaturationPoint;
4314 switch (Kind) {
4316 SaturationPoint = getZero(Ops[0]->getType());
4317 Pred = ICmpInst::ICMP_ULE;
4318 break;
4319 default:
4320 llvm_unreachable("Not a sequential min/max type.");
4321 }
4322
4323 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4324 if (!isGuaranteedNotToCauseUB(Ops[i]))
4325 continue;
4326 // We can replace %x umin_seq %y with %x umin %y if either:
4327 // * %y being poison implies %x is also poison.
4328 // * %x cannot be the saturating value (e.g. zero for umin).
4329 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4330 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4331 SaturationPoint)) {
4332 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4333 Ops[i - 1] = getMinMaxExpr(
4335 SeqOps);
4336 Ops.erase(Ops.begin() + i);
4337 return getSequentialMinMaxExpr(Kind, Ops);
4338 }
4339 // Fold %x umin_seq %y to %x if %x ule %y.
4340 // TODO: We might be able to prove the predicate for a later operand.
4341 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4342 Ops.erase(Ops.begin() + i);
4343 return getSequentialMinMaxExpr(Kind, Ops);
4344 }
4345 }
4346
4347 // Okay, it looks like we really DO need an expr. Check to see if we
4348 // already have one, otherwise create a new one.
4350 ID.AddInteger(Kind);
4351 for (const SCEV *Op : Ops)
4352 ID.AddPointer(Op);
4353 void *IP = nullptr;
4354 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4355 if (ExistingSCEV)
4356 return ExistingSCEV;
4357
4358 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4360 SCEV *S = new (SCEVAllocator)
4361 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4362
4363 UniqueSCEVs.InsertNode(S, IP);
4364 registerUser(S, Ops);
4365 return S;
4366}
4367
4368const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4370 return getSMaxExpr(Ops);
4371}
4372
4374 return getMinMaxExpr(scSMaxExpr, Ops);
4375}
4376
4377const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4379 return getUMaxExpr(Ops);
4380}
4381
4383 return getMinMaxExpr(scUMaxExpr, Ops);
4384}
4385
4387 const SCEV *RHS) {
4389 return getSMinExpr(Ops);
4390}
4391
4393 return getMinMaxExpr(scSMinExpr, Ops);
4394}
4395
4396const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4397 bool Sequential) {
4399 return getUMinExpr(Ops, Sequential);
4400}
4401
4403 bool Sequential) {
4404 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4405 : getMinMaxExpr(scUMinExpr, Ops);
4406}
4407
4408const SCEV *
4410 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4411 if (Size.isScalable())
4412 Res = getMulExpr(Res, getVScale(IntTy));
4413 return Res;
4414}
4415
4417 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4418}
4419
4421 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4422}
4423
4425 StructType *STy,
4426 unsigned FieldNo) {
4427 // We can bypass creating a target-independent constant expression and then
4428 // folding it back into a ConstantInt. This is just a compile-time
4429 // optimization.
4430 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4431 assert(!SL->getSizeInBits().isScalable() &&
4432 "Cannot get offset for structure containing scalable vector types");
4433 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4434}
4435
4437 // Don't attempt to do anything other than create a SCEVUnknown object
4438 // here. createSCEV only calls getUnknown after checking for all other
4439 // interesting possibilities, and any other code that calls getUnknown
4440 // is doing so in order to hide a value from SCEV canonicalization.
4441
4443 ID.AddInteger(scUnknown);
4444 ID.AddPointer(V);
4445 void *IP = nullptr;
4446 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4447 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4448 "Stale SCEVUnknown in uniquing map!");
4449 return S;
4450 }
4451 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4452 FirstUnknown);
4453 FirstUnknown = cast<SCEVUnknown>(S);
4454 UniqueSCEVs.InsertNode(S, IP);
4455 return S;
4456}
4457
4458//===----------------------------------------------------------------------===//
4459// Basic SCEV Analysis and PHI Idiom Recognition Code
4460//
4461
4462/// Test if values of the given type are analyzable within the SCEV
4463/// framework. This primarily includes integer types, and it can optionally
4464/// include pointer types if the ScalarEvolution class has access to
4465/// target-specific information.
4467 // Integers and pointers are always SCEVable.
4468 return Ty->isIntOrPtrTy();
4469}
4470
4471/// Return the size in bits of the specified type, for which isSCEVable must
4472/// return true.
4474 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4475 if (Ty->isPointerTy())
4477 return getDataLayout().getTypeSizeInBits(Ty);
4478}
4479
4480/// Return a type with the same bitwidth as the given type and which represents
4481/// how SCEV will treat the given type, for which isSCEVable must return
4482/// true. For pointer types, this is the pointer index sized integer type.
4484 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4485
4486 if (Ty->isIntegerTy())
4487 return Ty;
4488
4489 // The only other support type is pointer.
4490 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4491 return getDataLayout().getIndexType(Ty);
4492}
4493
4495 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4496}
4497
4499 const SCEV *B) {
4500 /// For a valid use point to exist, the defining scope of one operand
4501 /// must dominate the other.
4502 bool PreciseA, PreciseB;
4503 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4504 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4505 if (!PreciseA || !PreciseB)
4506 // Can't tell.
4507 return false;
4508 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4509 DT.dominates(ScopeB, ScopeA);
4510}
4511
4513 return CouldNotCompute.get();
4514}
4515
4516bool ScalarEvolution::checkValidity(const SCEV *S) const {
4517 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4518 auto *SU = dyn_cast<SCEVUnknown>(S);
4519 return SU && SU->getValue() == nullptr;
4520 });
4521
4522 return !ContainsNulls;
4523}
4524
4526 HasRecMapType::iterator I = HasRecMap.find(S);
4527 if (I != HasRecMap.end())
4528 return I->second;
4529
4530 bool FoundAddRec =
4531 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4532 HasRecMap.insert({S, FoundAddRec});
4533 return FoundAddRec;
4534}
4535
4536/// Return the ValueOffsetPair set for \p S. \p S can be represented
4537/// by the value and offset from any ValueOffsetPair in the set.
4538ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4539 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4540 if (SI == ExprValueMap.end())
4541 return {};
4542 return SI->second.getArrayRef();
4543}
4544
4545/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4546/// cannot be used separately. eraseValueFromMap should be used to remove
4547/// V from ValueExprMap and ExprValueMap at the same time.
4548void ScalarEvolution::eraseValueFromMap(Value *V) {
4549 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4550 if (I != ValueExprMap.end()) {
4551 auto EVIt = ExprValueMap.find(I->second);
4552 bool Removed = EVIt->second.remove(V);
4553 (void) Removed;
4554 assert(Removed && "Value not in ExprValueMap?");
4555 ValueExprMap.erase(I);
4556 }
4557}
4558
4559void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4560 // A recursive query may have already computed the SCEV. It should be
4561 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4562 // inferred nowrap flags.
4563 auto It = ValueExprMap.find_as(V);
4564 if (It == ValueExprMap.end()) {
4565 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4566 ExprValueMap[S].insert(V);
4567 }
4568}
4569
4570/// Return an existing SCEV if it exists, otherwise analyze the expression and
4571/// create a new one.
4573 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4574
4575 if (const SCEV *S = getExistingSCEV(V))
4576 return S;
4577 return createSCEVIter(V);
4578}
4579
4581 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4582
4583 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4584 if (I != ValueExprMap.end()) {
4585 const SCEV *S = I->second;
4586 assert(checkValidity(S) &&
4587 "existing SCEV has not been properly invalidated");
4588 return S;
4589 }
4590 return nullptr;
4591}
4592
4593/// Return a SCEV corresponding to -V = -1*V
4595 SCEV::NoWrapFlags Flags) {
4596 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4597 return getConstant(
4598 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4599
4600 Type *Ty = V->getType();
4601 Ty = getEffectiveSCEVType(Ty);
4602 return getMulExpr(V, getMinusOne(Ty), Flags);
4603}
4604
4605/// If Expr computes ~A, return A else return nullptr
4606static const SCEV *MatchNotExpr(const SCEV *Expr) {
4607 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4608 if (!Add || Add->getNumOperands() != 2 ||
4609 !Add->getOperand(0)->isAllOnesValue())
4610 return nullptr;
4611
4612 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4613 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4614 !AddRHS->getOperand(0)->isAllOnesValue())
4615 return nullptr;
4616
4617 return AddRHS->getOperand(1);
4618}
4619
4620/// Return a SCEV corresponding to ~V = -1-V
4622 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4623
4624 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4625 return getConstant(
4626 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4627
4628 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4629 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4630 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4631 SmallVector<const SCEV *, 2> MatchedOperands;
4632 for (const SCEV *Operand : MME->operands()) {
4633 const SCEV *Matched = MatchNotExpr(Operand);
4634 if (!Matched)
4635 return (const SCEV *)nullptr;
4636 MatchedOperands.push_back(Matched);
4637 }
4638 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4639 MatchedOperands);
4640 };
4641 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4642 return Replaced;
4643 }
4644
4645 Type *Ty = V->getType();
4646 Ty = getEffectiveSCEVType(Ty);
4647 return getMinusSCEV(getMinusOne(Ty), V);
4648}
4649
4651 assert(P->getType()->isPointerTy());
4652
4653 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4654 // The base of an AddRec is the first operand.
4655 SmallVector<const SCEV *> Ops{AddRec->operands()};
4656 Ops[0] = removePointerBase(Ops[0]);
4657 // Don't try to transfer nowrap flags for now. We could in some cases
4658 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4659 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4660 }
4661 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4662 // The base of an Add is the pointer operand.
4663 SmallVector<const SCEV *> Ops{Add->operands()};
4664 const SCEV **PtrOp = nullptr;
4665 for (const SCEV *&AddOp : Ops) {
4666 if (AddOp->getType()->isPointerTy()) {
4667 assert(!PtrOp && "Cannot have multiple pointer ops");
4668 PtrOp = &AddOp;
4669 }
4670 }
4671 *PtrOp = removePointerBase(*PtrOp);
4672 // Don't try to transfer nowrap flags for now. We could in some cases
4673 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4674 return getAddExpr(Ops);
4675 }
4676 // Any other expression must be a pointer base.
4677 return getZero(P->getType());
4678}
4679
4680const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4681 SCEV::NoWrapFlags Flags,
4682 unsigned Depth) {
4683 // Fast path: X - X --> 0.
4684 if (LHS == RHS)
4685 return getZero(LHS->getType());
4686
4687 // If we subtract two pointers with different pointer bases, bail.
4688 // Eventually, we're going to add an assertion to getMulExpr that we
4689 // can't multiply by a pointer.
4690 if (RHS->getType()->isPointerTy()) {
4691 if (!LHS->getType()->isPointerTy() ||
4693 return getCouldNotCompute();
4696 }
4697
4698 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4699 // makes it so that we cannot make much use of NUW.
4700 auto AddFlags = SCEV::FlagAnyWrap;
4701 const bool RHSIsNotMinSigned =
4703 if (hasFlags(Flags, SCEV::FlagNSW)) {
4704 // Let M be the minimum representable signed value. Then (-1)*RHS
4705 // signed-wraps if and only if RHS is M. That can happen even for
4706 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4707 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4708 // (-1)*RHS, we need to prove that RHS != M.
4709 //
4710 // If LHS is non-negative and we know that LHS - RHS does not
4711 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4712 // either by proving that RHS > M or that LHS >= 0.
4713 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4714 AddFlags = SCEV::FlagNSW;
4715 }
4716 }
4717
4718 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4719 // RHS is NSW and LHS >= 0.
4720 //
4721 // The difficulty here is that the NSW flag may have been proven
4722 // relative to a loop that is to be found in a recurrence in LHS and
4723 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4724 // larger scope than intended.
4725 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4726
4727 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4728}
4729
4731 unsigned Depth) {
4732 Type *SrcTy = V->getType();
4733 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4734 "Cannot truncate or zero extend with non-integer arguments!");
4735 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4736 return V; // No conversion
4737 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4738 return getTruncateExpr(V, Ty, Depth);
4739 return getZeroExtendExpr(V, Ty, Depth);
4740}
4741
4743 unsigned Depth) {
4744 Type *SrcTy = V->getType();
4745 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4746 "Cannot truncate or zero extend with non-integer arguments!");
4747 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4748 return V; // No conversion
4749 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4750 return getTruncateExpr(V, Ty, Depth);
4751 return getSignExtendExpr(V, Ty, Depth);
4752}
4753
4754const SCEV *
4756 Type *SrcTy = V->getType();
4757 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4758 "Cannot noop or zero extend with non-integer arguments!");
4760 "getNoopOrZeroExtend cannot truncate!");
4761 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4762 return V; // No conversion
4763 return getZeroExtendExpr(V, Ty);
4764}
4765
4766const SCEV *
4768 Type *SrcTy = V->getType();
4769 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4770 "Cannot noop or sign extend with non-integer arguments!");
4772 "getNoopOrSignExtend cannot truncate!");
4773 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4774 return V; // No conversion
4775 return getSignExtendExpr(V, Ty);
4776}
4777
4778const SCEV *
4780 Type *SrcTy = V->getType();
4781 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4782 "Cannot noop or any extend with non-integer arguments!");
4784 "getNoopOrAnyExtend cannot truncate!");
4785 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4786 return V; // No conversion
4787 return getAnyExtendExpr(V, Ty);
4788}
4789
4790const SCEV *
4792 Type *SrcTy = V->getType();
4793 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4794 "Cannot truncate or noop with non-integer arguments!");
4796 "getTruncateOrNoop cannot extend!");
4797 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4798 return V; // No conversion
4799 return getTruncateExpr(V, Ty);
4800}
4801
4803 const SCEV *RHS) {
4804 const SCEV *PromotedLHS = LHS;
4805 const SCEV *PromotedRHS = RHS;
4806
4808 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4809 else
4810 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4811
4812 return getUMaxExpr(PromotedLHS, PromotedRHS);
4813}
4814
4816 const SCEV *RHS,
4817 bool Sequential) {
4819 return getUMinFromMismatchedTypes(Ops, Sequential);
4820}
4821
4822const SCEV *
4824 bool Sequential) {
4825 assert(!Ops.empty() && "At least one operand must be!");
4826 // Trivial case.
4827 if (Ops.size() == 1)
4828 return Ops[0];
4829
4830 // Find the max type first.
4831 Type *MaxType = nullptr;
4832 for (const auto *S : Ops)
4833 if (MaxType)
4834 MaxType = getWiderType(MaxType, S->getType());
4835 else
4836 MaxType = S->getType();
4837 assert(MaxType && "Failed to find maximum type!");
4838
4839 // Extend all ops to max type.
4840 SmallVector<const SCEV *, 2> PromotedOps;
4841 for (const auto *S : Ops)
4842 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4843
4844 // Generate umin.
4845 return getUMinExpr(PromotedOps, Sequential);
4846}
4847
4849 // A pointer operand may evaluate to a nonpointer expression, such as null.
4850 if (!V->getType()->isPointerTy())
4851 return V;
4852
4853 while (true) {
4854 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4855 V = AddRec->getStart();
4856 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4857 const SCEV *PtrOp = nullptr;
4858 for (const SCEV *AddOp : Add->operands()) {
4859 if (AddOp->getType()->isPointerTy()) {
4860 assert(!PtrOp && "Cannot have multiple pointer ops");
4861 PtrOp = AddOp;
4862 }
4863 }
4864 assert(PtrOp && "Must have pointer op");
4865 V = PtrOp;
4866 } else // Not something we can look further into.
4867 return V;
4868 }
4869}
4870
4871/// Push users of the given Instruction onto the given Worklist.
4875 // Push the def-use children onto the Worklist stack.
4876 for (User *U : I->users()) {
4877 auto *UserInsn = cast<Instruction>(U);
4878 if (Visited.insert(UserInsn).second)
4879 Worklist.push_back(UserInsn);
4880 }
4881}
4882
4883namespace {
4884
4885/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4886/// expression in case its Loop is L. If it is not L then
4887/// if IgnoreOtherLoops is true then use AddRec itself
4888/// otherwise rewrite cannot be done.
4889/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4890class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4891public:
4892 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4893 bool IgnoreOtherLoops = true) {
4894 SCEVInitRewriter Rewriter(L, SE);
4895 const SCEV *Result = Rewriter.visit(S);
4896 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4897 return SE.getCouldNotCompute();
4898 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4899 ? SE.getCouldNotCompute()
4900 : Result;
4901 }
4902
4903 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4904 if (!SE.isLoopInvariant(Expr, L))
4905 SeenLoopVariantSCEVUnknown = true;
4906 return Expr;
4907 }
4908
4909 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4910 // Only re-write AddRecExprs for this loop.
4911 if (Expr->getLoop() == L)
4912 return Expr->getStart();
4913 SeenOtherLoops = true;
4914 return Expr;
4915 }
4916
4917 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4918
4919 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4920
4921private:
4922 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4923 : SCEVRewriteVisitor(SE), L(L) {}
4924
4925 const Loop *L;
4926 bool SeenLoopVariantSCEVUnknown = false;
4927 bool SeenOtherLoops = false;
4928};
4929
4930/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4931/// increment expression in case its Loop is L. If it is not L then
4932/// use AddRec itself.
4933/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4934class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4935public:
4936 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4937 SCEVPostIncRewriter Rewriter(L, SE);
4938 const SCEV *Result = Rewriter.visit(S);
4939 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4940 ? SE.getCouldNotCompute()
4941 : Result;
4942 }
4943
4944 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4945 if (!SE.isLoopInvariant(Expr, L))
4946 SeenLoopVariantSCEVUnknown = true;
4947 return Expr;
4948 }
4949
4950 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4951 // Only re-write AddRecExprs for this loop.
4952 if (Expr->getLoop() == L)
4953 return Expr->getPostIncExpr(SE);
4954 SeenOtherLoops = true;
4955 return Expr;
4956 }
4957
4958 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4959
4960 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4961
4962private:
4963 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4964 : SCEVRewriteVisitor(SE), L(L) {}
4965
4966 const Loop *L;
4967 bool SeenLoopVariantSCEVUnknown = false;
4968 bool SeenOtherLoops = false;
4969};
4970
4971/// This class evaluates the compare condition by matching it against the
4972/// condition of loop latch. If there is a match we assume a true value
4973/// for the condition while building SCEV nodes.
4974class SCEVBackedgeConditionFolder
4975 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4976public:
4977 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4978 ScalarEvolution &SE) {
4979 bool IsPosBECond = false;
4980 Value *BECond = nullptr;
4981 if (BasicBlock *Latch = L->getLoopLatch()) {
4982 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4983 if (BI && BI->isConditional()) {
4984 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4985 "Both outgoing branches should not target same header!");
4986 BECond = BI->getCondition();
4987 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4988 } else {
4989 return S;
4990 }
4991 }
4992 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4993 return Rewriter.visit(S);
4994 }
4995
4996 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4997 const SCEV *Result = Expr;
4998 bool InvariantF = SE.isLoopInvariant(Expr, L);
4999
5000 if (!InvariantF) {
5001 Instruction *I = cast<Instruction>(Expr->getValue());
5002 switch (I->getOpcode()) {
5003 case Instruction::Select: {
5004 SelectInst *SI = cast<SelectInst>(I);
5005 std::optional<const SCEV *> Res =
5006 compareWithBackedgeCondition(SI->getCondition());
5007 if (Res) {
5008 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5009 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5010 }
5011 break;
5012 }
5013 default: {
5014 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5015 if (Res)
5016 Result = *Res;
5017 break;
5018 }
5019 }
5020 }
5021 return Result;
5022 }
5023
5024private:
5025 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5026 bool IsPosBECond, ScalarEvolution &SE)
5027 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5028 IsPositiveBECond(IsPosBECond) {}
5029
5030 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5031
5032 const Loop *L;
5033 /// Loop back condition.
5034 Value *BackedgeCond = nullptr;
5035 /// Set to true if loop back is on positive branch condition.
5036 bool IsPositiveBECond;
5037};
5038
5039std::optional<const SCEV *>
5040SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5041
5042 // If value matches the backedge condition for loop latch,
5043 // then return a constant evolution node based on loopback
5044 // branch taken.
5045 if (BackedgeCond == IC)
5046 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5048 return std::nullopt;
5049}
5050
5051class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5052public:
5053 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5054 ScalarEvolution &SE) {
5055 SCEVShiftRewriter Rewriter(L, SE);
5056 const SCEV *Result = Rewriter.visit(S);
5057 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5058 }
5059
5060 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5061 // Only allow AddRecExprs for this loop.
5062 if (!SE.isLoopInvariant(Expr, L))
5063 Valid = false;
5064 return Expr;
5065 }
5066
5067 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5068 if (Expr->getLoop() == L && Expr->isAffine())
5069 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5070 Valid = false;
5071 return Expr;
5072 }
5073
5074 bool isValid() { return Valid; }
5075
5076private:
5077 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5078 : SCEVRewriteVisitor(SE), L(L) {}
5079
5080 const Loop *L;
5081 bool Valid = true;
5082};
5083
5084} // end anonymous namespace
5085
5087ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5088 if (!AR->isAffine())
5089 return SCEV::FlagAnyWrap;
5090
5091 using OBO = OverflowingBinaryOperator;
5092
5094
5095 if (!AR->hasNoSelfWrap()) {
5096 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5097 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5098 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5099 const APInt &BECountAP = BECountMax->getAPInt();
5100 unsigned NoOverflowBitWidth =
5101 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5102 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5104 }
5105 }
5106
5107 if (!AR->hasNoSignedWrap()) {
5108 ConstantRange AddRecRange = getSignedRange(AR);
5109 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5110
5112 Instruction::Add, IncRange, OBO::NoSignedWrap);
5113 if (NSWRegion.contains(AddRecRange))
5115 }
5116
5117 if (!AR->hasNoUnsignedWrap()) {
5118 ConstantRange AddRecRange = getUnsignedRange(AR);
5119 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5120
5122 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5123 if (NUWRegion.contains(AddRecRange))
5125 }
5126
5127 return Result;
5128}
5129
5131ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5133
5134 if (AR->hasNoSignedWrap())
5135 return Result;
5136
5137 if (!AR->isAffine())
5138 return Result;
5139
5140 // This function can be expensive, only try to prove NSW once per AddRec.
5141 if (!SignedWrapViaInductionTried.insert(AR).second)
5142 return Result;
5143
5144 const SCEV *Step = AR->getStepRecurrence(*this);
5145 const Loop *L = AR->getLoop();
5146
5147 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5148 // Note that this serves two purposes: It filters out loops that are
5149 // simply not analyzable, and it covers the case where this code is
5150 // being called from within backedge-taken count analysis, such that
5151 // attempting to ask for the backedge-taken count would likely result
5152 // in infinite recursion. In the later case, the analysis code will
5153 // cope with a conservative value, and it will take care to purge
5154 // that value once it has finished.
5155 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5156
5157 // Normally, in the cases we can prove no-overflow via a
5158 // backedge guarding condition, we can also compute a backedge
5159 // taken count for the loop. The exceptions are assumptions and
5160 // guards present in the loop -- SCEV is not great at exploiting
5161 // these to compute max backedge taken counts, but can still use
5162 // these to prove lack of overflow. Use this fact to avoid
5163 // doing extra work that may not pay off.
5164
5165 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5166 AC.assumptions().empty())
5167 return Result;
5168
5169 // If the backedge is guarded by a comparison with the pre-inc value the
5170 // addrec is safe. Also, if the entry is guarded by a comparison with the
5171 // start value and the backedge is guarded by a comparison with the post-inc
5172 // value, the addrec is safe.
5174 const SCEV *OverflowLimit =
5175 getSignedOverflowLimitForStep(Step, &Pred, this);
5176 if (OverflowLimit &&
5177 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5178 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5179 Result = setFlags(Result, SCEV::FlagNSW);
5180 }
5181 return Result;
5182}
5184ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5186
5187 if (AR->hasNoUnsignedWrap())
5188 return Result;
5189
5190 if (!AR->isAffine())
5191 return Result;
5192
5193 // This function can be expensive, only try to prove NUW once per AddRec.
5194 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5195 return Result;
5196
5197 const SCEV *Step = AR->getStepRecurrence(*this);
5198 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5199 const Loop *L = AR->getLoop();
5200
5201 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5202 // Note that this serves two purposes: It filters out loops that are
5203 // simply not analyzable, and it covers the case where this code is
5204 // being called from within backedge-taken count analysis, such that
5205 // attempting to ask for the backedge-taken count would likely result
5206 // in infinite recursion. In the later case, the analysis code will
5207 // cope with a conservative value, and it will take care to purge
5208 // that value once it has finished.
5209 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5210
5211 // Normally, in the cases we can prove no-overflow via a
5212 // backedge guarding condition, we can also compute a backedge
5213 // taken count for the loop. The exceptions are assumptions and
5214 // guards present in the loop -- SCEV is not great at exploiting
5215 // these to compute max backedge taken counts, but can still use
5216 // these to prove lack of overflow. Use this fact to avoid
5217 // doing extra work that may not pay off.
5218
5219 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5220 AC.assumptions().empty())
5221 return Result;
5222
5223 // If the backedge is guarded by a comparison with the pre-inc value the
5224 // addrec is safe. Also, if the entry is guarded by a comparison with the
5225 // start value and the backedge is guarded by a comparison with the post-inc
5226 // value, the addrec is safe.
5227 if (isKnownPositive(Step)) {
5229 getUnsignedRangeMax(Step));
5232 Result = setFlags(Result, SCEV::FlagNUW);
5233 }
5234 }
5235
5236 return Result;
5237}
5238
5239namespace {
5240
5241/// Represents an abstract binary operation. This may exist as a
5242/// normal instruction or constant expression, or may have been
5243/// derived from an expression tree.
5244struct BinaryOp {
5245 unsigned Opcode;
5246 Value *LHS;
5247 Value *RHS;
5248 bool IsNSW = false;
5249 bool IsNUW = false;
5250
5251 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5252 /// constant expression.
5253 Operator *Op = nullptr;
5254
5255 explicit BinaryOp(Operator *Op)
5256 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5257 Op(Op) {
5258 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5259 IsNSW = OBO->hasNoSignedWrap();
5260 IsNUW = OBO->hasNoUnsignedWrap();
5261 }
5262 }
5263
5264 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5265 bool IsNUW = false)
5266 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5267};
5268
5269} // end anonymous namespace
5270
5271/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5272static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5273 AssumptionCache &AC,
5274 const DominatorTree &DT,
5275 const Instruction *CxtI) {
5276 auto *Op = dyn_cast<Operator>(V);
5277 if (!Op)
5278 return std::nullopt;
5279
5280 // Implementation detail: all the cleverness here should happen without
5281 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5282 // SCEV expressions when possible, and we should not break that.
5283
5284 switch (Op->getOpcode()) {
5285 case Instruction::Add:
5286 case Instruction::Sub:
5287 case Instruction::Mul:
5288 case Instruction::UDiv:
5289 case Instruction::URem:
5290 case Instruction::And:
5291 case Instruction::AShr:
5292 case Instruction::Shl:
5293 return BinaryOp(Op);
5294
5295 case Instruction::Or: {
5296 // Convert or disjoint into add nuw nsw.
5297 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5298 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5299 /*IsNSW=*/true, /*IsNUW=*/true);
5300 return BinaryOp(Op);
5301 }
5302
5303 case Instruction::Xor:
5304 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5305 // If the RHS of the xor is a signmask, then this is just an add.
5306 // Instcombine turns add of signmask into xor as a strength reduction step.
5307 if (RHSC->getValue().isSignMask())
5308 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5309 // Binary `xor` is a bit-wise `add`.
5310 if (V->getType()->isIntegerTy(1))
5311 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5312 return BinaryOp(Op);
5313
5314 case Instruction::LShr:
5315 // Turn logical shift right of a constant into a unsigned divide.
5316 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5317 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5318
5319 // If the shift count is not less than the bitwidth, the result of
5320 // the shift is undefined. Don't try to analyze it, because the
5321 // resolution chosen here may differ from the resolution chosen in
5322 // other parts of the compiler.
5323 if (SA->getValue().ult(BitWidth)) {
5324 Constant *X =
5325 ConstantInt::get(SA->getContext(),
5326 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5327 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5328 }
5329 }
5330 return BinaryOp(Op);
5331
5332 case Instruction::ExtractValue: {
5333 auto *EVI = cast<ExtractValueInst>(Op);
5334 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5335 break;
5336
5337 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5338 if (!WO)
5339 break;
5340
5341 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5342 bool Signed = WO->isSigned();
5343 // TODO: Should add nuw/nsw flags for mul as well.
5344 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5345 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5346
5347 // Now that we know that all uses of the arithmetic-result component of
5348 // CI are guarded by the overflow check, we can go ahead and pretend
5349 // that the arithmetic is non-overflowing.
5350 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5351 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5352 }
5353
5354 default:
5355 break;
5356 }
5357
5358 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5359 // semantics as a Sub, return a binary sub expression.
5360 if (auto *II = dyn_cast<IntrinsicInst>(V))
5361 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5362 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5363
5364 return std::nullopt;
5365}
5366
5367/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5368/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5369/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5370/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5371/// follows one of the following patterns:
5372/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5373/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5374/// If the SCEV expression of \p Op conforms with one of the expected patterns
5375/// we return the type of the truncation operation, and indicate whether the
5376/// truncated type should be treated as signed/unsigned by setting
5377/// \p Signed to true/false, respectively.
5378static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5379 bool &Signed, ScalarEvolution &SE) {
5380 // The case where Op == SymbolicPHI (that is, with no type conversions on
5381 // the way) is handled by the regular add recurrence creating logic and
5382 // would have already been triggered in createAddRecForPHI. Reaching it here
5383 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5384 // because one of the other operands of the SCEVAddExpr updating this PHI is
5385 // not invariant).
5386 //
5387 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5388 // this case predicates that allow us to prove that Op == SymbolicPHI will
5389 // be added.
5390 if (Op == SymbolicPHI)
5391 return nullptr;
5392
5393 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5394 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5395 if (SourceBits != NewBits)
5396 return nullptr;
5397
5398 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5399 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5400 if (!SExt && !ZExt)
5401 return nullptr;
5402 const SCEVTruncateExpr *Trunc =
5403 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5404 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5405 if (!Trunc)
5406 return nullptr;
5407 const SCEV *X = Trunc->getOperand();
5408 if (X != SymbolicPHI)
5409 return nullptr;
5410 Signed = SExt != nullptr;
5411 return Trunc->getType();
5412}
5413
5414static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5415 if (!PN->getType()->isIntegerTy())
5416 return nullptr;
5417 const Loop *L = LI.getLoopFor(PN->getParent());
5418 if (!L || L->getHeader() != PN->getParent())
5419 return nullptr;
5420 return L;
5421}
5422
5423// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5424// computation that updates the phi follows the following pattern:
5425// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5426// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5427// If so, try to see if it can be rewritten as an AddRecExpr under some
5428// Predicates. If successful, return them as a pair. Also cache the results
5429// of the analysis.
5430//
5431// Example usage scenario:
5432// Say the Rewriter is called for the following SCEV:
5433// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5434// where:
5435// %X = phi i64 (%Start, %BEValue)
5436// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5437// and call this function with %SymbolicPHI = %X.
5438//
5439// The analysis will find that the value coming around the backedge has
5440// the following SCEV:
5441// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5442// Upon concluding that this matches the desired pattern, the function
5443// will return the pair {NewAddRec, SmallPredsVec} where:
5444// NewAddRec = {%Start,+,%Step}
5445// SmallPredsVec = {P1, P2, P3} as follows:
5446// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5447// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5448// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5449// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5450// under the predicates {P1,P2,P3}.
5451// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5452// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5453//
5454// TODO's:
5455//
5456// 1) Extend the Induction descriptor to also support inductions that involve
5457// casts: When needed (namely, when we are called in the context of the
5458// vectorizer induction analysis), a Set of cast instructions will be
5459// populated by this method, and provided back to isInductionPHI. This is
5460// needed to allow the vectorizer to properly record them to be ignored by
5461// the cost model and to avoid vectorizing them (otherwise these casts,
5462// which are redundant under the runtime overflow checks, will be
5463// vectorized, which can be costly).
5464//
5465// 2) Support additional induction/PHISCEV patterns: We also want to support
5466// inductions where the sext-trunc / zext-trunc operations (partly) occur
5467// after the induction update operation (the induction increment):
5468//
5469// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5470// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5471//
5472// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5473// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5474//
5475// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5476std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5477ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5479
5480 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5481 // return an AddRec expression under some predicate.
5482
5483 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5484 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5485 assert(L && "Expecting an integer loop header phi");
5486
5487 // The loop may have multiple entrances or multiple exits; we can analyze
5488 // this phi as an addrec if it has a unique entry value and a unique
5489 // backedge value.
5490 Value *BEValueV = nullptr, *StartValueV = nullptr;
5491 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5492 Value *V = PN->getIncomingValue(i);
5493 if (L->contains(PN->getIncomingBlock(i))) {
5494 if (!BEValueV) {
5495 BEValueV = V;
5496 } else if (BEValueV != V) {
5497 BEValueV = nullptr;
5498 break;
5499 }
5500 } else if (!StartValueV) {
5501 StartValueV = V;
5502 } else if (StartValueV != V) {
5503 StartValueV = nullptr;
5504 break;
5505 }
5506 }
5507 if (!BEValueV || !StartValueV)
5508 return std::nullopt;
5509
5510 const SCEV *BEValue = getSCEV(BEValueV);
5511
5512 // If the value coming around the backedge is an add with the symbolic
5513 // value we just inserted, possibly with casts that we can ignore under
5514 // an appropriate runtime guard, then we found a simple induction variable!
5515 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5516 if (!Add)
5517 return std::nullopt;
5518
5519 // If there is a single occurrence of the symbolic value, possibly
5520 // casted, replace it with a recurrence.
5521 unsigned FoundIndex = Add->getNumOperands();
5522 Type *TruncTy = nullptr;
5523 bool Signed;
5524 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5525 if ((TruncTy =
5526 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5527 if (FoundIndex == e) {
5528 FoundIndex = i;
5529 break;
5530 }
5531
5532 if (FoundIndex == Add->getNumOperands())
5533 return std::nullopt;
5534
5535 // Create an add with everything but the specified operand.
5537 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5538 if (i != FoundIndex)
5539 Ops.push_back(Add->getOperand(i));
5540 const SCEV *Accum = getAddExpr(Ops);
5541
5542 // The runtime checks will not be valid if the step amount is
5543 // varying inside the loop.
5544 if (!isLoopInvariant(Accum, L))
5545 return std::nullopt;
5546
5547 // *** Part2: Create the predicates
5548
5549 // Analysis was successful: we have a phi-with-cast pattern for which we
5550 // can return an AddRec expression under the following predicates:
5551 //
5552 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5553 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5554 // P2: An Equal predicate that guarantees that
5555 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5556 // P3: An Equal predicate that guarantees that
5557 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5558 //
5559 // As we next prove, the above predicates guarantee that:
5560 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5561 //
5562 //
5563 // More formally, we want to prove that:
5564 // Expr(i+1) = Start + (i+1) * Accum
5565 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5566 //
5567 // Given that:
5568 // 1) Expr(0) = Start
5569 // 2) Expr(1) = Start + Accum
5570 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5571 // 3) Induction hypothesis (step i):
5572 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5573 //
5574 // Proof:
5575 // Expr(i+1) =
5576 // = Start + (i+1)*Accum
5577 // = (Start + i*Accum) + Accum
5578 // = Expr(i) + Accum
5579 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5580 // :: from step i
5581 //
5582 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5583 //
5584 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5585 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5586 // + Accum :: from P3
5587 //
5588 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5589 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5590 //
5591 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5592 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5593 //
5594 // By induction, the same applies to all iterations 1<=i<n:
5595 //
5596
5597 // Create a truncated addrec for which we will add a no overflow check (P1).
5598 const SCEV *StartVal = getSCEV(StartValueV);
5599 const SCEV *PHISCEV =
5600 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5601 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5602
5603 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5604 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5605 // will be constant.
5606 //
5607 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5608 // add P1.
5609 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5613 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5614 Predicates.push_back(AddRecPred);
5615 }
5616
5617 // Create the Equal Predicates P2,P3:
5618
5619 // It is possible that the predicates P2 and/or P3 are computable at
5620 // compile time due to StartVal and/or Accum being constants.
5621 // If either one is, then we can check that now and escape if either P2
5622 // or P3 is false.
5623
5624 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5625 // for each of StartVal and Accum
5626 auto getExtendedExpr = [&](const SCEV *Expr,
5627 bool CreateSignExtend) -> const SCEV * {
5628 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5629 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5630 const SCEV *ExtendedExpr =
5631 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5632 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5633 return ExtendedExpr;
5634 };
5635
5636 // Given:
5637 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5638 // = getExtendedExpr(Expr)
5639 // Determine whether the predicate P: Expr == ExtendedExpr
5640 // is known to be false at compile time
5641 auto PredIsKnownFalse = [&](const SCEV *Expr,
5642 const SCEV *ExtendedExpr) -> bool {
5643 return Expr != ExtendedExpr &&
5644 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5645 };
5646
5647 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5648 if (PredIsKnownFalse(StartVal, StartExtended)) {
5649 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5650 return std::nullopt;
5651 }
5652
5653 // The Step is always Signed (because the overflow checks are either
5654 // NSSW or NUSW)
5655 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5656 if (PredIsKnownFalse(Accum, AccumExtended)) {
5657 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5658 return std::nullopt;
5659 }
5660
5661 auto AppendPredicate = [&](const SCEV *Expr,
5662 const SCEV *ExtendedExpr) -> void {
5663 if (Expr != ExtendedExpr &&
5664 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5665 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5666 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5667 Predicates.push_back(Pred);
5668 }
5669 };
5670
5671 AppendPredicate(StartVal, StartExtended);
5672 AppendPredicate(Accum, AccumExtended);
5673
5674 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5675 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5676 // into NewAR if it will also add the runtime overflow checks specified in
5677 // Predicates.
5678 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5679
5680 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5681 std::make_pair(NewAR, Predicates);
5682 // Remember the result of the analysis for this SCEV at this locayyytion.
5683 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5684 return PredRewrite;
5685}
5686
5687std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5689 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5690 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5691 if (!L)
5692 return std::nullopt;
5693
5694 // Check to see if we already analyzed this PHI.
5695 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5696 if (I != PredicatedSCEVRewrites.end()) {
5697 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5698 I->second;
5699 // Analysis was done before and failed to create an AddRec:
5700 if (Rewrite.first == SymbolicPHI)
5701 return std::nullopt;
5702 // Analysis was done before and succeeded to create an AddRec under
5703 // a predicate:
5704 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5705 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5706 return Rewrite;
5707 }
5708
5709 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5710 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5711
5712 // Record in the cache that the analysis failed
5713 if (!Rewrite) {
5715 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5716 return std::nullopt;
5717 }
5718
5719 return Rewrite;
5720}
5721
5722// FIXME: This utility is currently required because the Rewriter currently
5723// does not rewrite this expression:
5724// {0, +, (sext ix (trunc iy to ix) to iy)}
5725// into {0, +, %step},
5726// even when the following Equal predicate exists:
5727// "%step == (sext ix (trunc iy to ix) to iy)".
5729 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5730 if (AR1 == AR2)
5731 return true;
5732
5733 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5734 if (Expr1 != Expr2 &&
5735 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5736 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5737 return false;
5738 return true;
5739 };
5740
5741 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5742 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5743 return false;
5744 return true;
5745}
5746
5747/// A helper function for createAddRecFromPHI to handle simple cases.
5748///
5749/// This function tries to find an AddRec expression for the simplest (yet most
5750/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5751/// If it fails, createAddRecFromPHI will use a more general, but slow,
5752/// technique for finding the AddRec expression.
5753const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5754 Value *BEValueV,
5755 Value *StartValueV) {
5756 const Loop *L = LI.getLoopFor(PN->getParent());
5757 assert(L && L->getHeader() == PN->getParent());
5758 assert(BEValueV && StartValueV);
5759
5760 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5761 if (!BO)
5762 return nullptr;
5763
5764 if (BO->Opcode != Instruction::Add)
5765 return nullptr;
5766
5767 const SCEV *Accum = nullptr;
5768 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5769 Accum = getSCEV(BO->RHS);
5770 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5771 Accum = getSCEV(BO->LHS);
5772
5773 if (!Accum)
5774 return nullptr;
5775
5777 if (BO->IsNUW)
5778 Flags = setFlags(Flags, SCEV::FlagNUW);
5779 if (BO->IsNSW)
5780 Flags = setFlags(Flags, SCEV::FlagNSW);
5781
5782 const SCEV *StartVal = getSCEV(StartValueV);
5783 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5784 insertValueToMap(PN, PHISCEV);
5785
5786 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5787 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5789 proveNoWrapViaConstantRanges(AR)));
5790 }
5791
5792 // We can add Flags to the post-inc expression only if we
5793 // know that it is *undefined behavior* for BEValueV to
5794 // overflow.
5795 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5796 assert(isLoopInvariant(Accum, L) &&
5797 "Accum is defined outside L, but is not invariant?");
5798 if (isAddRecNeverPoison(BEInst, L))
5799 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5800 }
5801
5802 return PHISCEV;
5803}
5804
5805const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5806 const Loop *L = LI.getLoopFor(PN->getParent());
5807 if (!L || L->getHeader() != PN->getParent())
5808 return nullptr;
5809
5810 // The loop may have multiple entrances or multiple exits; we can analyze
5811 // this phi as an addrec if it has a unique entry value and a unique
5812 // backedge value.
5813 Value *BEValueV = nullptr, *StartValueV = nullptr;
5814 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5815 Value *V = PN->getIncomingValue(i);
5816 if (L->contains(PN->getIncomingBlock(i))) {
5817 if (!BEValueV) {
5818 BEValueV = V;
5819 } else if (BEValueV != V) {
5820 BEValueV = nullptr;
5821 break;
5822 }
5823 } else if (!StartValueV) {
5824 StartValueV = V;
5825 } else if (StartValueV != V) {
5826 StartValueV = nullptr;
5827 break;
5828 }
5829 }
5830 if (!BEValueV || !StartValueV)
5831 return nullptr;
5832
5833 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5834 "PHI node already processed?");
5835
5836 // First, try to find AddRec expression without creating a fictituos symbolic
5837 // value for PN.
5838 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5839 return S;
5840
5841 // Handle PHI node value symbolically.
5842 const SCEV *SymbolicName = getUnknown(PN);
5843 insertValueToMap(PN, SymbolicName);
5844
5845 // Using this symbolic name for the PHI, analyze the value coming around
5846 // the back-edge.
5847 const SCEV *BEValue = getSCEV(BEValueV);
5848
5849 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5850 // has a special value for the first iteration of the loop.
5851
5852 // If the value coming around the backedge is an add with the symbolic
5853 // value we just inserted, then we found a simple induction variable!
5854 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5855 // If there is a single occurrence of the symbolic value, replace it
5856 // with a recurrence.
5857 unsigned FoundIndex = Add->getNumOperands();
5858 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5859 if (Add->getOperand(i) == SymbolicName)
5860 if (FoundIndex == e) {
5861 FoundIndex = i;
5862 break;
5863 }
5864
5865 if (FoundIndex != Add->getNumOperands()) {
5866 // Create an add with everything but the specified operand.
5868 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5869 if (i != FoundIndex)
5870 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5871 L, *this));
5872 const SCEV *Accum = getAddExpr(Ops);
5873
5874 // This is not a valid addrec if the step amount is varying each
5875 // loop iteration, but is not itself an addrec in this loop.
5876 if (isLoopInvariant(Accum, L) ||
5877 (isa<SCEVAddRecExpr>(Accum) &&
5878 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5880
5881 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5882 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5883 if (BO->IsNUW)
5884 Flags = setFlags(Flags, SCEV::FlagNUW);
5885 if (BO->IsNSW)
5886 Flags = setFlags(Flags, SCEV::FlagNSW);
5887 }
5888 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5889 if (GEP->getOperand(0) == PN) {
5890 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5891 // If the increment has any nowrap flags, then we know the address
5892 // space cannot be wrapped around.
5893 if (NW != GEPNoWrapFlags::none())
5894 Flags = setFlags(Flags, SCEV::FlagNW);
5895 // If the GEP is nuw or nusw with non-negative offset, we know that
5896 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5897 // offset is treated as signed, while the base is unsigned.
5898 if (NW.hasNoUnsignedWrap() ||
5900 Flags = setFlags(Flags, SCEV::FlagNUW);
5901 }
5902
5903 // We cannot transfer nuw and nsw flags from subtraction
5904 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5905 // for instance.
5906 }
5907
5908 const SCEV *StartVal = getSCEV(StartValueV);
5909 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5910
5911 // Okay, for the entire analysis of this edge we assumed the PHI
5912 // to be symbolic. We now need to go back and purge all of the
5913 // entries for the scalars that use the symbolic expression.
5914 forgetMemoizedResults(SymbolicName);
5915 insertValueToMap(PN, PHISCEV);
5916
5917 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5918 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5920 proveNoWrapViaConstantRanges(AR)));
5921 }
5922
5923 // We can add Flags to the post-inc expression only if we
5924 // know that it is *undefined behavior* for BEValueV to
5925 // overflow.
5926 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5927 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5928 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5929
5930 return PHISCEV;
5931 }
5932 }
5933 } else {
5934 // Otherwise, this could be a loop like this:
5935 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5936 // In this case, j = {1,+,1} and BEValue is j.
5937 // Because the other in-value of i (0) fits the evolution of BEValue
5938 // i really is an addrec evolution.
5939 //
5940 // We can generalize this saying that i is the shifted value of BEValue
5941 // by one iteration:
5942 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5943
5944 // Do not allow refinement in rewriting of BEValue.
5945 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5946 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5947 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5948 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5949 const SCEV *StartVal = getSCEV(StartValueV);
5950 if (Start == StartVal) {
5951 // Okay, for the entire analysis of this edge we assumed the PHI
5952 // to be symbolic. We now need to go back and purge all of the
5953 // entries for the scalars that use the symbolic expression.
5954 forgetMemoizedResults(SymbolicName);
5955 insertValueToMap(PN, Shifted);
5956 return Shifted;
5957 }
5958 }
5959 }
5960
5961 // Remove the temporary PHI node SCEV that has been inserted while intending
5962 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5963 // as it will prevent later (possibly simpler) SCEV expressions to be added
5964 // to the ValueExprMap.
5965 eraseValueFromMap(PN);
5966
5967 return nullptr;
5968}
5969
5970// Try to match a control flow sequence that branches out at BI and merges back
5971// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5972// match.
5974 Value *&C, Value *&LHS, Value *&RHS) {
5975 C = BI->getCondition();
5976
5977 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5978 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5979
5980 if (!LeftEdge.isSingleEdge())
5981 return false;
5982
5983 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5984
5985 Use &LeftUse = Merge->getOperandUse(0);
5986 Use &RightUse = Merge->getOperandUse(1);
5987
5988 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5989 LHS = LeftUse;
5990 RHS = RightUse;
5991 return true;
5992 }
5993
5994 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5995 LHS = RightUse;
5996 RHS = LeftUse;
5997 return true;
5998 }
5999
6000 return false;
6001}
6002
6003const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6004 auto IsReachable =
6005 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6006 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6007 // Try to match
6008 //
6009 // br %cond, label %left, label %right
6010 // left:
6011 // br label %merge
6012 // right:
6013 // br label %merge
6014 // merge:
6015 // V = phi [ %x, %left ], [ %y, %right ]
6016 //
6017 // as "select %cond, %x, %y"
6018
6019 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6020 assert(IDom && "At least the entry block should dominate PN");
6021
6022 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6023 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6024
6025 if (BI && BI->isConditional() &&
6026 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6027 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6028 properlyDominates(getSCEV(RHS), PN->getParent()))
6029 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6030 }
6031
6032 return nullptr;
6033}
6034
6035/// Returns SCEV for the first operand of a phi if all phi operands have
6036/// identical opcodes and operands
6037/// eg.
6038/// a: %add = %a + %b
6039/// br %c
6040/// b: %add1 = %a + %b
6041/// br %c
6042/// c: %phi = phi [%add, a], [%add1, b]
6043/// scev(%phi) => scev(%add)
6044const SCEV *
6045ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6046 BinaryOperator *CommonInst = nullptr;
6047 // Check if instructions are identical.
6048 for (Value *Incoming : PN->incoming_values()) {
6049 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6050 if (!IncomingInst)
6051 return nullptr;
6052 if (CommonInst) {
6053 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6054 return nullptr; // Not identical, give up
6055 } else {
6056 // Remember binary operator
6057 CommonInst = IncomingInst;
6058 }
6059 }
6060 if (!CommonInst)
6061 return nullptr;
6062
6063 // Check if SCEV exprs for instructions are identical.
6064 const SCEV *CommonSCEV = getSCEV(CommonInst);
6065 bool SCEVExprsIdentical =
6067 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6068 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6069}
6070
6071const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6072 if (const SCEV *S = createAddRecFromPHI(PN))
6073 return S;
6074
6075 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6076 // phi node for X.
6077 if (Value *V = simplifyInstruction(
6078 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6079 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6080 return getSCEV(V);
6081
6082 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6083 return S;
6084
6085 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6086 return S;
6087
6088 // If it's not a loop phi, we can't handle it yet.
6089 return getUnknown(PN);
6090}
6091
6092bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6093 SCEVTypes RootKind) {
6094 struct FindClosure {
6095 const SCEV *OperandToFind;
6096 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6097 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6098
6099 bool Found = false;
6100
6101 bool canRecurseInto(SCEVTypes Kind) const {
6102 // We can only recurse into the SCEV expression of the same effective type
6103 // as the type of our root SCEV expression, and into zero-extensions.
6104 return RootKind == Kind || NonSequentialRootKind == Kind ||
6105 scZeroExtend == Kind;
6106 };
6107
6108 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6109 : OperandToFind(OperandToFind), RootKind(RootKind),
6110 NonSequentialRootKind(
6112 RootKind)) {}
6113
6114 bool follow(const SCEV *S) {
6115 Found = S == OperandToFind;
6116
6117 return !isDone() && canRecurseInto(S->getSCEVType());
6118 }
6119
6120 bool isDone() const { return Found; }
6121 };
6122
6123 FindClosure FC(OperandToFind, RootKind);
6124 visitAll(Root, FC);
6125 return FC.Found;
6126}
6127
6128std::optional<const SCEV *>
6129ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6130 ICmpInst *Cond,
6131 Value *TrueVal,
6132 Value *FalseVal) {
6133 // Try to match some simple smax or umax patterns.
6134 auto *ICI = Cond;
6135
6136 Value *LHS = ICI->getOperand(0);
6137 Value *RHS = ICI->getOperand(1);
6138
6139 switch (ICI->getPredicate()) {
6140 case ICmpInst::ICMP_SLT:
6141 case ICmpInst::ICMP_SLE:
6142 case ICmpInst::ICMP_ULT:
6143 case ICmpInst::ICMP_ULE:
6144 std::swap(LHS, RHS);
6145 [[fallthrough]];
6146 case ICmpInst::ICMP_SGT:
6147 case ICmpInst::ICMP_SGE:
6148 case ICmpInst::ICMP_UGT:
6149 case ICmpInst::ICMP_UGE:
6150 // a > b ? a+x : b+x -> max(a, b)+x
6151 // a > b ? b+x : a+x -> min(a, b)+x
6153 bool Signed = ICI->isSigned();
6154 const SCEV *LA = getSCEV(TrueVal);
6155 const SCEV *RA = getSCEV(FalseVal);
6156 const SCEV *LS = getSCEV(LHS);
6157 const SCEV *RS = getSCEV(RHS);
6158 if (LA->getType()->isPointerTy()) {
6159 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6160 // Need to make sure we can't produce weird expressions involving
6161 // negated pointers.
6162 if (LA == LS && RA == RS)
6163 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6164 if (LA == RS && RA == LS)
6165 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6166 }
6167 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6168 if (Op->getType()->isPointerTy()) {
6170 if (isa<SCEVCouldNotCompute>(Op))
6171 return Op;
6172 }
6173 if (Signed)
6174 Op = getNoopOrSignExtend(Op, Ty);
6175 else
6176 Op = getNoopOrZeroExtend(Op, Ty);
6177 return Op;
6178 };
6179 LS = CoerceOperand(LS);
6180 RS = CoerceOperand(RS);
6181 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6182 break;
6183 const SCEV *LDiff = getMinusSCEV(LA, LS);
6184 const SCEV *RDiff = getMinusSCEV(RA, RS);
6185 if (LDiff == RDiff)
6186 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6187 LDiff);
6188 LDiff = getMinusSCEV(LA, RS);
6189 RDiff = getMinusSCEV(RA, LS);
6190 if (LDiff == RDiff)
6191 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6192 LDiff);
6193 }
6194 break;
6195 case ICmpInst::ICMP_NE:
6196 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6197 std::swap(TrueVal, FalseVal);
6198 [[fallthrough]];
6199 case ICmpInst::ICMP_EQ:
6200 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6202 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6203 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6204 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6205 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6206 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6207 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6208 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6209 return getAddExpr(getUMaxExpr(X, C), Y);
6210 }
6211 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6212 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6213 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6214 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6215 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6216 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6217 const SCEV *X = getSCEV(LHS);
6218 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6219 X = ZExt->getOperand();
6220 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6221 const SCEV *FalseValExpr = getSCEV(FalseVal);
6222 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6223 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6224 /*Sequential=*/true);
6225 }
6226 }
6227 break;
6228 default:
6229 break;
6230 }
6231
6232 return std::nullopt;
6233}
6234
6235static std::optional<const SCEV *>
6237 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6238 assert(CondExpr->getType()->isIntegerTy(1) &&
6239 TrueExpr->getType() == FalseExpr->getType() &&
6240 TrueExpr->getType()->isIntegerTy(1) &&
6241 "Unexpected operands of a select.");
6242
6243 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6244 // --> C + (umin_seq cond, x - C)
6245 //
6246 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6247 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6248 // --> C + (umin_seq ~cond, x - C)
6249
6250 // FIXME: while we can't legally model the case where both of the hands
6251 // are fully variable, we only require that the *difference* is constant.
6252 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6253 return std::nullopt;
6254
6255 const SCEV *X, *C;
6256 if (isa<SCEVConstant>(TrueExpr)) {
6257 CondExpr = SE->getNotSCEV(CondExpr);
6258 X = FalseExpr;
6259 C = TrueExpr;
6260 } else {
6261 X = TrueExpr;
6262 C = FalseExpr;
6263 }
6264 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6265 /*Sequential=*/true));
6266}
6267
6268static std::optional<const SCEV *>
6270 Value *FalseVal) {
6271 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6272 return std::nullopt;
6273
6274 const auto *SECond = SE->getSCEV(Cond);
6275 const auto *SETrue = SE->getSCEV(TrueVal);
6276 const auto *SEFalse = SE->getSCEV(FalseVal);
6277 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6278}
6279
6280const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6281 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6282 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6283 assert(TrueVal->getType() == FalseVal->getType() &&
6284 V->getType() == TrueVal->getType() &&
6285 "Types of select hands and of the result must match.");
6286
6287 // For now, only deal with i1-typed `select`s.
6288 if (!V->getType()->isIntegerTy(1))
6289 return getUnknown(V);
6290
6291 if (std::optional<const SCEV *> S =
6292 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6293 return *S;
6294
6295 return getUnknown(V);
6296}
6297
6298const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6299 Value *TrueVal,
6300 Value *FalseVal) {
6301 // Handle "constant" branch or select. This can occur for instance when a
6302 // loop pass transforms an inner loop and moves on to process the outer loop.
6303 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6304 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6305
6306 if (auto *I = dyn_cast<Instruction>(V)) {
6307 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6308 if (std::optional<const SCEV *> S =
6309 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6310 TrueVal, FalseVal))
6311 return *S;
6312 }
6313 }
6314
6315 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6316}
6317
6318/// Expand GEP instructions into add and multiply operations. This allows them
6319/// to be analyzed by regular SCEV code.
6320const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6321 assert(GEP->getSourceElementType()->isSized() &&
6322 "GEP source element type must be sized");
6323
6325 for (Value *Index : GEP->indices())
6326 IndexExprs.push_back(getSCEV(Index));
6327 return getGEPExpr(GEP, IndexExprs);
6328}
6329
6330APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6332 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6333 return TrailingZeros >= BitWidth
6335 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6336 };
6337 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6338 // The result is GCD of all operands results.
6339 APInt Res = getConstantMultiple(N->getOperand(0));
6340 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6342 Res, getConstantMultiple(N->getOperand(I)));
6343 return Res;
6344 };
6345
6346 switch (S->getSCEVType()) {
6347 case scConstant:
6348 return cast<SCEVConstant>(S)->getAPInt();
6349 case scPtrToInt:
6350 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6351 case scUDivExpr:
6352 case scVScale:
6353 return APInt(BitWidth, 1);
6354 case scTruncate: {
6355 // Only multiples that are a power of 2 will hold after truncation.
6356 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6357 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6358 return GetShiftedByZeros(TZ);
6359 }
6360 case scZeroExtend: {
6361 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6362 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6363 }
6364 case scSignExtend: {
6365 // Only multiples that are a power of 2 will hold after sext.
6366 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6368 return GetShiftedByZeros(TZ);
6369 }
6370 case scMulExpr: {
6371 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6372 if (M->hasNoUnsignedWrap()) {
6373 // The result is the product of all operand results.
6374 APInt Res = getConstantMultiple(M->getOperand(0));
6375 for (const SCEV *Operand : M->operands().drop_front())
6376 Res = Res * getConstantMultiple(Operand);
6377 return Res;
6378 }
6379
6380 // If there are no wrap guarentees, find the trailing zeros, which is the
6381 // sum of trailing zeros for all its operands.
6382 uint32_t TZ = 0;
6383 for (const SCEV *Operand : M->operands())
6384 TZ += getMinTrailingZeros(Operand);
6385 return GetShiftedByZeros(TZ);
6386 }
6387 case scAddExpr:
6388 case scAddRecExpr: {
6389 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6390 if (N->hasNoUnsignedWrap())
6391 return GetGCDMultiple(N);
6392 // Find the trailing bits, which is the minimum of its operands.
6393 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6394 for (const SCEV *Operand : N->operands().drop_front())
6395 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6396 return GetShiftedByZeros(TZ);
6397 }
6398 case scUMaxExpr:
6399 case scSMaxExpr:
6400 case scUMinExpr:
6401 case scSMinExpr:
6403 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6404 case scUnknown: {
6405 // ask ValueTracking for known bits
6406 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6407 unsigned Known =
6408 computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6409 .countMinTrailingZeros();
6410 return GetShiftedByZeros(Known);
6411 }
6412 case scCouldNotCompute:
6413 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6414 }
6415 llvm_unreachable("Unknown SCEV kind!");
6416}
6417
6419 auto I = ConstantMultipleCache.find(S);
6420 if (I != ConstantMultipleCache.end())
6421 return I->second;
6422
6423 APInt Result = getConstantMultipleImpl(S);
6424 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6425 assert(InsertPair.second && "Should insert a new key");
6426 return InsertPair.first->second;
6427}
6428
6430 APInt Multiple = getConstantMultiple(S);
6431 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6432}
6433
6435 return std::min(getConstantMultiple(S).countTrailingZeros(),
6436 (unsigned)getTypeSizeInBits(S->getType()));
6437}
6438
6439/// Helper method to assign a range to V from metadata present in the IR.
6440static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6441 if (Instruction *I = dyn_cast<Instruction>(V)) {
6442 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6443 return getConstantRangeFromMetadata(*MD);
6444 if (const auto *CB = dyn_cast<CallBase>(V))
6445 if (std::optional<ConstantRange> Range = CB->getRange())
6446 return Range;
6447 }
6448 if (auto *A = dyn_cast<Argument>(V))
6449 if (std::optional<ConstantRange> Range = A->getRange())
6450 return Range;
6451
6452 return std::nullopt;
6453}
6454
6456 SCEV::NoWrapFlags Flags) {
6457 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6458 AddRec->setNoWrapFlags(Flags);
6459 UnsignedRanges.erase(AddRec);
6460 SignedRanges.erase(AddRec);
6461 ConstantMultipleCache.erase(AddRec);
6462 }
6463}
6464
6465ConstantRange ScalarEvolution::
6466getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6467 const DataLayout &DL = getDataLayout();
6468
6469 unsigned BitWidth = getTypeSizeInBits(U->getType());
6470 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6471
6472 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6473 // use information about the trip count to improve our available range. Note
6474 // that the trip count independent cases are already handled by known bits.
6475 // WARNING: The definition of recurrence used here is subtly different than
6476 // the one used by AddRec (and thus most of this file). Step is allowed to
6477 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6478 // and other addrecs in the same loop (for non-affine addrecs). The code
6479 // below intentionally handles the case where step is not loop invariant.
6480 auto *P = dyn_cast<PHINode>(U->getValue());
6481 if (!P)
6482 return FullSet;
6483
6484 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6485 // even the values that are not available in these blocks may come from them,
6486 // and this leads to false-positive recurrence test.
6487 for (auto *Pred : predecessors(P->getParent()))
6488 if (!DT.isReachableFromEntry(Pred))
6489 return FullSet;
6490
6491 BinaryOperator *BO;
6492 Value *Start, *Step;
6493 if (!matchSimpleRecurrence(P, BO, Start, Step))
6494 return FullSet;
6495
6496 // If we found a recurrence in reachable code, we must be in a loop. Note
6497 // that BO might be in some subloop of L, and that's completely okay.
6498 auto *L = LI.getLoopFor(P->getParent());
6499 assert(L && L->getHeader() == P->getParent());
6500 if (!L->contains(BO->getParent()))
6501 // NOTE: This bailout should be an assert instead. However, asserting
6502 // the condition here exposes a case where LoopFusion is querying SCEV
6503 // with malformed loop information during the midst of the transform.
6504 // There doesn't appear to be an obvious fix, so for the moment bailout
6505 // until the caller issue can be fixed. PR49566 tracks the bug.
6506 return FullSet;
6507
6508 // TODO: Extend to other opcodes such as mul, and div
6509 switch (BO->getOpcode()) {
6510 default:
6511 return FullSet;
6512 case Instruction::AShr:
6513 case Instruction::LShr:
6514 case Instruction::Shl:
6515 break;
6516 };
6517
6518 if (BO->getOperand(0) != P)
6519 // TODO: Handle the power function forms some day.
6520 return FullSet;
6521
6522 unsigned TC = getSmallConstantMaxTripCount(L);
6523 if (!TC || TC >= BitWidth)
6524 return FullSet;
6525
6526 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6527 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6528 assert(KnownStart.getBitWidth() == BitWidth &&
6529 KnownStep.getBitWidth() == BitWidth);
6530
6531 // Compute total shift amount, being careful of overflow and bitwidths.
6532 auto MaxShiftAmt = KnownStep.getMaxValue();
6533 APInt TCAP(BitWidth, TC-1);
6534 bool Overflow = false;
6535 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6536 if (Overflow)
6537 return FullSet;
6538
6539 switch (BO->getOpcode()) {
6540 default:
6541 llvm_unreachable("filtered out above");
6542 case Instruction::AShr: {
6543 // For each ashr, three cases:
6544 // shift = 0 => unchanged value
6545 // saturation => 0 or -1
6546 // other => a value closer to zero (of the same sign)
6547 // Thus, the end value is closer to zero than the start.
6548 auto KnownEnd = KnownBits::ashr(KnownStart,
6549 KnownBits::makeConstant(TotalShift));
6550 if (KnownStart.isNonNegative())
6551 // Analogous to lshr (simply not yet canonicalized)
6552 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6553 KnownStart.getMaxValue() + 1);
6554 if (KnownStart.isNegative())
6555 // End >=u Start && End <=s Start
6556 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6557 KnownEnd.getMaxValue() + 1);
6558 break;
6559 }
6560 case Instruction::LShr: {
6561 // For each lshr, three cases:
6562 // shift = 0 => unchanged value
6563 // saturation => 0
6564 // other => a smaller positive number
6565 // Thus, the low end of the unsigned range is the last value produced.
6566 auto KnownEnd = KnownBits::lshr(KnownStart,
6567 KnownBits::makeConstant(TotalShift));
6568 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6569 KnownStart.getMaxValue() + 1);
6570 }
6571 case Instruction::Shl: {
6572 // Iff no bits are shifted out, value increases on every shift.
6573 auto KnownEnd = KnownBits::shl(KnownStart,
6574 KnownBits::makeConstant(TotalShift));
6575 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6576 return ConstantRange(KnownStart.getMinValue(),
6577 KnownEnd.getMaxValue() + 1);
6578 break;
6579 }
6580 };
6581 return FullSet;
6582}
6583
6584const ConstantRange &
6585ScalarEvolution::getRangeRefIter(const SCEV *S,
6586 ScalarEvolution::RangeSignHint SignHint) {
6588 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6589 : SignedRanges;
6592
6593 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6594 // SCEVUnknown PHI node.
6595 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6596 if (!Seen.insert(Expr).second)
6597 return;
6598 if (Cache.contains(Expr))
6599 return;
6600 switch (Expr->getSCEVType()) {
6601 case scUnknown:
6602 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6603 break;
6604 [[fallthrough]];
6605 case scConstant:
6606 case scVScale:
6607 case scTruncate:
6608 case scZeroExtend:
6609 case scSignExtend:
6610 case scPtrToInt:
6611 case scAddExpr:
6612 case scMulExpr:
6613 case scUDivExpr:
6614 case scAddRecExpr:
6615 case scUMaxExpr:
6616 case scSMaxExpr:
6617 case scUMinExpr:
6618 case scSMinExpr:
6620 WorkList.push_back(Expr);
6621 break;
6622 case scCouldNotCompute:
6623 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6624 }
6625 };
6626 AddToWorklist(S);
6627
6628 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6629 for (unsigned I = 0; I != WorkList.size(); ++I) {
6630 const SCEV *P = WorkList[I];
6631 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6632 // If it is not a `SCEVUnknown`, just recurse into operands.
6633 if (!UnknownS) {
6634 for (const SCEV *Op : P->operands())
6635 AddToWorklist(Op);
6636 continue;
6637 }
6638 // `SCEVUnknown`'s require special treatment.
6639 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6640 if (!PendingPhiRangesIter.insert(P).second)
6641 continue;
6642 for (auto &Op : reverse(P->operands()))
6643 AddToWorklist(getSCEV(Op));
6644 }
6645 }
6646
6647 if (!WorkList.empty()) {
6648 // Use getRangeRef to compute ranges for items in the worklist in reverse
6649 // order. This will force ranges for earlier operands to be computed before
6650 // their users in most cases.
6651 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6652 getRangeRef(P, SignHint);
6653
6654 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6655 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6656 PendingPhiRangesIter.erase(P);
6657 }
6658 }
6659
6660 return getRangeRef(S, SignHint, 0);
6661}
6662
6663/// Determine the range for a particular SCEV. If SignHint is
6664/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6665/// with a "cleaner" unsigned (resp. signed) representation.
6666const ConstantRange &ScalarEvolution::getRangeRef(
6667 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6669 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6670 : SignedRanges;
6672 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6674
6675 // See if we've computed this range already.
6677 if (I != Cache.end())
6678 return I->second;
6679
6680 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6681 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6682
6683 // Switch to iteratively computing the range for S, if it is part of a deeply
6684 // nested expression.
6686 return getRangeRefIter(S, SignHint);
6687
6688 unsigned BitWidth = getTypeSizeInBits(S->getType());
6689 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6690 using OBO = OverflowingBinaryOperator;
6691
6692 // If the value has known zeros, the maximum value will have those known zeros
6693 // as well.
6694 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6695 APInt Multiple = getNonZeroConstantMultiple(S);
6696 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6697 if (!Remainder.isZero())
6698 ConservativeResult =
6700 APInt::getMaxValue(BitWidth) - Remainder + 1);
6701 }
6702 else {
6704 if (TZ != 0) {
6705 ConservativeResult = ConstantRange(
6707 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6708 }
6709 }
6710
6711 switch (S->getSCEVType()) {
6712 case scConstant:
6713 llvm_unreachable("Already handled above.");
6714 case scVScale:
6715 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6716 case scTruncate: {
6717 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6718 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6719 return setRange(
6720 Trunc, SignHint,
6721 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6722 }
6723 case scZeroExtend: {
6724 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6725 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6726 return setRange(
6727 ZExt, SignHint,
6728 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6729 }
6730 case scSignExtend: {
6731 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6732 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6733 return setRange(
6734 SExt, SignHint,
6735 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6736 }
6737 case scPtrToInt: {
6738 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6739 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6740 return setRange(PtrToInt, SignHint, X);
6741 }
6742 case scAddExpr: {
6743 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6744 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6745 unsigned WrapType = OBO::AnyWrap;
6746 if (Add->hasNoSignedWrap())
6747 WrapType |= OBO::NoSignedWrap;
6748 if (Add->hasNoUnsignedWrap())
6749 WrapType |= OBO::NoUnsignedWrap;
6750 for (const SCEV *Op : drop_begin(Add->operands()))
6751 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6752 RangeType);
6753 return setRange(Add, SignHint,
6754 ConservativeResult.intersectWith(X, RangeType));
6755 }
6756 case scMulExpr: {
6757 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6758 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6759 for (const SCEV *Op : drop_begin(Mul->operands()))
6760 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6761 return setRange(Mul, SignHint,
6762 ConservativeResult.intersectWith(X, RangeType));
6763 }
6764 case scUDivExpr: {
6765 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6766 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6767 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6768 return setRange(UDiv, SignHint,
6769 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6770 }
6771 case scAddRecExpr: {
6772 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6773 // If there's no unsigned wrap, the value will never be less than its
6774 // initial value.
6775 if (AddRec->hasNoUnsignedWrap()) {
6776 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6777 if (!UnsignedMinValue.isZero())
6778 ConservativeResult = ConservativeResult.intersectWith(
6779 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6780 }
6781
6782 // If there's no signed wrap, and all the operands except initial value have
6783 // the same sign or zero, the value won't ever be:
6784 // 1: smaller than initial value if operands are non negative,
6785 // 2: bigger than initial value if operands are non positive.
6786 // For both cases, value can not cross signed min/max boundary.
6787 if (AddRec->hasNoSignedWrap()) {
6788 bool AllNonNeg = true;
6789 bool AllNonPos = true;
6790 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6791 if (!isKnownNonNegative(AddRec->getOperand(i)))
6792 AllNonNeg = false;
6793 if (!isKnownNonPositive(AddRec->getOperand(i)))
6794 AllNonPos = false;
6795 }
6796 if (AllNonNeg)
6797 ConservativeResult = ConservativeResult.intersectWith(
6800 RangeType);
6801 else if (AllNonPos)
6802 ConservativeResult = ConservativeResult.intersectWith(
6804 getSignedRangeMax(AddRec->getStart()) +
6805 1),
6806 RangeType);
6807 }
6808
6809 // TODO: non-affine addrec
6810 if (AddRec->isAffine()) {
6811 const SCEV *MaxBEScev =
6813 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6814 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6815
6816 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6817 // MaxBECount's active bits are all <= AddRec's bit width.
6818 if (MaxBECount.getBitWidth() > BitWidth &&
6819 MaxBECount.getActiveBits() <= BitWidth)
6820 MaxBECount = MaxBECount.trunc(BitWidth);
6821 else if (MaxBECount.getBitWidth() < BitWidth)
6822 MaxBECount = MaxBECount.zext(BitWidth);
6823
6824 if (MaxBECount.getBitWidth() == BitWidth) {
6825 auto RangeFromAffine = getRangeForAffineAR(
6826 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6827 ConservativeResult =
6828 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6829
6830 auto RangeFromFactoring = getRangeViaFactoring(
6831 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6832 ConservativeResult =
6833 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6834 }
6835 }
6836
6837 // Now try symbolic BE count and more powerful methods.
6839 const SCEV *SymbolicMaxBECount =
6841 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6842 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6843 AddRec->hasNoSelfWrap()) {
6844 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6845 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6846 ConservativeResult =
6847 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6848 }
6849 }
6850 }
6851
6852 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6853 }
6854 case scUMaxExpr:
6855 case scSMaxExpr:
6856 case scUMinExpr:
6857 case scSMinExpr:
6858 case scSequentialUMinExpr: {
6860 switch (S->getSCEVType()) {
6861 case scUMaxExpr:
6862 ID = Intrinsic::umax;
6863 break;
6864 case scSMaxExpr:
6865 ID = Intrinsic::smax;
6866 break;
6867 case scUMinExpr:
6869 ID = Intrinsic::umin;
6870 break;
6871 case scSMinExpr:
6872 ID = Intrinsic::smin;
6873 break;
6874 default:
6875 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6876 }
6877
6878 const auto *NAry = cast<SCEVNAryExpr>(S);
6879 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6880 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6881 X = X.intrinsic(
6882 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6883 return setRange(S, SignHint,
6884 ConservativeResult.intersectWith(X, RangeType));
6885 }
6886 case scUnknown: {
6887 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6888 Value *V = U->getValue();
6889
6890 // Check if the IR explicitly contains !range metadata.
6891 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6892 if (MDRange)
6893 ConservativeResult =
6894 ConservativeResult.intersectWith(*MDRange, RangeType);
6895
6896 // Use facts about recurrences in the underlying IR. Note that add
6897 // recurrences are AddRecExprs and thus don't hit this path. This
6898 // primarily handles shift recurrences.
6899 auto CR = getRangeForUnknownRecurrence(U);
6900 ConservativeResult = ConservativeResult.intersectWith(CR);
6901
6902 // See if ValueTracking can give us a useful range.
6903 const DataLayout &DL = getDataLayout();
6904 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6905 if (Known.getBitWidth() != BitWidth)
6906 Known = Known.zextOrTrunc(BitWidth);
6907
6908 // ValueTracking may be able to compute a tighter result for the number of
6909 // sign bits than for the value of those sign bits.
6910 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6911 if (U->getType()->isPointerTy()) {
6912 // If the pointer size is larger than the index size type, this can cause
6913 // NS to be larger than BitWidth. So compensate for this.
6914 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6915 int ptrIdxDiff = ptrSize - BitWidth;
6916 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6917 NS -= ptrIdxDiff;
6918 }
6919
6920 if (NS > 1) {
6921 // If we know any of the sign bits, we know all of the sign bits.
6922 if (!Known.Zero.getHiBits(NS).isZero())
6923 Known.Zero.setHighBits(NS);
6924 if (!Known.One.getHiBits(NS).isZero())
6925 Known.One.setHighBits(NS);
6926 }
6927
6928 if (Known.getMinValue() != Known.getMaxValue() + 1)
6929 ConservativeResult = ConservativeResult.intersectWith(
6930 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6931 RangeType);
6932 if (NS > 1)
6933 ConservativeResult = ConservativeResult.intersectWith(
6935 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6936 RangeType);
6937
6938 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6939 // Strengthen the range if the underlying IR value is a
6940 // global/alloca/heap allocation using the size of the object.
6941 bool CanBeNull, CanBeFreed;
6942 uint64_t DerefBytes =
6943 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6944 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6945 // The highest address the object can start is DerefBytes bytes before
6946 // the end (unsigned max value). If this value is not a multiple of the
6947 // alignment, the last possible start value is the next lowest multiple
6948 // of the alignment. Note: The computations below cannot overflow,
6949 // because if they would there's no possible start address for the
6950 // object.
6951 APInt MaxVal =
6952 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6953 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6954 uint64_t Rem = MaxVal.urem(Align);
6955 MaxVal -= APInt(BitWidth, Rem);
6956 APInt MinVal = APInt::getZero(BitWidth);
6957 if (llvm::isKnownNonZero(V, DL))
6958 MinVal = Align;
6959 ConservativeResult = ConservativeResult.intersectWith(
6960 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6961 }
6962 }
6963
6964 // A range of Phi is a subset of union of all ranges of its input.
6965 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6966 // Make sure that we do not run over cycled Phis.
6967 if (PendingPhiRanges.insert(Phi).second) {
6968 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6969
6970 for (const auto &Op : Phi->operands()) {
6971 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6972 RangeFromOps = RangeFromOps.unionWith(OpRange);
6973 // No point to continue if we already have a full set.
6974 if (RangeFromOps.isFullSet())
6975 break;
6976 }
6977 ConservativeResult =
6978 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6979 bool Erased = PendingPhiRanges.erase(Phi);
6980 assert(Erased && "Failed to erase Phi properly?");
6981 (void)Erased;
6982 }
6983 }
6984
6985 // vscale can't be equal to zero
6986 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6987 if (II->getIntrinsicID() == Intrinsic::vscale) {
6989 ConservativeResult = ConservativeResult.difference(Disallowed);
6990 }
6991
6992 return setRange(U, SignHint, std::move(ConservativeResult));
6993 }
6994 case scCouldNotCompute:
6995 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6996 }
6997
6998 return setRange(S, SignHint, std::move(ConservativeResult));
6999}
7000
7001// Given a StartRange, Step and MaxBECount for an expression compute a range of
7002// values that the expression can take. Initially, the expression has a value
7003// from StartRange and then is changed by Step up to MaxBECount times. Signed
7004// argument defines if we treat Step as signed or unsigned.
7006 const ConstantRange &StartRange,
7007 const APInt &MaxBECount,
7008 bool Signed) {
7009 unsigned BitWidth = Step.getBitWidth();
7010 assert(BitWidth == StartRange.getBitWidth() &&
7011 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7012 // If either Step or MaxBECount is 0, then the expression won't change, and we
7013 // just need to return the initial range.
7014 if (Step == 0 || MaxBECount == 0)
7015 return StartRange;
7016
7017 // If we don't know anything about the initial value (i.e. StartRange is
7018 // FullRange), then we don't know anything about the final range either.
7019 // Return FullRange.
7020 if (StartRange.isFullSet())
7021 return ConstantRange::getFull(BitWidth);
7022
7023 // If Step is signed and negative, then we use its absolute value, but we also
7024 // note that we're moving in the opposite direction.
7025 bool Descending = Signed && Step.isNegative();
7026
7027 if (Signed)
7028 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7029 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7030 // This equations hold true due to the well-defined wrap-around behavior of
7031 // APInt.
7032 Step = Step.abs();
7033
7034 // Check if Offset is more than full span of BitWidth. If it is, the
7035 // expression is guaranteed to overflow.
7036 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7037 return ConstantRange::getFull(BitWidth);
7038
7039 // Offset is by how much the expression can change. Checks above guarantee no
7040 // overflow here.
7041 APInt Offset = Step * MaxBECount;
7042
7043 // Minimum value of the final range will match the minimal value of StartRange
7044 // if the expression is increasing and will be decreased by Offset otherwise.
7045 // Maximum value of the final range will match the maximal value of StartRange
7046 // if the expression is decreasing and will be increased by Offset otherwise.
7047 APInt StartLower = StartRange.getLower();
7048 APInt StartUpper = StartRange.getUpper() - 1;
7049 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7050 : (StartUpper + std::move(Offset));
7051
7052 // It's possible that the new minimum/maximum value will fall into the initial
7053 // range (due to wrap around). This means that the expression can take any
7054 // value in this bitwidth, and we have to return full range.
7055 if (StartRange.contains(MovedBoundary))
7056 return ConstantRange::getFull(BitWidth);
7057
7058 APInt NewLower =
7059 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7060 APInt NewUpper =
7061 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7062 NewUpper += 1;
7063
7064 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7065 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7066}
7067
7068ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7069 const SCEV *Step,
7070 const APInt &MaxBECount) {
7071 assert(getTypeSizeInBits(Start->getType()) ==
7072 getTypeSizeInBits(Step->getType()) &&
7073 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7074 "mismatched bit widths");
7075
7076 // First, consider step signed.
7077 ConstantRange StartSRange = getSignedRange(Start);
7078 ConstantRange StepSRange = getSignedRange(Step);
7079
7080 // If Step can be both positive and negative, we need to find ranges for the
7081 // maximum absolute step values in both directions and union them.
7083 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7085 StartSRange, MaxBECount,
7086 /* Signed = */ true));
7087
7088 // Next, consider step unsigned.
7090 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7091 /* Signed = */ false);
7092
7093 // Finally, intersect signed and unsigned ranges.
7095}
7096
7097ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7098 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7099 ScalarEvolution::RangeSignHint SignHint) {
7100 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7101 assert(AddRec->hasNoSelfWrap() &&
7102 "This only works for non-self-wrapping AddRecs!");
7103 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7104 const SCEV *Step = AddRec->getStepRecurrence(*this);
7105 // Only deal with constant step to save compile time.
7106 if (!isa<SCEVConstant>(Step))
7107 return ConstantRange::getFull(BitWidth);
7108 // Let's make sure that we can prove that we do not self-wrap during
7109 // MaxBECount iterations. We need this because MaxBECount is a maximum
7110 // iteration count estimate, and we might infer nw from some exit for which we
7111 // do not know max exit count (or any other side reasoning).
7112 // TODO: Turn into assert at some point.
7113 if (getTypeSizeInBits(MaxBECount->getType()) >
7114 getTypeSizeInBits(AddRec->getType()))
7115 return ConstantRange::getFull(BitWidth);
7116 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7117 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7118 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7119 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7120 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7121 MaxItersWithoutWrap))
7122 return ConstantRange::getFull(BitWidth);
7123
7124 ICmpInst::Predicate LEPred =
7126 ICmpInst::Predicate GEPred =
7128 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7129
7130 // We know that there is no self-wrap. Let's take Start and End values and
7131 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7132 // the iteration. They either lie inside the range [Min(Start, End),
7133 // Max(Start, End)] or outside it:
7134 //
7135 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7136 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7137 //
7138 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7139 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7140 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7141 // Start <= End and step is positive, or Start >= End and step is negative.
7142 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7143 ConstantRange StartRange = getRangeRef(Start, SignHint);
7144 ConstantRange EndRange = getRangeRef(End, SignHint);
7145 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7146 // If they already cover full iteration space, we will know nothing useful
7147 // even if we prove what we want to prove.
7148 if (RangeBetween.isFullSet())
7149 return RangeBetween;
7150 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7151 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7152 : RangeBetween.isWrappedSet();
7153 if (IsWrappedSet)
7154 return ConstantRange::getFull(BitWidth);
7155
7156 if (isKnownPositive(Step) &&
7157 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7158 return RangeBetween;
7159 if (isKnownNegative(Step) &&
7160 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7161 return RangeBetween;
7162 return ConstantRange::getFull(BitWidth);
7163}
7164
7165ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7166 const SCEV *Step,
7167 const APInt &MaxBECount) {
7168 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7169 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7170
7171 unsigned BitWidth = MaxBECount.getBitWidth();
7172 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7173 getTypeSizeInBits(Step->getType()) == BitWidth &&
7174 "mismatched bit widths");
7175
7176 struct SelectPattern {
7177 Value *Condition = nullptr;
7178 APInt TrueValue;
7179 APInt FalseValue;
7180
7181 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7182 const SCEV *S) {
7183 std::optional<unsigned> CastOp;
7184 APInt Offset(BitWidth, 0);
7185
7187 "Should be!");
7188
7189 // Peel off a constant offset. In the future we could consider being
7190 // smarter here and handle {Start+Step,+,Step} too.
7191 const APInt *Off;
7192 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7193 Offset = *Off;
7194
7195 // Peel off a cast operation
7196 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7197 CastOp = SCast->getSCEVType();
7198 S = SCast->getOperand();
7199 }
7200
7201 using namespace llvm::PatternMatch;
7202
7203 auto *SU = dyn_cast<SCEVUnknown>(S);
7204 const APInt *TrueVal, *FalseVal;
7205 if (!SU ||
7206 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7207 m_APInt(FalseVal)))) {
7208 Condition = nullptr;
7209 return;
7210 }
7211
7212 TrueValue = *TrueVal;
7213 FalseValue = *FalseVal;
7214
7215 // Re-apply the cast we peeled off earlier
7216 if (CastOp)
7217 switch (*CastOp) {
7218 default:
7219 llvm_unreachable("Unknown SCEV cast type!");
7220
7221 case scTruncate:
7222 TrueValue = TrueValue.trunc(BitWidth);
7223 FalseValue = FalseValue.trunc(BitWidth);
7224 break;
7225 case scZeroExtend:
7226 TrueValue = TrueValue.zext(BitWidth);
7227 FalseValue = FalseValue.zext(BitWidth);
7228 break;
7229 case scSignExtend:
7230 TrueValue = TrueValue.sext(BitWidth);
7231 FalseValue = FalseValue.sext(BitWidth);
7232 break;
7233 }
7234
7235 // Re-apply the constant offset we peeled off earlier
7236 TrueValue += Offset;
7237 FalseValue += Offset;
7238 }
7239
7240 bool isRecognized() { return Condition != nullptr; }
7241 };
7242
7243 SelectPattern StartPattern(*this, BitWidth, Start);
7244 if (!StartPattern.isRecognized())
7245 return ConstantRange::getFull(BitWidth);
7246
7247 SelectPattern StepPattern(*this, BitWidth, Step);
7248 if (!StepPattern.isRecognized())
7249 return ConstantRange::getFull(BitWidth);
7250
7251 if (StartPattern.Condition != StepPattern.Condition) {
7252 // We don't handle this case today; but we could, by considering four
7253 // possibilities below instead of two. I'm not sure if there are cases where
7254 // that will help over what getRange already does, though.
7255 return ConstantRange::getFull(BitWidth);
7256 }
7257
7258 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7259 // construct arbitrary general SCEV expressions here. This function is called
7260 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7261 // say) can end up caching a suboptimal value.
7262
7263 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7264 // C2352 and C2512 (otherwise it isn't needed).
7265
7266 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7267 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7268 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7269 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7270
7271 ConstantRange TrueRange =
7272 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7273 ConstantRange FalseRange =
7274 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7275
7276 return TrueRange.unionWith(FalseRange);
7277}
7278
7279SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7280 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7281 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7282
7283 // Return early if there are no flags to propagate to the SCEV.
7285 if (BinOp->hasNoUnsignedWrap())
7287 if (BinOp->hasNoSignedWrap())
7289 if (Flags == SCEV::FlagAnyWrap)
7290 return SCEV::FlagAnyWrap;
7291
7292 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7293}
7294
7295const Instruction *
7296ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7297 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7298 return &*AddRec->getLoop()->getHeader()->begin();
7299 if (auto *U = dyn_cast<SCEVUnknown>(S))
7300 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7301 return I;
7302 return nullptr;
7303}
7304
7305const Instruction *
7306ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7307 bool &Precise) {
7308 Precise = true;
7309 // Do a bounded search of the def relation of the requested SCEVs.
7312 auto pushOp = [&](const SCEV *S) {
7313 if (!Visited.insert(S).second)
7314 return;
7315 // Threshold of 30 here is arbitrary.
7316 if (Visited.size() > 30) {
7317 Precise = false;
7318 return;
7319 }
7320 Worklist.push_back(S);
7321 };
7322
7323 for (const auto *S : Ops)
7324 pushOp(S);
7325
7326 const Instruction *Bound = nullptr;
7327 while (!Worklist.empty()) {
7328 auto *S = Worklist.pop_back_val();
7329 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7330 if (!Bound || DT.dominates(Bound, DefI))
7331 Bound = DefI;
7332 } else {
7333 for (const auto *Op : S->operands())
7334 pushOp(Op);
7335 }
7336 }
7337 return Bound ? Bound : &*F.getEntryBlock().begin();
7338}
7339
7340const Instruction *
7341ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7342 bool Discard;
7343 return getDefiningScopeBound(Ops, Discard);
7344}
7345
7346bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7347 const Instruction *B) {
7348 if (A->getParent() == B->getParent() &&
7350 B->getIterator()))
7351 return true;
7352
7353 auto *BLoop = LI.getLoopFor(B->getParent());
7354 if (BLoop && BLoop->getHeader() == B->getParent() &&
7355 BLoop->getLoopPreheader() == A->getParent() &&
7357 A->getParent()->end()) &&
7358 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7359 B->getIterator()))
7360 return true;
7361 return false;
7362}
7363
7364bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7365 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7366 visitAll(Op, PC);
7367 return PC.MaybePoison.empty();
7368}
7369
7370bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7371 return !SCEVExprContains(Op, [this](const SCEV *S) {
7372 const SCEV *Op1;
7373 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7374 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7375 // is a non-zero constant, we have to assume the UDiv may be UB.
7376 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7377 });
7378}
7379
7380bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7381 // Only proceed if we can prove that I does not yield poison.
7383 return false;
7384
7385 // At this point we know that if I is executed, then it does not wrap
7386 // according to at least one of NSW or NUW. If I is not executed, then we do
7387 // not know if the calculation that I represents would wrap. Multiple
7388 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7389 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7390 // derived from other instructions that map to the same SCEV. We cannot make
7391 // that guarantee for cases where I is not executed. So we need to find a
7392 // upper bound on the defining scope for the SCEV, and prove that I is
7393 // executed every time we enter that scope. When the bounding scope is a
7394 // loop (the common case), this is equivalent to proving I executes on every
7395 // iteration of that loop.
7397 for (const Use &Op : I->operands()) {
7398 // I could be an extractvalue from a call to an overflow intrinsic.
7399 // TODO: We can do better here in some cases.
7400 if (isSCEVable(Op->getType()))
7401 SCEVOps.push_back(getSCEV(Op));
7402 }
7403 auto *DefI = getDefiningScopeBound(SCEVOps);
7404 return isGuaranteedToTransferExecutionTo(DefI, I);
7405}
7406
7407bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7408 // If we know that \c I can never be poison period, then that's enough.
7409 if (isSCEVExprNeverPoison(I))
7410 return true;
7411
7412 // If the loop only has one exit, then we know that, if the loop is entered,
7413 // any instruction dominating that exit will be executed. If any such
7414 // instruction would result in UB, the addrec cannot be poison.
7415 //
7416 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7417 // also handles uses outside the loop header (they just need to dominate the
7418 // single exit).
7419
7420 auto *ExitingBB = L->getExitingBlock();
7421 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7422 return false;
7423
7426
7427 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7428 // things that are known to be poison under that assumption go on the
7429 // Worklist.
7430 KnownPoison.insert(I);
7431 Worklist.push_back(I);
7432
7433 while (!Worklist.empty()) {
7434 const Instruction *Poison = Worklist.pop_back_val();
7435
7436 for (const Use &U : Poison->uses()) {
7437 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7438 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7439 DT.dominates(PoisonUser->getParent(), ExitingBB))
7440 return true;
7441
7442 if (propagatesPoison(U) && L->contains(PoisonUser))
7443 if (KnownPoison.insert(PoisonUser).second)
7444 Worklist.push_back(PoisonUser);
7445 }
7446 }
7447
7448 return false;
7449}
7450
7451ScalarEvolution::LoopProperties
7452ScalarEvolution::getLoopProperties(const Loop *L) {
7453 using LoopProperties = ScalarEvolution::LoopProperties;
7454
7455 auto Itr = LoopPropertiesCache.find(L);
7456 if (Itr == LoopPropertiesCache.end()) {
7457 auto HasSideEffects = [](Instruction *I) {
7458 if (auto *SI = dyn_cast<StoreInst>(I))
7459 return !SI->isSimple();
7460
7461 if (I->mayThrow())
7462 return true;
7463
7464 // Non-volatile memset / memcpy do not count as side-effect for forward
7465 // progress.
7466 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7467 return false;
7468
7469 return I->mayWriteToMemory();
7470 };
7471
7472 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7473 /*HasNoSideEffects*/ true};
7474
7475 for (auto *BB : L->getBlocks())
7476 for (auto &I : *BB) {
7478 LP.HasNoAbnormalExits = false;
7479 if (HasSideEffects(&I))
7480 LP.HasNoSideEffects = false;
7481 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7482 break; // We're already as pessimistic as we can get.
7483 }
7484
7485 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7486 assert(InsertPair.second && "We just checked!");
7487 Itr = InsertPair.first;
7488 }
7489
7490 return Itr->second;
7491}
7492
7494 // A mustprogress loop without side effects must be finite.
7495 // TODO: The check used here is very conservative. It's only *specific*
7496 // side effects which are well defined in infinite loops.
7497 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7498}
7499
7500const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7501 // Worklist item with a Value and a bool indicating whether all operands have
7502 // been visited already.
7505
7506 Stack.emplace_back(V, true);
7507 Stack.emplace_back(V, false);
7508 while (!Stack.empty()) {
7509 auto E = Stack.pop_back_val();
7510 Value *CurV = E.getPointer();
7511
7512 if (getExistingSCEV(CurV))
7513 continue;
7514
7516 const SCEV *CreatedSCEV = nullptr;
7517 // If all operands have been visited already, create the SCEV.
7518 if (E.getInt()) {
7519 CreatedSCEV = createSCEV(CurV);
7520 } else {
7521 // Otherwise get the operands we need to create SCEV's for before creating
7522 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7523 // just use it.
7524 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7525 }
7526
7527 if (CreatedSCEV) {
7528 insertValueToMap(CurV, CreatedSCEV);
7529 } else {
7530 // Queue CurV for SCEV creation, followed by its's operands which need to
7531 // be constructed first.
7532 Stack.emplace_back(CurV, true);
7533 for (Value *Op : Ops)
7534 Stack.emplace_back(Op, false);
7535 }
7536 }
7537
7538 return getExistingSCEV(V);
7539}
7540
7541const SCEV *
7542ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7543 if (!isSCEVable(V->getType()))
7544 return getUnknown(V);
7545
7546 if (Instruction *I = dyn_cast<Instruction>(V)) {
7547 // Don't attempt to analyze instructions in blocks that aren't
7548 // reachable. Such instructions don't matter, and they aren't required
7549 // to obey basic rules for definitions dominating uses which this
7550 // analysis depends on.
7551 if (!DT.isReachableFromEntry(I->getParent()))
7552 return getUnknown(PoisonValue::get(V->getType()));
7553 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7554 return getConstant(CI);
7555 else if (isa<GlobalAlias>(V))
7556 return getUnknown(V);
7557 else if (!isa<ConstantExpr>(V))
7558 return getUnknown(V);
7559
7560 Operator *U = cast<Operator>(V);
7561 if (auto BO =
7562 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7563 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7564 switch (BO->Opcode) {
7565 case Instruction::Add:
7566 case Instruction::Mul: {
7567 // For additions and multiplications, traverse add/mul chains for which we
7568 // can potentially create a single SCEV, to reduce the number of
7569 // get{Add,Mul}Expr calls.
7570 do {
7571 if (BO->Op) {
7572 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7573 Ops.push_back(BO->Op);
7574 break;
7575 }
7576 }
7577 Ops.push_back(BO->RHS);
7578 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7579 dyn_cast<Instruction>(V));
7580 if (!NewBO ||
7581 (BO->Opcode == Instruction::Add &&
7582 (NewBO->Opcode != Instruction::Add &&
7583 NewBO->Opcode != Instruction::Sub)) ||
7584 (BO->Opcode == Instruction::Mul &&
7585 NewBO->Opcode != Instruction::Mul)) {
7586 Ops.push_back(BO->LHS);
7587 break;
7588 }
7589 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7590 // requires a SCEV for the LHS.
7591 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7592 auto *I = dyn_cast<Instruction>(BO->Op);
7593 if (I && programUndefinedIfPoison(I)) {
7594 Ops.push_back(BO->LHS);
7595 break;
7596 }
7597 }
7598 BO = NewBO;
7599 } while (true);
7600 return nullptr;
7601 }
7602 case Instruction::Sub:
7603 case Instruction::UDiv:
7604 case Instruction::URem:
7605 break;
7606 case Instruction::AShr:
7607 case Instruction::Shl:
7608 case Instruction::Xor:
7609 if (!IsConstArg)
7610 return nullptr;
7611 break;
7612 case Instruction::And:
7613 case Instruction::Or:
7614 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7615 return nullptr;
7616 break;
7617 case Instruction::LShr:
7618 return getUnknown(V);
7619 default:
7620 llvm_unreachable("Unhandled binop");
7621 break;
7622 }
7623
7624 Ops.push_back(BO->LHS);
7625 Ops.push_back(BO->RHS);
7626 return nullptr;
7627 }
7628
7629 switch (U->getOpcode()) {
7630 case Instruction::Trunc:
7631 case Instruction::ZExt:
7632 case Instruction::SExt:
7633 case Instruction::PtrToInt:
7634 Ops.push_back(U->getOperand(0));
7635 return nullptr;
7636
7637 case Instruction::BitCast:
7638 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7639 Ops.push_back(U->getOperand(0));
7640 return nullptr;
7641 }
7642 return getUnknown(V);
7643
7644 case Instruction::SDiv:
7645 case Instruction::SRem:
7646 Ops.push_back(U->getOperand(0));
7647 Ops.push_back(U->getOperand(1));
7648 return nullptr;
7649
7650 case Instruction::GetElementPtr:
7651 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7652 "GEP source element type must be sized");
7653 llvm::append_range(Ops, U->operands());
7654 return nullptr;
7655
7656 case Instruction::IntToPtr:
7657 return getUnknown(V);
7658
7659 case Instruction::PHI:
7660 // Keep constructing SCEVs' for phis recursively for now.
7661 return nullptr;
7662
7663 case Instruction::Select: {
7664 // Check if U is a select that can be simplified to a SCEVUnknown.
7665 auto CanSimplifyToUnknown = [this, U]() {
7666 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7667 return false;
7668
7669 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7670 if (!ICI)
7671 return false;
7672 Value *LHS = ICI->getOperand(0);
7673 Value *RHS = ICI->getOperand(1);
7674 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7675 ICI->getPredicate() == CmpInst::ICMP_NE) {
7676 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7677 return true;
7678 } else if (getTypeSizeInBits(LHS->getType()) >
7679 getTypeSizeInBits(U->getType()))
7680 return true;
7681 return false;
7682 };
7683 if (CanSimplifyToUnknown())
7684 return getUnknown(U);
7685
7686 llvm::append_range(Ops, U->operands());
7687 return nullptr;
7688 break;
7689 }
7690 case Instruction::Call:
7691 case Instruction::Invoke:
7692 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7693 Ops.push_back(RV);
7694 return nullptr;
7695 }
7696
7697 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7698 switch (II->getIntrinsicID()) {
7699 case Intrinsic::abs:
7700 Ops.push_back(II->getArgOperand(0));
7701 return nullptr;
7702 case Intrinsic::umax:
7703 case Intrinsic::umin:
7704 case Intrinsic::smax:
7705 case Intrinsic::smin:
7706 case Intrinsic::usub_sat:
7707 case Intrinsic::uadd_sat:
7708 Ops.push_back(II->getArgOperand(0));
7709 Ops.push_back(II->getArgOperand(1));
7710 return nullptr;
7711 case Intrinsic::start_loop_iterations:
7712 case Intrinsic::annotation:
7713 case Intrinsic::ptr_annotation:
7714 Ops.push_back(II->getArgOperand(0));
7715 return nullptr;
7716 default:
7717 break;
7718 }
7719 }
7720 break;
7721 }
7722
7723 return nullptr;
7724}
7725
7726const SCEV *ScalarEvolution::createSCEV(Value *V) {
7727 if (!isSCEVable(V->getType()))
7728 return getUnknown(V);
7729
7730 if (Instruction *I = dyn_cast<Instruction>(V)) {
7731 // Don't attempt to analyze instructions in blocks that aren't
7732 // reachable. Such instructions don't matter, and they aren't required
7733 // to obey basic rules for definitions dominating uses which this
7734 // analysis depends on.
7735 if (!DT.isReachableFromEntry(I->getParent()))
7736 return getUnknown(PoisonValue::get(V->getType()));
7737 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7738 return getConstant(CI);
7739 else if (isa<GlobalAlias>(V))
7740 return getUnknown(V);
7741 else if (!isa<ConstantExpr>(V))
7742 return getUnknown(V);
7743
7744 const SCEV *LHS;
7745 const SCEV *RHS;
7746
7747 Operator *U = cast<Operator>(V);
7748 if (auto BO =
7749 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7750 switch (BO->Opcode) {
7751 case Instruction::Add: {
7752 // The simple thing to do would be to just call getSCEV on both operands
7753 // and call getAddExpr with the result. However if we're looking at a
7754 // bunch of things all added together, this can be quite inefficient,
7755 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7756 // Instead, gather up all the operands and make a single getAddExpr call.
7757 // LLVM IR canonical form means we need only traverse the left operands.
7759 do {
7760 if (BO->Op) {
7761 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7762 AddOps.push_back(OpSCEV);
7763 break;
7764 }
7765
7766 // If a NUW or NSW flag can be applied to the SCEV for this
7767 // addition, then compute the SCEV for this addition by itself
7768 // with a separate call to getAddExpr. We need to do that
7769 // instead of pushing the operands of the addition onto AddOps,
7770 // since the flags are only known to apply to this particular
7771 // addition - they may not apply to other additions that can be
7772 // formed with operands from AddOps.
7773 const SCEV *RHS = getSCEV(BO->RHS);
7774 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7775 if (Flags != SCEV::FlagAnyWrap) {
7776 const SCEV *LHS = getSCEV(BO->LHS);
7777 if (BO->Opcode == Instruction::Sub)
7778 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7779 else
7780 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7781 break;
7782 }
7783 }
7784
7785 if (BO->Opcode == Instruction::Sub)
7786 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7787 else
7788 AddOps.push_back(getSCEV(BO->RHS));
7789
7790 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7791 dyn_cast<Instruction>(V));
7792 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7793 NewBO->Opcode != Instruction::Sub)) {
7794 AddOps.push_back(getSCEV(BO->LHS));
7795 break;
7796 }
7797 BO = NewBO;
7798 } while (true);
7799
7800 return getAddExpr(AddOps);
7801 }
7802
7803 case Instruction::Mul: {
7805 do {
7806 if (BO->Op) {
7807 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7808 MulOps.push_back(OpSCEV);
7809 break;
7810 }
7811
7812 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7813 if (Flags != SCEV::FlagAnyWrap) {
7814 LHS = getSCEV(BO->LHS);
7815 RHS = getSCEV(BO->RHS);
7816 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7817 break;
7818 }
7819 }
7820
7821 MulOps.push_back(getSCEV(BO->RHS));
7822 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7823 dyn_cast<Instruction>(V));
7824 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7825 MulOps.push_back(getSCEV(BO->LHS));
7826 break;
7827 }
7828 BO = NewBO;
7829 } while (true);
7830
7831 return getMulExpr(MulOps);
7832 }
7833 case Instruction::UDiv:
7834 LHS = getSCEV(BO->LHS);
7835 RHS = getSCEV(BO->RHS);
7836 return getUDivExpr(LHS, RHS);
7837 case Instruction::URem:
7838 LHS = getSCEV(BO->LHS);
7839 RHS = getSCEV(BO->RHS);
7840 return getURemExpr(LHS, RHS);
7841 case Instruction::Sub: {
7843 if (BO->Op)
7844 Flags = getNoWrapFlagsFromUB(BO->Op);
7845 LHS = getSCEV(BO->LHS);
7846 RHS = getSCEV(BO->RHS);
7847 return getMinusSCEV(LHS, RHS, Flags);
7848 }
7849 case Instruction::And:
7850 // For an expression like x&255 that merely masks off the high bits,
7851 // use zext(trunc(x)) as the SCEV expression.
7852 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7853 if (CI->isZero())
7854 return getSCEV(BO->RHS);
7855 if (CI->isMinusOne())
7856 return getSCEV(BO->LHS);
7857 const APInt &A = CI->getValue();
7858
7859 // Instcombine's ShrinkDemandedConstant may strip bits out of
7860 // constants, obscuring what would otherwise be a low-bits mask.
7861 // Use computeKnownBits to compute what ShrinkDemandedConstant
7862 // knew about to reconstruct a low-bits mask value.
7863 unsigned LZ = A.countl_zero();
7864 unsigned TZ = A.countr_zero();
7865 unsigned BitWidth = A.getBitWidth();
7866 KnownBits Known(BitWidth);
7867 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7868
7869 APInt EffectiveMask =
7870 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7871 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7872 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7873 const SCEV *LHS = getSCEV(BO->LHS);
7874 const SCEV *ShiftedLHS = nullptr;
7875 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7876 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7877 // For an expression like (x * 8) & 8, simplify the multiply.
7878 unsigned MulZeros = OpC->getAPInt().countr_zero();
7879 unsigned GCD = std::min(MulZeros, TZ);
7880 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7882 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7883 append_range(MulOps, LHSMul->operands().drop_front());
7884 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7885 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7886 }
7887 }
7888 if (!ShiftedLHS)
7889 ShiftedLHS = getUDivExpr(LHS, MulCount);
7890 return getMulExpr(
7892 getTruncateExpr(ShiftedLHS,
7893 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7894 BO->LHS->getType()),
7895 MulCount);
7896 }
7897 }
7898 // Binary `and` is a bit-wise `umin`.
7899 if (BO->LHS->getType()->isIntegerTy(1)) {
7900 LHS = getSCEV(BO->LHS);
7901 RHS = getSCEV(BO->RHS);
7902 return getUMinExpr(LHS, RHS);
7903 }
7904 break;
7905
7906 case Instruction::Or:
7907 // Binary `or` is a bit-wise `umax`.
7908 if (BO->LHS->getType()->isIntegerTy(1)) {
7909 LHS = getSCEV(BO->LHS);
7910 RHS = getSCEV(BO->RHS);
7911 return getUMaxExpr(LHS, RHS);
7912 }
7913 break;
7914
7915 case Instruction::Xor:
7916 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7917 // If the RHS of xor is -1, then this is a not operation.
7918 if (CI->isMinusOne())
7919 return getNotSCEV(getSCEV(BO->LHS));
7920
7921 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7922 // This is a variant of the check for xor with -1, and it handles
7923 // the case where instcombine has trimmed non-demanded bits out
7924 // of an xor with -1.
7925 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7926 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7927 if (LBO->getOpcode() == Instruction::And &&
7928 LCI->getValue() == CI->getValue())
7929 if (const SCEVZeroExtendExpr *Z =
7930 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7931 Type *UTy = BO->LHS->getType();
7932 const SCEV *Z0 = Z->getOperand();
7933 Type *Z0Ty = Z0->getType();
7934 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7935
7936 // If C is a low-bits mask, the zero extend is serving to
7937 // mask off the high bits. Complement the operand and
7938 // re-apply the zext.
7939 if (CI->getValue().isMask(Z0TySize))
7940 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7941
7942 // If C is a single bit, it may be in the sign-bit position
7943 // before the zero-extend. In this case, represent the xor
7944 // using an add, which is equivalent, and re-apply the zext.
7945 APInt Trunc = CI->getValue().trunc(Z0TySize);
7946 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7947 Trunc.isSignMask())
7948 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7949 UTy);
7950 }
7951 }
7952 break;
7953
7954 case Instruction::Shl:
7955 // Turn shift left of a constant amount into a multiply.
7956 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7957 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7958
7959 // If the shift count is not less than the bitwidth, the result of
7960 // the shift is undefined. Don't try to analyze it, because the
7961 // resolution chosen here may differ from the resolution chosen in
7962 // other parts of the compiler.
7963 if (SA->getValue().uge(BitWidth))
7964 break;
7965
7966 // We can safely preserve the nuw flag in all cases. It's also safe to
7967 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7968 // requires special handling. It can be preserved as long as we're not
7969 // left shifting by bitwidth - 1.
7970 auto Flags = SCEV::FlagAnyWrap;
7971 if (BO->Op) {
7972 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7973 if ((MulFlags & SCEV::FlagNSW) &&
7974 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7976 if (MulFlags & SCEV::FlagNUW)
7978 }
7979
7980 ConstantInt *X = ConstantInt::get(
7981 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7982 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7983 }
7984 break;
7985
7986 case Instruction::AShr:
7987 // AShr X, C, where C is a constant.
7988 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7989 if (!CI)
7990 break;
7991
7992 Type *OuterTy = BO->LHS->getType();
7994 // If the shift count is not less than the bitwidth, the result of
7995 // the shift is undefined. Don't try to analyze it, because the
7996 // resolution chosen here may differ from the resolution chosen in
7997 // other parts of the compiler.
7998 if (CI->getValue().uge(BitWidth))
7999 break;
8000
8001 if (CI->isZero())
8002 return getSCEV(BO->LHS); // shift by zero --> noop
8003
8004 uint64_t AShrAmt = CI->getZExtValue();
8005 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8006
8007 Operator *L = dyn_cast<Operator>(BO->LHS);
8008 const SCEV *AddTruncateExpr = nullptr;
8009 ConstantInt *ShlAmtCI = nullptr;
8010 const SCEV *AddConstant = nullptr;
8011
8012 if (L && L->getOpcode() == Instruction::Add) {
8013 // X = Shl A, n
8014 // Y = Add X, c
8015 // Z = AShr Y, m
8016 // n, c and m are constants.
8017
8018 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8019 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8020 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8021 if (AddOperandCI) {
8022 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8023 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8024 // since we truncate to TruncTy, the AddConstant should be of the
8025 // same type, so create a new Constant with type same as TruncTy.
8026 // Also, the Add constant should be shifted right by AShr amount.
8027 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8028 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8029 // we model the expression as sext(add(trunc(A), c << n)), since the
8030 // sext(trunc) part is already handled below, we create a
8031 // AddExpr(TruncExp) which will be used later.
8032 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8033 }
8034 }
8035 } else if (L && L->getOpcode() == Instruction::Shl) {
8036 // X = Shl A, n
8037 // Y = AShr X, m
8038 // Both n and m are constant.
8039
8040 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8041 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8042 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8043 }
8044
8045 if (AddTruncateExpr && ShlAmtCI) {
8046 // We can merge the two given cases into a single SCEV statement,
8047 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8048 // a simpler case. The following code handles the two cases:
8049 //
8050 // 1) For a two-shift sext-inreg, i.e. n = m,
8051 // use sext(trunc(x)) as the SCEV expression.
8052 //
8053 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8054 // expression. We already checked that ShlAmt < BitWidth, so
8055 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8056 // ShlAmt - AShrAmt < Amt.
8057 const APInt &ShlAmt = ShlAmtCI->getValue();
8058 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8060 ShlAmtCI->getZExtValue() - AShrAmt);
8061 const SCEV *CompositeExpr =
8062 getMulExpr(AddTruncateExpr, getConstant(Mul));
8063 if (L->getOpcode() != Instruction::Shl)
8064 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8065
8066 return getSignExtendExpr(CompositeExpr, OuterTy);
8067 }
8068 }
8069 break;
8070 }
8071 }
8072
8073 switch (U->getOpcode()) {
8074 case Instruction::Trunc:
8075 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8076
8077 case Instruction::ZExt:
8078 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8079
8080 case Instruction::SExt:
8081 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8082 dyn_cast<Instruction>(V))) {
8083 // The NSW flag of a subtract does not always survive the conversion to
8084 // A + (-1)*B. By pushing sign extension onto its operands we are much
8085 // more likely to preserve NSW and allow later AddRec optimisations.
8086 //
8087 // NOTE: This is effectively duplicating this logic from getSignExtend:
8088 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8089 // but by that point the NSW information has potentially been lost.
8090 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8091 Type *Ty = U->getType();
8092 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8093 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8094 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8095 }
8096 }
8097 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8098
8099 case Instruction::BitCast:
8100 // BitCasts are no-op casts so we just eliminate the cast.
8101 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8102 return getSCEV(U->getOperand(0));
8103 break;
8104
8105 case Instruction::PtrToInt: {
8106 // Pointer to integer cast is straight-forward, so do model it.
8107 const SCEV *Op = getSCEV(U->getOperand(0));
8108 Type *DstIntTy = U->getType();
8109 // But only if effective SCEV (integer) type is wide enough to represent
8110 // all possible pointer values.
8111 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8112 if (isa<SCEVCouldNotCompute>(IntOp))
8113 return getUnknown(V);
8114 return IntOp;
8115 }
8116 case Instruction::IntToPtr:
8117 // Just don't deal with inttoptr casts.
8118 return getUnknown(V);
8119
8120 case Instruction::SDiv:
8121 // If both operands are non-negative, this is just an udiv.
8122 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8123 isKnownNonNegative(getSCEV(U->getOperand(1))))
8124 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8125 break;
8126
8127 case Instruction::SRem:
8128 // If both operands are non-negative, this is just an urem.
8129 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8130 isKnownNonNegative(getSCEV(U->getOperand(1))))
8131 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8132 break;
8133
8134 case Instruction::GetElementPtr:
8135 return createNodeForGEP(cast<GEPOperator>(U));
8136
8137 case Instruction::PHI:
8138 return createNodeForPHI(cast<PHINode>(U));
8139
8140 case Instruction::Select:
8141 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8142 U->getOperand(2));
8143
8144 case Instruction::Call:
8145 case Instruction::Invoke:
8146 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8147 return getSCEV(RV);
8148
8149 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8150 switch (II->getIntrinsicID()) {
8151 case Intrinsic::abs:
8152 return getAbsExpr(
8153 getSCEV(II->getArgOperand(0)),
8154 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8155 case Intrinsic::umax:
8156 LHS = getSCEV(II->getArgOperand(0));
8157 RHS = getSCEV(II->getArgOperand(1));
8158 return getUMaxExpr(LHS, RHS);
8159 case Intrinsic::umin:
8160 LHS = getSCEV(II->getArgOperand(0));
8161 RHS = getSCEV(II->getArgOperand(1));
8162 return getUMinExpr(LHS, RHS);
8163 case Intrinsic::smax:
8164 LHS = getSCEV(II->getArgOperand(0));
8165 RHS = getSCEV(II->getArgOperand(1));
8166 return getSMaxExpr(LHS, RHS);
8167 case Intrinsic::smin:
8168 LHS = getSCEV(II->getArgOperand(0));
8169 RHS = getSCEV(II->getArgOperand(1));
8170 return getSMinExpr(LHS, RHS);
8171 case Intrinsic::usub_sat: {
8172 const SCEV *X = getSCEV(II->getArgOperand(0));
8173 const SCEV *Y = getSCEV(II->getArgOperand(1));
8174 const SCEV *ClampedY = getUMinExpr(X, Y);
8175 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8176 }
8177 case Intrinsic::uadd_sat: {
8178 const SCEV *X = getSCEV(II->getArgOperand(0));
8179 const SCEV *Y = getSCEV(II->getArgOperand(1));
8180 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8181 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8182 }
8183 case Intrinsic::start_loop_iterations:
8184 case Intrinsic::annotation:
8185 case Intrinsic::ptr_annotation:
8186 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8187 // just eqivalent to the first operand for SCEV purposes.
8188 return getSCEV(II->getArgOperand(0));
8189 case Intrinsic::vscale:
8190 return getVScale(II->getType());
8191 default:
8192 break;
8193 }
8194 }
8195 break;
8196 }
8197
8198 return getUnknown(V);
8199}
8200
8201//===----------------------------------------------------------------------===//
8202// Iteration Count Computation Code
8203//
8204
8206 if (isa<SCEVCouldNotCompute>(ExitCount))
8207 return getCouldNotCompute();
8208
8209 auto *ExitCountType = ExitCount->getType();
8210 assert(ExitCountType->isIntegerTy());
8211 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8212 1 + ExitCountType->getScalarSizeInBits());
8213 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8214}
8215
8217 Type *EvalTy,
8218 const Loop *L) {
8219 if (isa<SCEVCouldNotCompute>(ExitCount))
8220 return getCouldNotCompute();
8221
8222 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8223 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8224
8225 auto CanAddOneWithoutOverflow = [&]() {
8226 ConstantRange ExitCountRange =
8227 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8228 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8229 return true;
8230
8231 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8232 getMinusOne(ExitCount->getType()));
8233 };
8234
8235 // If we need to zero extend the backedge count, check if we can add one to
8236 // it prior to zero extending without overflow. Provided this is safe, it
8237 // allows better simplification of the +1.
8238 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8239 return getZeroExtendExpr(
8240 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8241
8242 // Get the total trip count from the count by adding 1. This may wrap.
8243 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8244}
8245
8246static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8247 if (!ExitCount)
8248 return 0;
8249
8250 ConstantInt *ExitConst = ExitCount->getValue();
8251
8252 // Guard against huge trip counts.
8253 if (ExitConst->getValue().getActiveBits() > 32)
8254 return 0;
8255
8256 // In case of integer overflow, this returns 0, which is correct.
8257 return ((unsigned)ExitConst->getZExtValue()) + 1;
8258}
8259
8261 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8262 return getConstantTripCount(ExitCount);
8263}
8264
8265unsigned
8267 const BasicBlock *ExitingBlock) {
8268 assert(ExitingBlock && "Must pass a non-null exiting block!");
8269 assert(L->isLoopExiting(ExitingBlock) &&
8270 "Exiting block must actually branch out of the loop!");
8271 const SCEVConstant *ExitCount =
8272 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8273 return getConstantTripCount(ExitCount);
8274}
8275
8277 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8278
8279 const auto *MaxExitCount =
8280 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8282 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8283}
8284
8286 SmallVector<BasicBlock *, 8> ExitingBlocks;
8287 L->getExitingBlocks(ExitingBlocks);
8288
8289 std::optional<unsigned> Res;
8290 for (auto *ExitingBB : ExitingBlocks) {
8291 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8292 if (!Res)
8293 Res = Multiple;
8294 Res = std::gcd(*Res, Multiple);
8295 }
8296 return Res.value_or(1);
8297}
8298
8300 const SCEV *ExitCount) {
8301 if (isa<SCEVCouldNotCompute>(ExitCount))
8302 return 1;
8303
8304 // Get the trip count
8305 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8306
8307 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8308 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8309 // the greatest power of 2 divisor less than 2^32.
8310 return Multiple.getActiveBits() > 32
8311 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8312 : (unsigned)Multiple.getZExtValue();
8313}
8314
8315/// Returns the largest constant divisor of the trip count of this loop as a
8316/// normal unsigned value, if possible. This means that the actual trip count is
8317/// always a multiple of the returned value (don't forget the trip count could
8318/// very well be zero as well!).
8319///
8320/// Returns 1 if the trip count is unknown or not guaranteed to be the
8321/// multiple of a constant (which is also the case if the trip count is simply
8322/// constant, use getSmallConstantTripCount for that case), Will also return 1
8323/// if the trip count is very large (>= 2^32).
8324///
8325/// As explained in the comments for getSmallConstantTripCount, this assumes
8326/// that control exits the loop via ExitingBlock.
8327unsigned
8329 const BasicBlock *ExitingBlock) {
8330 assert(ExitingBlock && "Must pass a non-null exiting block!");
8331 assert(L->isLoopExiting(ExitingBlock) &&
8332 "Exiting block must actually branch out of the loop!");
8333 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8334 return getSmallConstantTripMultiple(L, ExitCount);
8335}
8336
8338 const BasicBlock *ExitingBlock,
8339 ExitCountKind Kind) {
8340 switch (Kind) {
8341 case Exact:
8342 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8343 case SymbolicMaximum:
8344 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8345 case ConstantMaximum:
8346 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8347 };
8348 llvm_unreachable("Invalid ExitCountKind!");
8349}
8350
8352 const Loop *L, const BasicBlock *ExitingBlock,
8354 switch (Kind) {
8355 case Exact:
8356 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8357 Predicates);
8358 case SymbolicMaximum:
8359 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8360 Predicates);
8361 case ConstantMaximum:
8362 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8363 Predicates);
8364 };
8365 llvm_unreachable("Invalid ExitCountKind!");
8366}
8367
8370 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8371}
8372
8374 ExitCountKind Kind) {
8375 switch (Kind) {
8376 case Exact:
8377 return getBackedgeTakenInfo(L).getExact(L, this);
8378 case ConstantMaximum:
8379 return getBackedgeTakenInfo(L).getConstantMax(this);
8380 case SymbolicMaximum:
8381 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8382 };
8383 llvm_unreachable("Invalid ExitCountKind!");
8384}
8385
8388 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8389}
8390
8393 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8394}
8395
8397 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8398}
8399
8400/// Push PHI nodes in the header of the given loop onto the given Worklist.
8401static void PushLoopPHIs(const Loop *L,
8404 BasicBlock *Header = L->getHeader();
8405
8406 // Push all Loop-header PHIs onto the Worklist stack.
8407 for (PHINode &PN : Header->phis())
8408 if (Visited.insert(&PN).second)
8409 Worklist.push_back(&PN);
8410}
8411
8412ScalarEvolution::BackedgeTakenInfo &
8413ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8414 auto &BTI = getBackedgeTakenInfo(L);
8415 if (BTI.hasFullInfo())
8416 return BTI;
8417
8418 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8419
8420 if (!Pair.second)
8421 return Pair.first->second;
8422
8423 BackedgeTakenInfo Result =
8424 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8425
8426 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8427}
8428
8429ScalarEvolution::BackedgeTakenInfo &
8430ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8431 // Initially insert an invalid entry for this loop. If the insertion
8432 // succeeds, proceed to actually compute a backedge-taken count and
8433 // update the value. The temporary CouldNotCompute value tells SCEV
8434 // code elsewhere that it shouldn't attempt to request a new
8435 // backedge-taken count, which could result in infinite recursion.
8436 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8437 BackedgeTakenCounts.try_emplace(L);
8438 if (!Pair.second)
8439 return Pair.first->second;
8440
8441 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8442 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8443 // must be cleared in this scope.
8444 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8445
8446 // Now that we know more about the trip count for this loop, forget any
8447 // existing SCEV values for PHI nodes in this loop since they are only
8448 // conservative estimates made without the benefit of trip count
8449 // information. This invalidation is not necessary for correctness, and is
8450 // only done to produce more precise results.
8451 if (Result.hasAnyInfo()) {
8452 // Invalidate any expression using an addrec in this loop.
8454 auto LoopUsersIt = LoopUsers.find(L);
8455 if (LoopUsersIt != LoopUsers.end())
8456 append_range(ToForget, LoopUsersIt->second);
8457 forgetMemoizedResults(ToForget);
8458
8459 // Invalidate constant-evolved loop header phis.
8460 for (PHINode &PN : L->getHeader()->phis())
8461 ConstantEvolutionLoopExitValue.erase(&PN);
8462 }
8463
8464 // Re-lookup the insert position, since the call to
8465 // computeBackedgeTakenCount above could result in a
8466 // recusive call to getBackedgeTakenInfo (on a different
8467 // loop), which would invalidate the iterator computed
8468 // earlier.
8469 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8470}
8471
8473 // This method is intended to forget all info about loops. It should
8474 // invalidate caches as if the following happened:
8475 // - The trip counts of all loops have changed arbitrarily
8476 // - Every llvm::Value has been updated in place to produce a different
8477 // result.
8478 BackedgeTakenCounts.clear();
8479 PredicatedBackedgeTakenCounts.clear();
8480 BECountUsers.clear();
8481 LoopPropertiesCache.clear();
8482 ConstantEvolutionLoopExitValue.clear();
8483 ValueExprMap.clear();
8484 ValuesAtScopes.clear();
8485 ValuesAtScopesUsers.clear();
8486 LoopDispositions.clear();
8487 BlockDispositions.clear();
8488 UnsignedRanges.clear();
8489 SignedRanges.clear();
8490 ExprValueMap.clear();
8491 HasRecMap.clear();
8492 ConstantMultipleCache.clear();
8493 PredicatedSCEVRewrites.clear();
8494 FoldCache.clear();
8495 FoldCacheUser.clear();
8496}
8497void ScalarEvolution::visitAndClearUsers(
8501 while (!Worklist.empty()) {
8502 Instruction *I = Worklist.pop_back_val();
8503 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8504 continue;
8505
8507 ValueExprMap.find_as(static_cast<Value *>(I));
8508 if (It != ValueExprMap.end()) {
8509 eraseValueFromMap(It->first);
8510 ToForget.push_back(It->second);
8511 if (PHINode *PN = dyn_cast<PHINode>(I))
8512 ConstantEvolutionLoopExitValue.erase(PN);
8513 }
8514
8515 PushDefUseChildren(I, Worklist, Visited);
8516 }
8517}
8518
8520 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8524
8525 // Iterate over all the loops and sub-loops to drop SCEV information.
8526 while (!LoopWorklist.empty()) {
8527 auto *CurrL = LoopWorklist.pop_back_val();
8528
8529 // Drop any stored trip count value.
8530 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8531 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8532
8533 // Drop information about predicated SCEV rewrites for this loop.
8534 for (auto I = PredicatedSCEVRewrites.begin();
8535 I != PredicatedSCEVRewrites.end();) {
8536 std::pair<const SCEV *, const Loop *> Entry = I->first;
8537 if (Entry.second == CurrL)
8538 PredicatedSCEVRewrites.erase(I++);
8539 else
8540 ++I;
8541 }
8542
8543 auto LoopUsersItr = LoopUsers.find(CurrL);
8544 if (LoopUsersItr != LoopUsers.end())
8545 llvm::append_range(ToForget, LoopUsersItr->second);
8546
8547 // Drop information about expressions based on loop-header PHIs.
8548 PushLoopPHIs(CurrL, Worklist, Visited);
8549 visitAndClearUsers(Worklist, Visited, ToForget);
8550
8551 LoopPropertiesCache.erase(CurrL);
8552 // Forget all contained loops too, to avoid dangling entries in the
8553 // ValuesAtScopes map.
8554 LoopWorklist.append(CurrL->begin(), CurrL->end());
8555 }
8556 forgetMemoizedResults(ToForget);
8557}
8558
8560 forgetLoop(L->getOutermostLoop());
8561}
8562
8564 Instruction *I = dyn_cast<Instruction>(V);
8565 if (!I) return;
8566
8567 // Drop information about expressions based on loop-header PHIs.
8571 Worklist.push_back(I);
8572 Visited.insert(I);
8573 visitAndClearUsers(Worklist, Visited, ToForget);
8574
8575 forgetMemoizedResults(ToForget);
8576}
8577
8579 if (!isSCEVable(V->getType()))
8580 return;
8581
8582 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8583 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8584 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8585 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8586 if (const SCEV *S = getExistingSCEV(V)) {
8587 struct InvalidationRootCollector {
8588 Loop *L;
8590
8591 InvalidationRootCollector(Loop *L) : L(L) {}
8592
8593 bool follow(const SCEV *S) {
8594 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8595 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8596 if (L->contains(I))
8597 Roots.push_back(S);
8598 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8599 if (L->contains(AddRec->getLoop()))
8600 Roots.push_back(S);
8601 }
8602 return true;
8603 }
8604 bool isDone() const { return false; }
8605 };
8606
8607 InvalidationRootCollector C(L);
8608 visitAll(S, C);
8609 forgetMemoizedResults(C.Roots);
8610 }
8611
8612 // Also perform the normal invalidation.
8613 forgetValue(V);
8614}
8615
8616void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8617
8619 // Unless a specific value is passed to invalidation, completely clear both
8620 // caches.
8621 if (!V) {
8622 BlockDispositions.clear();
8623 LoopDispositions.clear();
8624 return;
8625 }
8626
8627 if (!isSCEVable(V->getType()))
8628 return;
8629
8630 const SCEV *S = getExistingSCEV(V);
8631 if (!S)
8632 return;
8633
8634 // Invalidate the block and loop dispositions cached for S. Dispositions of
8635 // S's users may change if S's disposition changes (i.e. a user may change to
8636 // loop-invariant, if S changes to loop invariant), so also invalidate
8637 // dispositions of S's users recursively.
8638 SmallVector<const SCEV *, 8> Worklist = {S};
8640 while (!Worklist.empty()) {
8641 const SCEV *Curr = Worklist.pop_back_val();
8642 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8643 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8644 if (!LoopDispoRemoved && !BlockDispoRemoved)
8645 continue;
8646 auto Users = SCEVUsers.find(Curr);
8647 if (Users != SCEVUsers.end())
8648 for (const auto *User : Users->second)
8649 if (Seen.insert(User).second)
8650 Worklist.push_back(User);
8651 }
8652}
8653
8654/// Get the exact loop backedge taken count considering all loop exits. A
8655/// computable result can only be returned for loops with all exiting blocks
8656/// dominating the latch. howFarToZero assumes that the limit of each loop test
8657/// is never skipped. This is a valid assumption as long as the loop exits via
8658/// that test. For precise results, it is the caller's responsibility to specify
8659/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8660const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8661 const Loop *L, ScalarEvolution *SE,
8663 // If any exits were not computable, the loop is not computable.
8664 if (!isComplete() || ExitNotTaken.empty())
8665 return SE->getCouldNotCompute();
8666
8667 const BasicBlock *Latch = L->getLoopLatch();
8668 // All exiting blocks we have collected must dominate the only backedge.
8669 if (!Latch)
8670 return SE->getCouldNotCompute();
8671
8672 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8673 // count is simply a minimum out of all these calculated exit counts.
8675 for (const auto &ENT : ExitNotTaken) {
8676 const SCEV *BECount = ENT.ExactNotTaken;
8677 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8678 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8679 "We should only have known counts for exiting blocks that dominate "
8680 "latch!");
8681
8682 Ops.push_back(BECount);
8683
8684 if (Preds)
8685 append_range(*Preds, ENT.Predicates);
8686
8687 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8688 "Predicate should be always true!");
8689 }
8690
8691 // If an earlier exit exits on the first iteration (exit count zero), then
8692 // a later poison exit count should not propagate into the result. This are
8693 // exactly the semantics provided by umin_seq.
8694 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8695}
8696
8697const ScalarEvolution::ExitNotTakenInfo *
8698ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8699 const BasicBlock *ExitingBlock,
8700 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8701 for (const auto &ENT : ExitNotTaken)
8702 if (ENT.ExitingBlock == ExitingBlock) {
8703 if (ENT.hasAlwaysTruePredicate())
8704 return &ENT;
8705 else if (Predicates) {
8706 append_range(*Predicates, ENT.Predicates);
8707 return &ENT;
8708 }
8709 }
8710
8711 return nullptr;
8712}
8713
8714/// getConstantMax - Get the constant max backedge taken count for the loop.
8715const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8716 ScalarEvolution *SE,
8717 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8718 if (!getConstantMax())
8719 return SE->getCouldNotCompute();
8720
8721 for (const auto &ENT : ExitNotTaken)
8722 if (!ENT.hasAlwaysTruePredicate()) {
8723 if (!Predicates)
8724 return SE->getCouldNotCompute();
8725 append_range(*Predicates, ENT.Predicates);
8726 }
8727
8728 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8729 isa<SCEVConstant>(getConstantMax())) &&
8730 "No point in having a non-constant max backedge taken count!");
8731 return getConstantMax();
8732}
8733
8734const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8735 const Loop *L, ScalarEvolution *SE,
8737 if (!SymbolicMax) {
8738 // Form an expression for the maximum exit count possible for this loop. We
8739 // merge the max and exact information to approximate a version of
8740 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8741 // constants.
8743
8744 for (const auto &ENT : ExitNotTaken) {
8745 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8746 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8747 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8748 "We should only have known counts for exiting blocks that "
8749 "dominate latch!");
8750 ExitCounts.push_back(ExitCount);
8751 if (Predicates)
8752 append_range(*Predicates, ENT.Predicates);
8753
8754 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8755 "Predicate should be always true!");
8756 }
8757 }
8758 if (ExitCounts.empty())
8759 SymbolicMax = SE->getCouldNotCompute();
8760 else
8761 SymbolicMax =
8762 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8763 }
8764 return SymbolicMax;
8765}
8766
8767bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8768 ScalarEvolution *SE) const {
8769 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8770 return !ENT.hasAlwaysTruePredicate();
8771 };
8772 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8773}
8774
8776 : ExitLimit(E, E, E, false) {}
8777
8779 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8780 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8782 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8783 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8784 // If we prove the max count is zero, so is the symbolic bound. This happens
8785 // in practice due to differences in a) how context sensitive we've chosen
8786 // to be and b) how we reason about bounds implied by UB.
8787 if (ConstantMaxNotTaken->isZero()) {
8789 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8790 }
8791
8792 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8793 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8794 "Exact is not allowed to be less precise than Constant Max");
8795 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8796 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8797 "Exact is not allowed to be less precise than Symbolic Max");
8798 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8799 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8800 "Symbolic Max is not allowed to be less precise than Constant Max");
8801 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8802 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8803 "No point in having a non-constant max backedge taken count!");
8805 for (const auto PredList : PredLists)
8806 for (const auto *P : PredList) {
8807 if (SeenPreds.contains(P))
8808 continue;
8809 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8810 SeenPreds.insert(P);
8811 Predicates.push_back(P);
8812 }
8813 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8814 "Backedge count should be int");
8815 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8817 "Max backedge count should be int");
8818}
8819
8821 const SCEV *ConstantMaxNotTaken,
8822 const SCEV *SymbolicMaxNotTaken,
8823 bool MaxOrZero,
8825 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8826 ArrayRef({PredList})) {}
8827
8828/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8829/// computable exit into a persistent ExitNotTakenInfo array.
8830ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8832 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8833 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8834 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8835
8836 ExitNotTaken.reserve(ExitCounts.size());
8837 std::transform(ExitCounts.begin(), ExitCounts.end(),
8838 std::back_inserter(ExitNotTaken),
8839 [&](const EdgeExitInfo &EEI) {
8840 BasicBlock *ExitBB = EEI.first;
8841 const ExitLimit &EL = EEI.second;
8842 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8843 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8844 EL.Predicates);
8845 });
8846 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8847 isa<SCEVConstant>(ConstantMax)) &&
8848 "No point in having a non-constant max backedge taken count!");
8849}
8850
8851/// Compute the number of times the backedge of the specified loop will execute.
8852ScalarEvolution::BackedgeTakenInfo
8853ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8854 bool AllowPredicates) {
8855 SmallVector<BasicBlock *, 8> ExitingBlocks;
8856 L->getExitingBlocks(ExitingBlocks);
8857
8858 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8859
8861 bool CouldComputeBECount = true;
8862 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8863 const SCEV *MustExitMaxBECount = nullptr;
8864 const SCEV *MayExitMaxBECount = nullptr;
8865 bool MustExitMaxOrZero = false;
8866 bool IsOnlyExit = ExitingBlocks.size() == 1;
8867
8868 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8869 // and compute maxBECount.
8870 // Do a union of all the predicates here.
8871 for (BasicBlock *ExitBB : ExitingBlocks) {
8872 // We canonicalize untaken exits to br (constant), ignore them so that
8873 // proving an exit untaken doesn't negatively impact our ability to reason
8874 // about the loop as whole.
8875 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8876 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8877 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8878 if (ExitIfTrue == CI->isZero())
8879 continue;
8880 }
8881
8882 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8883
8884 assert((AllowPredicates || EL.Predicates.empty()) &&
8885 "Predicated exit limit when predicates are not allowed!");
8886
8887 // 1. For each exit that can be computed, add an entry to ExitCounts.
8888 // CouldComputeBECount is true only if all exits can be computed.
8889 if (EL.ExactNotTaken != getCouldNotCompute())
8890 ++NumExitCountsComputed;
8891 else
8892 // We couldn't compute an exact value for this exit, so
8893 // we won't be able to compute an exact value for the loop.
8894 CouldComputeBECount = false;
8895 // Remember exit count if either exact or symbolic is known. Because
8896 // Exact always implies symbolic, only check symbolic.
8897 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8898 ExitCounts.emplace_back(ExitBB, EL);
8899 else {
8900 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8901 "Exact is known but symbolic isn't?");
8902 ++NumExitCountsNotComputed;
8903 }
8904
8905 // 2. Derive the loop's MaxBECount from each exit's max number of
8906 // non-exiting iterations. Partition the loop exits into two kinds:
8907 // LoopMustExits and LoopMayExits.
8908 //
8909 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8910 // is a LoopMayExit. If any computable LoopMustExit is found, then
8911 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8912 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8913 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8914 // any
8915 // computable EL.ConstantMaxNotTaken.
8916 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8917 DT.dominates(ExitBB, Latch)) {
8918 if (!MustExitMaxBECount) {
8919 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8920 MustExitMaxOrZero = EL.MaxOrZero;
8921 } else {
8922 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8923 EL.ConstantMaxNotTaken);
8924 }
8925 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8926 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8927 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8928 else {
8929 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8930 EL.ConstantMaxNotTaken);
8931 }
8932 }
8933 }
8934 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8935 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8936 // The loop backedge will be taken the maximum or zero times if there's
8937 // a single exit that must be taken the maximum or zero times.
8938 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8939
8940 // Remember which SCEVs are used in exit limits for invalidation purposes.
8941 // We only care about non-constant SCEVs here, so we can ignore
8942 // EL.ConstantMaxNotTaken
8943 // and MaxBECount, which must be SCEVConstant.
8944 for (const auto &Pair : ExitCounts) {
8945 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8946 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8947 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8948 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8949 {L, AllowPredicates});
8950 }
8951 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8952 MaxBECount, MaxOrZero);
8953}
8954
8956ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8957 bool IsOnlyExit, bool AllowPredicates) {
8958 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8959 // If our exiting block does not dominate the latch, then its connection with
8960 // loop's exit limit may be far from trivial.
8961 const BasicBlock *Latch = L->getLoopLatch();
8962 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8963 return getCouldNotCompute();
8964
8965 Instruction *Term = ExitingBlock->getTerminator();
8966 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8967 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8968 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8969 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8970 "It should have one successor in loop and one exit block!");
8971 // Proceed to the next level to examine the exit condition expression.
8972 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8973 /*ControlsOnlyExit=*/IsOnlyExit,
8974 AllowPredicates);
8975 }
8976
8977 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8978 // For switch, make sure that there is a single exit from the loop.
8979 BasicBlock *Exit = nullptr;
8980 for (auto *SBB : successors(ExitingBlock))
8981 if (!L->contains(SBB)) {
8982 if (Exit) // Multiple exit successors.
8983 return getCouldNotCompute();
8984 Exit = SBB;
8985 }
8986 assert(Exit && "Exiting block must have at least one exit");
8987 return computeExitLimitFromSingleExitSwitch(
8988 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8989 }
8990
8991 return getCouldNotCompute();
8992}
8993
8995 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8996 bool AllowPredicates) {
8997 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8998 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8999 ControlsOnlyExit, AllowPredicates);
9000}
9001
9002std::optional<ScalarEvolution::ExitLimit>
9003ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9004 bool ExitIfTrue, bool ControlsOnlyExit,
9005 bool AllowPredicates) {
9006 (void)this->L;
9007 (void)this->ExitIfTrue;
9008 (void)this->AllowPredicates;
9009
9010 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9011 this->AllowPredicates == AllowPredicates &&
9012 "Variance in assumed invariant key components!");
9013 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9014 if (Itr == TripCountMap.end())
9015 return std::nullopt;
9016 return Itr->second;
9017}
9018
9019void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9020 bool ExitIfTrue,
9021 bool ControlsOnlyExit,
9022 bool AllowPredicates,
9023 const ExitLimit &EL) {
9024 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9025 this->AllowPredicates == AllowPredicates &&
9026 "Variance in assumed invariant key components!");
9027
9028 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9029 assert(InsertResult.second && "Expected successful insertion!");
9030 (void)InsertResult;
9031 (void)ExitIfTrue;
9032}
9033
9034ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9035 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9036 bool ControlsOnlyExit, bool AllowPredicates) {
9037
9038 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9039 AllowPredicates))
9040 return *MaybeEL;
9041
9042 ExitLimit EL = computeExitLimitFromCondImpl(
9043 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9044 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9045 return EL;
9046}
9047
9048ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9049 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9050 bool ControlsOnlyExit, bool AllowPredicates) {
9051 // Handle BinOp conditions (And, Or).
9052 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9053 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9054 return *LimitFromBinOp;
9055
9056 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9057 // Proceed to the next level to examine the icmp.
9058 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9059 ExitLimit EL =
9060 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9061 if (EL.hasFullInfo() || !AllowPredicates)
9062 return EL;
9063
9064 // Try again, but use SCEV predicates this time.
9065 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9066 ControlsOnlyExit,
9067 /*AllowPredicates=*/true);
9068 }
9069
9070 // Check for a constant condition. These are normally stripped out by
9071 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9072 // preserve the CFG and is temporarily leaving constant conditions
9073 // in place.
9074 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9075 if (ExitIfTrue == !CI->getZExtValue())
9076 // The backedge is always taken.
9077 return getCouldNotCompute();
9078 // The backedge is never taken.
9079 return getZero(CI->getType());
9080 }
9081
9082 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9083 // with a constant step, we can form an equivalent icmp predicate and figure
9084 // out how many iterations will be taken before we exit.
9085 const WithOverflowInst *WO;
9086 const APInt *C;
9087 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9088 match(WO->getRHS(), m_APInt(C))) {
9089 ConstantRange NWR =
9091 WO->getNoWrapKind());
9092 CmpInst::Predicate Pred;
9093 APInt NewRHSC, Offset;
9094 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9095 if (!ExitIfTrue)
9096 Pred = ICmpInst::getInversePredicate(Pred);
9097 auto *LHS = getSCEV(WO->getLHS());
9098 if (Offset != 0)
9100 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9101 ControlsOnlyExit, AllowPredicates);
9102 if (EL.hasAnyInfo())
9103 return EL;
9104 }
9105
9106 // If it's not an integer or pointer comparison then compute it the hard way.
9107 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9108}
9109
9110std::optional<ScalarEvolution::ExitLimit>
9111ScalarEvolution::computeExitLimitFromCondFromBinOp(
9112 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9113 bool ControlsOnlyExit, bool AllowPredicates) {
9114 // Check if the controlling expression for this loop is an And or Or.
9115 Value *Op0, *Op1;
9116 bool IsAnd = false;
9117 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9118 IsAnd = true;
9119 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9120 IsAnd = false;
9121 else
9122 return std::nullopt;
9123
9124 // EitherMayExit is true in these two cases:
9125 // br (and Op0 Op1), loop, exit
9126 // br (or Op0 Op1), exit, loop
9127 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9128 ExitLimit EL0 = computeExitLimitFromCondCached(
9129 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9130 AllowPredicates);
9131 ExitLimit EL1 = computeExitLimitFromCondCached(
9132 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9133 AllowPredicates);
9134
9135 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9136 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9137 if (isa<ConstantInt>(Op1))
9138 return Op1 == NeutralElement ? EL0 : EL1;
9139 if (isa<ConstantInt>(Op0))
9140 return Op0 == NeutralElement ? EL1 : EL0;
9141
9142 const SCEV *BECount = getCouldNotCompute();
9143 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9144 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9145 if (EitherMayExit) {
9146 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9147 // Both conditions must be same for the loop to continue executing.
9148 // Choose the less conservative count.
9149 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9150 EL1.ExactNotTaken != getCouldNotCompute()) {
9151 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9152 UseSequentialUMin);
9153 }
9154 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9155 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9156 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9157 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9158 else
9159 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9160 EL1.ConstantMaxNotTaken);
9161 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9162 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9163 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9164 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9165 else
9166 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9167 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9168 } else {
9169 // Both conditions must be same at the same time for the loop to exit.
9170 // For now, be conservative.
9171 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9172 BECount = EL0.ExactNotTaken;
9173 }
9174
9175 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9176 // to be more aggressive when computing BECount than when computing
9177 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9178 // and
9179 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9180 // EL1.ConstantMaxNotTaken to not.
9181 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9182 !isa<SCEVCouldNotCompute>(BECount))
9183 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9184 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9185 SymbolicMaxBECount =
9186 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9187 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9188 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9189}
9190
9191ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9192 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9193 bool AllowPredicates) {
9194 // If the condition was exit on true, convert the condition to exit on false
9195 CmpPredicate Pred;
9196 if (!ExitIfTrue)
9197 Pred = ExitCond->getCmpPredicate();
9198 else
9199 Pred = ExitCond->getInverseCmpPredicate();
9200 const ICmpInst::Predicate OriginalPred = Pred;
9201
9202 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9203 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9204
9205 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9206 AllowPredicates);
9207 if (EL.hasAnyInfo())
9208 return EL;
9209
9210 auto *ExhaustiveCount =
9211 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9212
9213 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9214 return ExhaustiveCount;
9215
9216 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9217 ExitCond->getOperand(1), L, OriginalPred);
9218}
9219ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9220 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9221 bool ControlsOnlyExit, bool AllowPredicates) {
9222
9223 // Try to evaluate any dependencies out of the loop.
9224 LHS = getSCEVAtScope(LHS, L);
9225 RHS = getSCEVAtScope(RHS, L);
9226
9227 // At this point, we would like to compute how many iterations of the
9228 // loop the predicate will return true for these inputs.
9229 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9230 // If there is a loop-invariant, force it into the RHS.
9231 std::swap(LHS, RHS);
9233 }
9234
9235 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9237 // Simplify the operands before analyzing them.
9238 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9239
9240 // If we have a comparison of a chrec against a constant, try to use value
9241 // ranges to answer this query.
9242 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9243 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9244 if (AddRec->getLoop() == L) {
9245 // Form the constant range.
9246 ConstantRange CompRange =
9247 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9248
9249 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9250 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9251 }
9252
9253 // If this loop must exit based on this condition (or execute undefined
9254 // behaviour), see if we can improve wrap flags. This is essentially
9255 // a must execute style proof.
9256 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9257 // If we can prove the test sequence produced must repeat the same values
9258 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9259 // because if it did, we'd have an infinite (undefined) loop.
9260 // TODO: We can peel off any functions which are invertible *in L*. Loop
9261 // invariant terms are effectively constants for our purposes here.
9262 auto *InnerLHS = LHS;
9263 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9264 InnerLHS = ZExt->getOperand();
9265 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9266 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9267 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9268 /*OrNegative=*/true)) {
9269 auto Flags = AR->getNoWrapFlags();
9270 Flags = setFlags(Flags, SCEV::FlagNW);
9273 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9274 }
9275
9276 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9277 // From no-self-wrap, this follows trivially from the fact that every
9278 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9279 // last value before (un)signed wrap. Since we know that last value
9280 // didn't exit, nor will any smaller one.
9281 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9282 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9283 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9284 AR && AR->getLoop() == L && AR->isAffine() &&
9285 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9286 isKnownPositive(AR->getStepRecurrence(*this))) {
9287 auto Flags = AR->getNoWrapFlags();
9288 Flags = setFlags(Flags, WrapType);
9291 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9292 }
9293 }
9294 }
9295
9296 switch (Pred) {
9297 case ICmpInst::ICMP_NE: { // while (X != Y)
9298 // Convert to: while (X-Y != 0)
9299 if (LHS->getType()->isPointerTy()) {
9301 if (isa<SCEVCouldNotCompute>(LHS))
9302 return LHS;
9303 }
9304 if (RHS->getType()->isPointerTy()) {
9306 if (isa<SCEVCouldNotCompute>(RHS))
9307 return RHS;
9308 }
9309 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9310 AllowPredicates);
9311 if (EL.hasAnyInfo())
9312 return EL;
9313 break;
9314 }
9315 case ICmpInst::ICMP_EQ: { // while (X == Y)
9316 // Convert to: while (X-Y == 0)
9317 if (LHS->getType()->isPointerTy()) {
9319 if (isa<SCEVCouldNotCompute>(LHS))
9320 return LHS;
9321 }
9322 if (RHS->getType()->isPointerTy()) {
9324 if (isa<SCEVCouldNotCompute>(RHS))
9325 return RHS;
9326 }
9327 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9328 if (EL.hasAnyInfo()) return EL;
9329 break;
9330 }
9331 case ICmpInst::ICMP_SLE:
9332 case ICmpInst::ICMP_ULE:
9333 // Since the loop is finite, an invariant RHS cannot include the boundary
9334 // value, otherwise it would loop forever.
9335 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9336 !isLoopInvariant(RHS, L)) {
9337 // Otherwise, perform the addition in a wider type, to avoid overflow.
9338 // If the LHS is an addrec with the appropriate nowrap flag, the
9339 // extension will be sunk into it and the exit count can be analyzed.
9340 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9341 if (!OldType)
9342 break;
9343 // Prefer doubling the bitwidth over adding a single bit to make it more
9344 // likely that we use a legal type.
9345 auto *NewType =
9346 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9347 if (ICmpInst::isSigned(Pred)) {
9348 LHS = getSignExtendExpr(LHS, NewType);
9349 RHS = getSignExtendExpr(RHS, NewType);
9350 } else {
9351 LHS = getZeroExtendExpr(LHS, NewType);
9352 RHS = getZeroExtendExpr(RHS, NewType);
9353 }
9354 }
9355 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9356 [[fallthrough]];
9357 case ICmpInst::ICMP_SLT:
9358 case ICmpInst::ICMP_ULT: { // while (X < Y)
9359 bool IsSigned = ICmpInst::isSigned(Pred);
9360 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9361 AllowPredicates);
9362 if (EL.hasAnyInfo())
9363 return EL;
9364 break;
9365 }
9366 case ICmpInst::ICMP_SGE:
9367 case ICmpInst::ICMP_UGE:
9368 // Since the loop is finite, an invariant RHS cannot include the boundary
9369 // value, otherwise it would loop forever.
9370 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9371 !isLoopInvariant(RHS, L))
9372 break;
9373 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9374 [[fallthrough]];
9375 case ICmpInst::ICMP_SGT:
9376 case ICmpInst::ICMP_UGT: { // while (X > Y)
9377 bool IsSigned = ICmpInst::isSigned(Pred);
9378 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9379 AllowPredicates);
9380 if (EL.hasAnyInfo())
9381 return EL;
9382 break;
9383 }
9384 default:
9385 break;
9386 }
9387
9388 return getCouldNotCompute();
9389}
9390
9392ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9393 SwitchInst *Switch,
9394 BasicBlock *ExitingBlock,
9395 bool ControlsOnlyExit) {
9396 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9397
9398 // Give up if the exit is the default dest of a switch.
9399 if (Switch->getDefaultDest() == ExitingBlock)
9400 return getCouldNotCompute();
9401
9402 assert(L->contains(Switch->getDefaultDest()) &&
9403 "Default case must not exit the loop!");
9404 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9405 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9406
9407 // while (X != Y) --> while (X-Y != 0)
9408 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9409 if (EL.hasAnyInfo())
9410 return EL;
9411
9412 return getCouldNotCompute();
9413}
9414
9415static ConstantInt *
9417 ScalarEvolution &SE) {
9418 const SCEV *InVal = SE.getConstant(C);
9419 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9420 assert(isa<SCEVConstant>(Val) &&
9421 "Evaluation of SCEV at constant didn't fold correctly?");
9422 return cast<SCEVConstant>(Val)->getValue();
9423}
9424
9425ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9426 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9427 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9428 if (!RHS)
9429 return getCouldNotCompute();
9430
9431 const BasicBlock *Latch = L->getLoopLatch();
9432 if (!Latch)
9433 return getCouldNotCompute();
9434
9435 const BasicBlock *Predecessor = L->getLoopPredecessor();
9436 if (!Predecessor)
9437 return getCouldNotCompute();
9438
9439 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9440 // Return LHS in OutLHS and shift_opt in OutOpCode.
9441 auto MatchPositiveShift =
9442 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9443
9444 using namespace PatternMatch;
9445
9446 ConstantInt *ShiftAmt;
9447 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9448 OutOpCode = Instruction::LShr;
9449 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9450 OutOpCode = Instruction::AShr;
9451 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9452 OutOpCode = Instruction::Shl;
9453 else
9454 return false;
9455
9456 return ShiftAmt->getValue().isStrictlyPositive();
9457 };
9458
9459 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9460 //
9461 // loop:
9462 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9463 // %iv.shifted = lshr i32 %iv, <positive constant>
9464 //
9465 // Return true on a successful match. Return the corresponding PHI node (%iv
9466 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9467 auto MatchShiftRecurrence =
9468 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9469 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9470
9471 {
9473 Value *V;
9474
9475 // If we encounter a shift instruction, "peel off" the shift operation,
9476 // and remember that we did so. Later when we inspect %iv's backedge
9477 // value, we will make sure that the backedge value uses the same
9478 // operation.
9479 //
9480 // Note: the peeled shift operation does not have to be the same
9481 // instruction as the one feeding into the PHI's backedge value. We only
9482 // really care about it being the same *kind* of shift instruction --
9483 // that's all that is required for our later inferences to hold.
9484 if (MatchPositiveShift(LHS, V, OpC)) {
9485 PostShiftOpCode = OpC;
9486 LHS = V;
9487 }
9488 }
9489
9490 PNOut = dyn_cast<PHINode>(LHS);
9491 if (!PNOut || PNOut->getParent() != L->getHeader())
9492 return false;
9493
9494 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9495 Value *OpLHS;
9496
9497 return
9498 // The backedge value for the PHI node must be a shift by a positive
9499 // amount
9500 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9501
9502 // of the PHI node itself
9503 OpLHS == PNOut &&
9504
9505 // and the kind of shift should be match the kind of shift we peeled
9506 // off, if any.
9507 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9508 };
9509
9510 PHINode *PN;
9512 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9513 return getCouldNotCompute();
9514
9515 const DataLayout &DL = getDataLayout();
9516
9517 // The key rationale for this optimization is that for some kinds of shift
9518 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9519 // within a finite number of iterations. If the condition guarding the
9520 // backedge (in the sense that the backedge is taken if the condition is true)
9521 // is false for the value the shift recurrence stabilizes to, then we know
9522 // that the backedge is taken only a finite number of times.
9523
9524 ConstantInt *StableValue = nullptr;
9525 switch (OpCode) {
9526 default:
9527 llvm_unreachable("Impossible case!");
9528
9529 case Instruction::AShr: {
9530 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9531 // bitwidth(K) iterations.
9532 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9533 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9534 Predecessor->getTerminator(), &DT);
9535 auto *Ty = cast<IntegerType>(RHS->getType());
9536 if (Known.isNonNegative())
9537 StableValue = ConstantInt::get(Ty, 0);
9538 else if (Known.isNegative())
9539 StableValue = ConstantInt::get(Ty, -1, true);
9540 else
9541 return getCouldNotCompute();
9542
9543 break;
9544 }
9545 case Instruction::LShr:
9546 case Instruction::Shl:
9547 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9548 // stabilize to 0 in at most bitwidth(K) iterations.
9549 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9550 break;
9551 }
9552
9553 auto *Result =
9554 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9555 assert(Result->getType()->isIntegerTy(1) &&
9556 "Otherwise cannot be an operand to a branch instruction");
9557
9558 if (Result->isZeroValue()) {
9559 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9560 const SCEV *UpperBound =
9562 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9563 }
9564
9565 return getCouldNotCompute();
9566}
9567
9568/// Return true if we can constant fold an instruction of the specified type,
9569/// assuming that all operands were constants.
9570static bool CanConstantFold(const Instruction *I) {
9571 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9572 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9573 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9574 return true;
9575
9576 if (const CallInst *CI = dyn_cast<CallInst>(I))
9577 if (const Function *F = CI->getCalledFunction())
9578 return canConstantFoldCallTo(CI, F);
9579 return false;
9580}
9581
9582/// Determine whether this instruction can constant evolve within this loop
9583/// assuming its operands can all constant evolve.
9584static bool canConstantEvolve(Instruction *I, const Loop *L) {
9585 // An instruction outside of the loop can't be derived from a loop PHI.
9586 if (!L->contains(I)) return false;
9587
9588 if (isa<PHINode>(I)) {
9589 // We don't currently keep track of the control flow needed to evaluate
9590 // PHIs, so we cannot handle PHIs inside of loops.
9591 return L->getHeader() == I->getParent();
9592 }
9593
9594 // If we won't be able to constant fold this expression even if the operands
9595 // are constants, bail early.
9596 return CanConstantFold(I);
9597}
9598
9599/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9600/// recursing through each instruction operand until reaching a loop header phi.
9601static PHINode *
9604 unsigned Depth) {
9606 return nullptr;
9607
9608 // Otherwise, we can evaluate this instruction if all of its operands are
9609 // constant or derived from a PHI node themselves.
9610 PHINode *PHI = nullptr;
9611 for (Value *Op : UseInst->operands()) {
9612 if (isa<Constant>(Op)) continue;
9613
9614 Instruction *OpInst = dyn_cast<Instruction>(Op);
9615 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9616
9617 PHINode *P = dyn_cast<PHINode>(OpInst);
9618 if (!P)
9619 // If this operand is already visited, reuse the prior result.
9620 // We may have P != PHI if this is the deepest point at which the
9621 // inconsistent paths meet.
9622 P = PHIMap.lookup(OpInst);
9623 if (!P) {
9624 // Recurse and memoize the results, whether a phi is found or not.
9625 // This recursive call invalidates pointers into PHIMap.
9626 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9627 PHIMap[OpInst] = P;
9628 }
9629 if (!P)
9630 return nullptr; // Not evolving from PHI
9631 if (PHI && PHI != P)
9632 return nullptr; // Evolving from multiple different PHIs.
9633 PHI = P;
9634 }
9635 // This is a expression evolving from a constant PHI!
9636 return PHI;
9637}
9638
9639/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9640/// in the loop that V is derived from. We allow arbitrary operations along the
9641/// way, but the operands of an operation must either be constants or a value
9642/// derived from a constant PHI. If this expression does not fit with these
9643/// constraints, return null.
9645 Instruction *I = dyn_cast<Instruction>(V);
9646 if (!I || !canConstantEvolve(I, L)) return nullptr;
9647
9648 if (PHINode *PN = dyn_cast<PHINode>(I))
9649 return PN;
9650
9651 // Record non-constant instructions contained by the loop.
9653 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9654}
9655
9656/// EvaluateExpression - Given an expression that passes the
9657/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9658/// in the loop has the value PHIVal. If we can't fold this expression for some
9659/// reason, return null.
9662 const DataLayout &DL,
9663 const TargetLibraryInfo *TLI) {
9664 // Convenient constant check, but redundant for recursive calls.
9665 if (Constant *C = dyn_cast<Constant>(V)) return C;
9666 Instruction *I = dyn_cast<Instruction>(V);
9667 if (!I) return nullptr;
9668
9669 if (Constant *C = Vals.lookup(I)) return C;
9670
9671 // An instruction inside the loop depends on a value outside the loop that we
9672 // weren't given a mapping for, or a value such as a call inside the loop.
9673 if (!canConstantEvolve(I, L)) return nullptr;
9674
9675 // An unmapped PHI can be due to a branch or another loop inside this loop,
9676 // or due to this not being the initial iteration through a loop where we
9677 // couldn't compute the evolution of this particular PHI last time.
9678 if (isa<PHINode>(I)) return nullptr;
9679
9680 std::vector<Constant*> Operands(I->getNumOperands());
9681
9682 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9683 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9684 if (!Operand) {
9685 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9686 if (!Operands[i]) return nullptr;
9687 continue;
9688 }
9689 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9690 Vals[Operand] = C;
9691 if (!C) return nullptr;
9692 Operands[i] = C;
9693 }
9694
9695 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9696 /*AllowNonDeterministic=*/false);
9697}
9698
9699
9700// If every incoming value to PN except the one for BB is a specific Constant,
9701// return that, else return nullptr.
9703 Constant *IncomingVal = nullptr;
9704
9705 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9706 if (PN->getIncomingBlock(i) == BB)
9707 continue;
9708
9709 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9710 if (!CurrentVal)
9711 return nullptr;
9712
9713 if (IncomingVal != CurrentVal) {
9714 if (IncomingVal)
9715 return nullptr;
9716 IncomingVal = CurrentVal;
9717 }
9718 }
9719
9720 return IncomingVal;
9721}
9722
9723/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9724/// in the header of its containing loop, we know the loop executes a
9725/// constant number of times, and the PHI node is just a recurrence
9726/// involving constants, fold it.
9727Constant *
9728ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9729 const APInt &BEs,
9730 const Loop *L) {
9731 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9732 if (!Inserted)
9733 return I->second;
9734
9736 return nullptr; // Not going to evaluate it.
9737
9738 Constant *&RetVal = I->second;
9739
9741 BasicBlock *Header = L->getHeader();
9742 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9743
9744 BasicBlock *Latch = L->getLoopLatch();
9745 if (!Latch)
9746 return nullptr;
9747
9748 for (PHINode &PHI : Header->phis()) {
9749 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9750 CurrentIterVals[&PHI] = StartCST;
9751 }
9752 if (!CurrentIterVals.count(PN))
9753 return RetVal = nullptr;
9754
9755 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9756
9757 // Execute the loop symbolically to determine the exit value.
9758 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9759 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9760
9761 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9762 unsigned IterationNum = 0;
9763 const DataLayout &DL = getDataLayout();
9764 for (; ; ++IterationNum) {
9765 if (IterationNum == NumIterations)
9766 return RetVal = CurrentIterVals[PN]; // Got exit value!
9767
9768 // Compute the value of the PHIs for the next iteration.
9769 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9771 Constant *NextPHI =
9772 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9773 if (!NextPHI)
9774 return nullptr; // Couldn't evaluate!
9775 NextIterVals[PN] = NextPHI;
9776
9777 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9778
9779 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9780 // cease to be able to evaluate one of them or if they stop evolving,
9781 // because that doesn't necessarily prevent us from computing PN.
9783 for (const auto &I : CurrentIterVals) {
9784 PHINode *PHI = dyn_cast<PHINode>(I.first);
9785 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9786 PHIsToCompute.emplace_back(PHI, I.second);
9787 }
9788 // We use two distinct loops because EvaluateExpression may invalidate any
9789 // iterators into CurrentIterVals.
9790 for (const auto &I : PHIsToCompute) {
9791 PHINode *PHI = I.first;
9792 Constant *&NextPHI = NextIterVals[PHI];
9793 if (!NextPHI) { // Not already computed.
9794 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9795 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9796 }
9797 if (NextPHI != I.second)
9798 StoppedEvolving = false;
9799 }
9800
9801 // If all entries in CurrentIterVals == NextIterVals then we can stop
9802 // iterating, the loop can't continue to change.
9803 if (StoppedEvolving)
9804 return RetVal = CurrentIterVals[PN];
9805
9806 CurrentIterVals.swap(NextIterVals);
9807 }
9808}
9809
9810const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9811 Value *Cond,
9812 bool ExitWhen) {
9814 if (!PN) return getCouldNotCompute();
9815
9816 // If the loop is canonicalized, the PHI will have exactly two entries.
9817 // That's the only form we support here.
9818 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9819
9821 BasicBlock *Header = L->getHeader();
9822 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9823
9824 BasicBlock *Latch = L->getLoopLatch();
9825 assert(Latch && "Should follow from NumIncomingValues == 2!");
9826
9827 for (PHINode &PHI : Header->phis()) {
9828 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9829 CurrentIterVals[&PHI] = StartCST;
9830 }
9831 if (!CurrentIterVals.count(PN))
9832 return getCouldNotCompute();
9833
9834 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9835 // the loop symbolically to determine when the condition gets a value of
9836 // "ExitWhen".
9837 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9838 const DataLayout &DL = getDataLayout();
9839 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9840 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9841 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9842
9843 // Couldn't symbolically evaluate.
9844 if (!CondVal) return getCouldNotCompute();
9845
9846 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9847 ++NumBruteForceTripCountsComputed;
9848 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9849 }
9850
9851 // Update all the PHI nodes for the next iteration.
9853
9854 // Create a list of which PHIs we need to compute. We want to do this before
9855 // calling EvaluateExpression on them because that may invalidate iterators
9856 // into CurrentIterVals.
9857 SmallVector<PHINode *, 8> PHIsToCompute;
9858 for (const auto &I : CurrentIterVals) {
9859 PHINode *PHI = dyn_cast<PHINode>(I.first);
9860 if (!PHI || PHI->getParent() != Header) continue;
9861 PHIsToCompute.push_back(PHI);
9862 }
9863 for (PHINode *PHI : PHIsToCompute) {
9864 Constant *&NextPHI = NextIterVals[PHI];
9865 if (NextPHI) continue; // Already computed!
9866
9867 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9868 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9869 }
9870 CurrentIterVals.swap(NextIterVals);
9871 }
9872
9873 // Too many iterations were needed to evaluate.
9874 return getCouldNotCompute();
9875}
9876
9877const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9879 ValuesAtScopes[V];
9880 // Check to see if we've folded this expression at this loop before.
9881 for (auto &LS : Values)
9882 if (LS.first == L)
9883 return LS.second ? LS.second : V;
9884
9885 Values.emplace_back(L, nullptr);
9886
9887 // Otherwise compute it.
9888 const SCEV *C = computeSCEVAtScope(V, L);
9889 for (auto &LS : reverse(ValuesAtScopes[V]))
9890 if (LS.first == L) {
9891 LS.second = C;
9892 if (!isa<SCEVConstant>(C))
9893 ValuesAtScopesUsers[C].push_back({L, V});
9894 break;
9895 }
9896 return C;
9897}
9898
9899/// This builds up a Constant using the ConstantExpr interface. That way, we
9900/// will return Constants for objects which aren't represented by a
9901/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9902/// Returns NULL if the SCEV isn't representable as a Constant.
9904 switch (V->getSCEVType()) {
9905 case scCouldNotCompute:
9906 case scAddRecExpr:
9907 case scVScale:
9908 return nullptr;
9909 case scConstant:
9910 return cast<SCEVConstant>(V)->getValue();
9911 case scUnknown:
9912 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9913 case scPtrToInt: {
9914 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9915 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9916 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9917
9918 return nullptr;
9919 }
9920 case scTruncate: {
9921 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9922 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9923 return ConstantExpr::getTrunc(CastOp, ST->getType());
9924 return nullptr;
9925 }
9926 case scAddExpr: {
9927 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9928 Constant *C = nullptr;
9929 for (const SCEV *Op : SA->operands()) {
9931 if (!OpC)
9932 return nullptr;
9933 if (!C) {
9934 C = OpC;
9935 continue;
9936 }
9937 assert(!C->getType()->isPointerTy() &&
9938 "Can only have one pointer, and it must be last");
9939 if (OpC->getType()->isPointerTy()) {
9940 // The offsets have been converted to bytes. We can add bytes using
9941 // an i8 GEP.
9943 OpC, C);
9944 } else {
9945 C = ConstantExpr::getAdd(C, OpC);
9946 }
9947 }
9948 return C;
9949 }
9950 case scMulExpr:
9951 case scSignExtend:
9952 case scZeroExtend:
9953 case scUDivExpr:
9954 case scSMaxExpr:
9955 case scUMaxExpr:
9956 case scSMinExpr:
9957 case scUMinExpr:
9959 return nullptr;
9960 }
9961 llvm_unreachable("Unknown SCEV kind!");
9962}
9963
9964const SCEV *
9965ScalarEvolution::getWithOperands(const SCEV *S,
9967 switch (S->getSCEVType()) {
9968 case scTruncate:
9969 case scZeroExtend:
9970 case scSignExtend:
9971 case scPtrToInt:
9972 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9973 case scAddRecExpr: {
9974 auto *AddRec = cast<SCEVAddRecExpr>(S);
9975 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9976 }
9977 case scAddExpr:
9978 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9979 case scMulExpr:
9980 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9981 case scUDivExpr:
9982 return getUDivExpr(NewOps[0], NewOps[1]);
9983 case scUMaxExpr:
9984 case scSMaxExpr:
9985 case scUMinExpr:
9986 case scSMinExpr:
9987 return getMinMaxExpr(S->getSCEVType(), NewOps);
9989 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9990 case scConstant:
9991 case scVScale:
9992 case scUnknown:
9993 return S;
9994 case scCouldNotCompute:
9995 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9996 }
9997 llvm_unreachable("Unknown SCEV kind!");
9998}
9999
10000const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10001 switch (V->getSCEVType()) {
10002 case scConstant:
10003 case scVScale:
10004 return V;
10005 case scAddRecExpr: {
10006 // If this is a loop recurrence for a loop that does not contain L, then we
10007 // are dealing with the final value computed by the loop.
10008 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10009 // First, attempt to evaluate each operand.
10010 // Avoid performing the look-up in the common case where the specified
10011 // expression has no loop-variant portions.
10012 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10013 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10014 if (OpAtScope == AddRec->getOperand(i))
10015 continue;
10016
10017 // Okay, at least one of these operands is loop variant but might be
10018 // foldable. Build a new instance of the folded commutative expression.
10020 NewOps.reserve(AddRec->getNumOperands());
10021 append_range(NewOps, AddRec->operands().take_front(i));
10022 NewOps.push_back(OpAtScope);
10023 for (++i; i != e; ++i)
10024 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10025
10026 const SCEV *FoldedRec = getAddRecExpr(
10027 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10028 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10029 // The addrec may be folded to a nonrecurrence, for example, if the
10030 // induction variable is multiplied by zero after constant folding. Go
10031 // ahead and return the folded value.
10032 if (!AddRec)
10033 return FoldedRec;
10034 break;
10035 }
10036
10037 // If the scope is outside the addrec's loop, evaluate it by using the
10038 // loop exit value of the addrec.
10039 if (!AddRec->getLoop()->contains(L)) {
10040 // To evaluate this recurrence, we need to know how many times the AddRec
10041 // loop iterates. Compute this now.
10042 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10043 if (BackedgeTakenCount == getCouldNotCompute())
10044 return AddRec;
10045
10046 // Then, evaluate the AddRec.
10047 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10048 }
10049
10050 return AddRec;
10051 }
10052 case scTruncate:
10053 case scZeroExtend:
10054 case scSignExtend:
10055 case scPtrToInt:
10056 case scAddExpr:
10057 case scMulExpr:
10058 case scUDivExpr:
10059 case scUMaxExpr:
10060 case scSMaxExpr:
10061 case scUMinExpr:
10062 case scSMinExpr:
10063 case scSequentialUMinExpr: {
10064 ArrayRef<const SCEV *> Ops = V->operands();
10065 // Avoid performing the look-up in the common case where the specified
10066 // expression has no loop-variant portions.
10067 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10068 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10069 if (OpAtScope != Ops[i]) {
10070 // Okay, at least one of these operands is loop variant but might be
10071 // foldable. Build a new instance of the folded commutative expression.
10073 NewOps.reserve(Ops.size());
10074 append_range(NewOps, Ops.take_front(i));
10075 NewOps.push_back(OpAtScope);
10076
10077 for (++i; i != e; ++i) {
10078 OpAtScope = getSCEVAtScope(Ops[i], L);
10079 NewOps.push_back(OpAtScope);
10080 }
10081
10082 return getWithOperands(V, NewOps);
10083 }
10084 }
10085 // If we got here, all operands are loop invariant.
10086 return V;
10087 }
10088 case scUnknown: {
10089 // If this instruction is evolved from a constant-evolving PHI, compute the
10090 // exit value from the loop without using SCEVs.
10091 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10092 Instruction *I = dyn_cast<Instruction>(SU->getValue());
10093 if (!I)
10094 return V; // This is some other type of SCEVUnknown, just return it.
10095
10096 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10097 const Loop *CurrLoop = this->LI[I->getParent()];
10098 // Looking for loop exit value.
10099 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10100 PN->getParent() == CurrLoop->getHeader()) {
10101 // Okay, there is no closed form solution for the PHI node. Check
10102 // to see if the loop that contains it has a known backedge-taken
10103 // count. If so, we may be able to force computation of the exit
10104 // value.
10105 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10106 // This trivial case can show up in some degenerate cases where
10107 // the incoming IR has not yet been fully simplified.
10108 if (BackedgeTakenCount->isZero()) {
10109 Value *InitValue = nullptr;
10110 bool MultipleInitValues = false;
10111 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10112 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10113 if (!InitValue)
10114 InitValue = PN->getIncomingValue(i);
10115 else if (InitValue != PN->getIncomingValue(i)) {
10116 MultipleInitValues = true;
10117 break;
10118 }
10119 }
10120 }
10121 if (!MultipleInitValues && InitValue)
10122 return getSCEV(InitValue);
10123 }
10124 // Do we have a loop invariant value flowing around the backedge
10125 // for a loop which must execute the backedge?
10126 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10127 isKnownNonZero(BackedgeTakenCount) &&
10128 PN->getNumIncomingValues() == 2) {
10129
10130 unsigned InLoopPred =
10131 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10132 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10133 if (CurrLoop->isLoopInvariant(BackedgeVal))
10134 return getSCEV(BackedgeVal);
10135 }
10136 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10137 // Okay, we know how many times the containing loop executes. If
10138 // this is a constant evolving PHI node, get the final value at
10139 // the specified iteration number.
10140 Constant *RV =
10141 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10142 if (RV)
10143 return getSCEV(RV);
10144 }
10145 }
10146 }
10147
10148 // Okay, this is an expression that we cannot symbolically evaluate
10149 // into a SCEV. Check to see if it's possible to symbolically evaluate
10150 // the arguments into constants, and if so, try to constant propagate the
10151 // result. This is particularly useful for computing loop exit values.
10152 if (!CanConstantFold(I))
10153 return V; // This is some other type of SCEVUnknown, just return it.
10154
10156 Operands.reserve(I->getNumOperands());
10157 bool MadeImprovement = false;
10158 for (Value *Op : I->operands()) {
10159 if (Constant *C = dyn_cast<Constant>(Op)) {
10160 Operands.push_back(C);
10161 continue;
10162 }
10163
10164 // If any of the operands is non-constant and if they are
10165 // non-integer and non-pointer, don't even try to analyze them
10166 // with scev techniques.
10167 if (!isSCEVable(Op->getType()))
10168 return V;
10169
10170 const SCEV *OrigV = getSCEV(Op);
10171 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10172 MadeImprovement |= OrigV != OpV;
10173
10175 if (!C)
10176 return V;
10177 assert(C->getType() == Op->getType() && "Type mismatch");
10178 Operands.push_back(C);
10179 }
10180
10181 // Check to see if getSCEVAtScope actually made an improvement.
10182 if (!MadeImprovement)
10183 return V; // This is some other type of SCEVUnknown, just return it.
10184
10185 Constant *C = nullptr;
10186 const DataLayout &DL = getDataLayout();
10187 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10188 /*AllowNonDeterministic=*/false);
10189 if (!C)
10190 return V;
10191 return getSCEV(C);
10192 }
10193 case scCouldNotCompute:
10194 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10195 }
10196 llvm_unreachable("Unknown SCEV type!");
10197}
10198
10200 return getSCEVAtScope(getSCEV(V), L);
10201}
10202
10203const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10204 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10205 return stripInjectiveFunctions(ZExt->getOperand());
10206 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10207 return stripInjectiveFunctions(SExt->getOperand());
10208 return S;
10209}
10210
10211/// Finds the minimum unsigned root of the following equation:
10212///
10213/// A * X = B (mod N)
10214///
10215/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10216/// A and B isn't important.
10217///
10218/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10219static const SCEV *
10222
10223 ScalarEvolution &SE) {
10224 uint32_t BW = A.getBitWidth();
10225 assert(BW == SE.getTypeSizeInBits(B->getType()));
10226 assert(A != 0 && "A must be non-zero.");
10227
10228 // 1. D = gcd(A, N)
10229 //
10230 // The gcd of A and N may have only one prime factor: 2. The number of
10231 // trailing zeros in A is its multiplicity
10232 uint32_t Mult2 = A.countr_zero();
10233 // D = 2^Mult2
10234
10235 // 2. Check if B is divisible by D.
10236 //
10237 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10238 // is not less than multiplicity of this prime factor for D.
10239 if (SE.getMinTrailingZeros(B) < Mult2) {
10240 // Check if we can prove there's no remainder using URem.
10241 const SCEV *URem =
10242 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10243 const SCEV *Zero = SE.getZero(B->getType());
10244 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10245 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10246 if (!Predicates)
10247 return SE.getCouldNotCompute();
10248
10249 // Avoid adding a predicate that is known to be false.
10250 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10251 return SE.getCouldNotCompute();
10252 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10253 }
10254 }
10255
10256 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10257 // modulo (N / D).
10258 //
10259 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10260 // (N / D) in general. The inverse itself always fits into BW bits, though,
10261 // so we immediately truncate it.
10262 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10263 APInt I = AD.multiplicativeInverse().zext(BW);
10264
10265 // 4. Compute the minimum unsigned root of the equation:
10266 // I * (B / D) mod (N / D)
10267 // To simplify the computation, we factor out the divide by D:
10268 // (I * B mod N) / D
10269 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10270 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10271}
10272
10273/// For a given quadratic addrec, generate coefficients of the corresponding
10274/// quadratic equation, multiplied by a common value to ensure that they are
10275/// integers.
10276/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10277/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10278/// were multiplied by, and BitWidth is the bit width of the original addrec
10279/// coefficients.
10280/// This function returns std::nullopt if the addrec coefficients are not
10281/// compile- time constants.
10282static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10284 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10285 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10286 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10287 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10288 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10289 << *AddRec << '\n');
10290
10291 // We currently can only solve this if the coefficients are constants.
10292 if (!LC || !MC || !NC) {
10293 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10294 return std::nullopt;
10295 }
10296
10297 APInt L = LC->getAPInt();
10298 APInt M = MC->getAPInt();
10299 APInt N = NC->getAPInt();
10300 assert(!N.isZero() && "This is not a quadratic addrec");
10301
10302 unsigned BitWidth = LC->getAPInt().getBitWidth();
10303 unsigned NewWidth = BitWidth + 1;
10304 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10305 << BitWidth << '\n');
10306 // The sign-extension (as opposed to a zero-extension) here matches the
10307 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10308 N = N.sext(NewWidth);
10309 M = M.sext(NewWidth);
10310 L = L.sext(NewWidth);
10311
10312 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10313 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10314 // L+M, L+2M+N, L+3M+3N, ...
10315 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10316 //
10317 // The equation Acc = 0 is then
10318 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10319 // In a quadratic form it becomes:
10320 // N n^2 + (2M-N) n + 2L = 0.
10321
10322 APInt A = N;
10323 APInt B = 2 * M - A;
10324 APInt C = 2 * L;
10325 APInt T = APInt(NewWidth, 2);
10326 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10327 << "x + " << C << ", coeff bw: " << NewWidth
10328 << ", multiplied by " << T << '\n');
10329 return std::make_tuple(A, B, C, T, BitWidth);
10330}
10331
10332/// Helper function to compare optional APInts:
10333/// (a) if X and Y both exist, return min(X, Y),
10334/// (b) if neither X nor Y exist, return std::nullopt,
10335/// (c) if exactly one of X and Y exists, return that value.
10336static std::optional<APInt> MinOptional(std::optional<APInt> X,
10337 std::optional<APInt> Y) {
10338 if (X && Y) {
10339 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10340 APInt XW = X->sext(W);
10341 APInt YW = Y->sext(W);
10342 return XW.slt(YW) ? *X : *Y;
10343 }
10344 if (!X && !Y)
10345 return std::nullopt;
10346 return X ? *X : *Y;
10347}
10348
10349/// Helper function to truncate an optional APInt to a given BitWidth.
10350/// When solving addrec-related equations, it is preferable to return a value
10351/// that has the same bit width as the original addrec's coefficients. If the
10352/// solution fits in the original bit width, truncate it (except for i1).
10353/// Returning a value of a different bit width may inhibit some optimizations.
10354///
10355/// In general, a solution to a quadratic equation generated from an addrec
10356/// may require BW+1 bits, where BW is the bit width of the addrec's
10357/// coefficients. The reason is that the coefficients of the quadratic
10358/// equation are BW+1 bits wide (to avoid truncation when converting from
10359/// the addrec to the equation).
10360static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10361 unsigned BitWidth) {
10362 if (!X)
10363 return std::nullopt;
10364 unsigned W = X->getBitWidth();
10365 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10366 return X->trunc(BitWidth);
10367 return X;
10368}
10369
10370/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10371/// iterations. The values L, M, N are assumed to be signed, and they
10372/// should all have the same bit widths.
10373/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10374/// where BW is the bit width of the addrec's coefficients.
10375/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10376/// returned as such, otherwise the bit width of the returned value may
10377/// be greater than BW.
10378///
10379/// This function returns std::nullopt if
10380/// (a) the addrec coefficients are not constant, or
10381/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10382/// like x^2 = 5, no integer solutions exist, in other cases an integer
10383/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10384static std::optional<APInt>
10386 APInt A, B, C, M;
10387 unsigned BitWidth;
10388 auto T = GetQuadraticEquation(AddRec);
10389 if (!T)
10390 return std::nullopt;
10391
10392 std::tie(A, B, C, M, BitWidth) = *T;
10393 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10394 std::optional<APInt> X =
10396 if (!X)
10397 return std::nullopt;
10398
10399 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10400 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10401 if (!V->isZero())
10402 return std::nullopt;
10403
10404 return TruncIfPossible(X, BitWidth);
10405}
10406
10407/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10408/// iterations. The values M, N are assumed to be signed, and they
10409/// should all have the same bit widths.
10410/// Find the least n such that c(n) does not belong to the given range,
10411/// while c(n-1) does.
10412///
10413/// This function returns std::nullopt if
10414/// (a) the addrec coefficients are not constant, or
10415/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10416/// bounds of the range.
10417static std::optional<APInt>
10419 const ConstantRange &Range, ScalarEvolution &SE) {
10420 assert(AddRec->getOperand(0)->isZero() &&
10421 "Starting value of addrec should be 0");
10422 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10423 << Range << ", addrec " << *AddRec << '\n');
10424 // This case is handled in getNumIterationsInRange. Here we can assume that
10425 // we start in the range.
10426 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10427 "Addrec's initial value should be in range");
10428
10429 APInt A, B, C, M;
10430 unsigned BitWidth;
10431 auto T = GetQuadraticEquation(AddRec);
10432 if (!T)
10433 return std::nullopt;
10434
10435 // Be careful about the return value: there can be two reasons for not
10436 // returning an actual number. First, if no solutions to the equations
10437 // were found, and second, if the solutions don't leave the given range.
10438 // The first case means that the actual solution is "unknown", the second
10439 // means that it's known, but not valid. If the solution is unknown, we
10440 // cannot make any conclusions.
10441 // Return a pair: the optional solution and a flag indicating if the
10442 // solution was found.
10443 auto SolveForBoundary =
10444 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10445 // Solve for signed overflow and unsigned overflow, pick the lower
10446 // solution.
10447 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10448 << Bound << " (before multiplying by " << M << ")\n");
10449 Bound *= M; // The quadratic equation multiplier.
10450
10451 std::optional<APInt> SO;
10452 if (BitWidth > 1) {
10453 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10454 "signed overflow\n");
10456 }
10457 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10458 "unsigned overflow\n");
10459 std::optional<APInt> UO =
10461
10462 auto LeavesRange = [&] (const APInt &X) {
10463 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10464 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10465 if (Range.contains(V0->getValue()))
10466 return false;
10467 // X should be at least 1, so X-1 is non-negative.
10468 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10469 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10470 if (Range.contains(V1->getValue()))
10471 return true;
10472 return false;
10473 };
10474
10475 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10476 // can be a solution, but the function failed to find it. We cannot treat it
10477 // as "no solution".
10478 if (!SO || !UO)
10479 return {std::nullopt, false};
10480
10481 // Check the smaller value first to see if it leaves the range.
10482 // At this point, both SO and UO must have values.
10483 std::optional<APInt> Min = MinOptional(SO, UO);
10484 if (LeavesRange(*Min))
10485 return { Min, true };
10486 std::optional<APInt> Max = Min == SO ? UO : SO;
10487 if (LeavesRange(*Max))
10488 return { Max, true };
10489
10490 // Solutions were found, but were eliminated, hence the "true".
10491 return {std::nullopt, true};
10492 };
10493
10494 std::tie(A, B, C, M, BitWidth) = *T;
10495 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10496 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10497 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10498 auto SL = SolveForBoundary(Lower);
10499 auto SU = SolveForBoundary(Upper);
10500 // If any of the solutions was unknown, no meaninigful conclusions can
10501 // be made.
10502 if (!SL.second || !SU.second)
10503 return std::nullopt;
10504
10505 // Claim: The correct solution is not some value between Min and Max.
10506 //
10507 // Justification: Assuming that Min and Max are different values, one of
10508 // them is when the first signed overflow happens, the other is when the
10509 // first unsigned overflow happens. Crossing the range boundary is only
10510 // possible via an overflow (treating 0 as a special case of it, modeling
10511 // an overflow as crossing k*2^W for some k).
10512 //
10513 // The interesting case here is when Min was eliminated as an invalid
10514 // solution, but Max was not. The argument is that if there was another
10515 // overflow between Min and Max, it would also have been eliminated if
10516 // it was considered.
10517 //
10518 // For a given boundary, it is possible to have two overflows of the same
10519 // type (signed/unsigned) without having the other type in between: this
10520 // can happen when the vertex of the parabola is between the iterations
10521 // corresponding to the overflows. This is only possible when the two
10522 // overflows cross k*2^W for the same k. In such case, if the second one
10523 // left the range (and was the first one to do so), the first overflow
10524 // would have to enter the range, which would mean that either we had left
10525 // the range before or that we started outside of it. Both of these cases
10526 // are contradictions.
10527 //
10528 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10529 // solution is not some value between the Max for this boundary and the
10530 // Min of the other boundary.
10531 //
10532 // Justification: Assume that we had such Max_A and Min_B corresponding
10533 // to range boundaries A and B and such that Max_A < Min_B. If there was
10534 // a solution between Max_A and Min_B, it would have to be caused by an
10535 // overflow corresponding to either A or B. It cannot correspond to B,
10536 // since Min_B is the first occurrence of such an overflow. If it
10537 // corresponded to A, it would have to be either a signed or an unsigned
10538 // overflow that is larger than both eliminated overflows for A. But
10539 // between the eliminated overflows and this overflow, the values would
10540 // cover the entire value space, thus crossing the other boundary, which
10541 // is a contradiction.
10542
10543 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10544}
10545
10546ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10547 const Loop *L,
10548 bool ControlsOnlyExit,
10549 bool AllowPredicates) {
10550
10551 // This is only used for loops with a "x != y" exit test. The exit condition
10552 // is now expressed as a single expression, V = x-y. So the exit test is
10553 // effectively V != 0. We know and take advantage of the fact that this
10554 // expression only being used in a comparison by zero context.
10555
10557 // If the value is a constant
10558 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10559 // If the value is already zero, the branch will execute zero times.
10560 if (C->getValue()->isZero()) return C;
10561 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10562 }
10563
10564 const SCEVAddRecExpr *AddRec =
10565 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10566
10567 if (!AddRec && AllowPredicates)
10568 // Try to make this an AddRec using runtime tests, in the first X
10569 // iterations of this loop, where X is the SCEV expression found by the
10570 // algorithm below.
10571 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10572
10573 if (!AddRec || AddRec->getLoop() != L)
10574 return getCouldNotCompute();
10575
10576 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10577 // the quadratic equation to solve it.
10578 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10579 // We can only use this value if the chrec ends up with an exact zero
10580 // value at this index. When solving for "X*X != 5", for example, we
10581 // should not accept a root of 2.
10582 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10583 const auto *R = cast<SCEVConstant>(getConstant(*S));
10584 return ExitLimit(R, R, R, false, Predicates);
10585 }
10586 return getCouldNotCompute();
10587 }
10588
10589 // Otherwise we can only handle this if it is affine.
10590 if (!AddRec->isAffine())
10591 return getCouldNotCompute();
10592
10593 // If this is an affine expression, the execution count of this branch is
10594 // the minimum unsigned root of the following equation:
10595 //
10596 // Start + Step*N = 0 (mod 2^BW)
10597 //
10598 // equivalent to:
10599 //
10600 // Step*N = -Start (mod 2^BW)
10601 //
10602 // where BW is the common bit width of Start and Step.
10603
10604 // Get the initial value for the loop.
10605 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10606 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10607
10608 if (!isLoopInvariant(Step, L))
10609 return getCouldNotCompute();
10610
10611 LoopGuards Guards = LoopGuards::collect(L, *this);
10612 // Specialize step for this loop so we get context sensitive facts below.
10613 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10614
10615 // For positive steps (counting up until unsigned overflow):
10616 // N = -Start/Step (as unsigned)
10617 // For negative steps (counting down to zero):
10618 // N = Start/-Step
10619 // First compute the unsigned distance from zero in the direction of Step.
10620 bool CountDown = isKnownNegative(StepWLG);
10621 if (!CountDown && !isKnownNonNegative(StepWLG))
10622 return getCouldNotCompute();
10623
10624 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10625 // Handle unitary steps, which cannot wraparound.
10626 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10627 // N = Distance (as unsigned)
10628
10629 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10630 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10631 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10632
10633 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10634 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10635 // case, and see if we can improve the bound.
10636 //
10637 // Explicitly handling this here is necessary because getUnsignedRange
10638 // isn't context-sensitive; it doesn't know that we only care about the
10639 // range inside the loop.
10640 const SCEV *Zero = getZero(Distance->getType());
10641 const SCEV *One = getOne(Distance->getType());
10642 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10643 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10644 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10645 // as "unsigned_max(Distance + 1) - 1".
10646 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10647 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10648 }
10649 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10650 Predicates);
10651 }
10652
10653 // If the condition controls loop exit (the loop exits only if the expression
10654 // is true) and the addition is no-wrap we can use unsigned divide to
10655 // compute the backedge count. In this case, the step may not divide the
10656 // distance, but we don't care because if the condition is "missed" the loop
10657 // will have undefined behavior due to wrapping.
10658 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10659 loopHasNoAbnormalExits(AddRec->getLoop())) {
10660
10661 // If the stride is zero and the start is non-zero, the loop must be
10662 // infinite. In C++, most loops are finite by assumption, in which case the
10663 // step being zero implies UB must execute if the loop is entered.
10664 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10665 !isKnownNonZero(StepWLG))
10666 return getCouldNotCompute();
10667
10668 const SCEV *Exact =
10669 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10670 const SCEV *ConstantMax = getCouldNotCompute();
10671 if (Exact != getCouldNotCompute()) {
10673 ConstantMax =
10675 }
10676 const SCEV *SymbolicMax =
10677 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10678 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10679 }
10680
10681 // Solve the general equation.
10682 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10683 if (!StepC || StepC->getValue()->isZero())
10684 return getCouldNotCompute();
10686 StepC->getAPInt(), getNegativeSCEV(Start),
10687 AllowPredicates ? &Predicates : nullptr, *this);
10688
10689 const SCEV *M = E;
10690 if (E != getCouldNotCompute()) {
10691 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10692 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10693 }
10694 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10695 return ExitLimit(E, M, S, false, Predicates);
10696}
10697
10699ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10700 // Loops that look like: while (X == 0) are very strange indeed. We don't
10701 // handle them yet except for the trivial case. This could be expanded in the
10702 // future as needed.
10703
10704 // If the value is a constant, check to see if it is known to be non-zero
10705 // already. If so, the backedge will execute zero times.
10706 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10707 if (!C->getValue()->isZero())
10708 return getZero(C->getType());
10709 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10710 }
10711
10712 // We could implement others, but I really doubt anyone writes loops like
10713 // this, and if they did, they would already be constant folded.
10714 return getCouldNotCompute();
10715}
10716
10717std::pair<const BasicBlock *, const BasicBlock *>
10718ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10719 const {
10720 // If the block has a unique predecessor, then there is no path from the
10721 // predecessor to the block that does not go through the direct edge
10722 // from the predecessor to the block.
10723 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10724 return {Pred, BB};
10725
10726 // A loop's header is defined to be a block that dominates the loop.
10727 // If the header has a unique predecessor outside the loop, it must be
10728 // a block that has exactly one successor that can reach the loop.
10729 if (const Loop *L = LI.getLoopFor(BB))
10730 return {L->getLoopPredecessor(), L->getHeader()};
10731
10732 return {nullptr, BB};
10733}
10734
10735/// SCEV structural equivalence is usually sufficient for testing whether two
10736/// expressions are equal, however for the purposes of looking for a condition
10737/// guarding a loop, it can be useful to be a little more general, since a
10738/// front-end may have replicated the controlling expression.
10739static bool HasSameValue(const SCEV *A, const SCEV *B) {
10740 // Quick check to see if they are the same SCEV.
10741 if (A == B) return true;
10742
10743 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10744 // Not all instructions that are "identical" compute the same value. For
10745 // instance, two distinct alloca instructions allocating the same type are
10746 // identical and do not read memory; but compute distinct values.
10747 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10748 };
10749
10750 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10751 // two different instructions with the same value. Check for this case.
10752 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10753 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10754 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10755 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10756 if (ComputesEqualValues(AI, BI))
10757 return true;
10758
10759 // Otherwise assume they may have a different value.
10760 return false;
10761}
10762
10763static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10764 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10765 if (!Add || Add->getNumOperands() != 2)
10766 return false;
10767 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10768 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10769 LHS = Add->getOperand(1);
10770 RHS = ME->getOperand(1);
10771 return true;
10772 }
10773 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10774 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10775 LHS = Add->getOperand(0);
10776 RHS = ME->getOperand(1);
10777 return true;
10778 }
10779 return false;
10780}
10781
10783 const SCEV *&RHS, unsigned Depth) {
10784 bool Changed = false;
10785 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10786 // '0 != 0'.
10787 auto TrivialCase = [&](bool TriviallyTrue) {
10789 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10790 return true;
10791 };
10792 // If we hit the max recursion limit bail out.
10793 if (Depth >= 3)
10794 return false;
10795
10796 // Canonicalize a constant to the right side.
10797 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10798 // Check for both operands constant.
10799 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10800 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10801 return TrivialCase(false);
10802 return TrivialCase(true);
10803 }
10804 // Otherwise swap the operands to put the constant on the right.
10805 std::swap(LHS, RHS);
10807 Changed = true;
10808 }
10809
10810 // If we're comparing an addrec with a value which is loop-invariant in the
10811 // addrec's loop, put the addrec on the left. Also make a dominance check,
10812 // as both operands could be addrecs loop-invariant in each other's loop.
10813 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10814 const Loop *L = AR->getLoop();
10815 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10816 std::swap(LHS, RHS);
10818 Changed = true;
10819 }
10820 }
10821
10822 // If there's a constant operand, canonicalize comparisons with boundary
10823 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10824 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10825 const APInt &RA = RC->getAPInt();
10826
10827 bool SimplifiedByConstantRange = false;
10828
10829 if (!ICmpInst::isEquality(Pred)) {
10831 if (ExactCR.isFullSet())
10832 return TrivialCase(true);
10833 if (ExactCR.isEmptySet())
10834 return TrivialCase(false);
10835
10836 APInt NewRHS;
10837 CmpInst::Predicate NewPred;
10838 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10839 ICmpInst::isEquality(NewPred)) {
10840 // We were able to convert an inequality to an equality.
10841 Pred = NewPred;
10842 RHS = getConstant(NewRHS);
10843 Changed = SimplifiedByConstantRange = true;
10844 }
10845 }
10846
10847 if (!SimplifiedByConstantRange) {
10848 switch (Pred) {
10849 default:
10850 break;
10851 case ICmpInst::ICMP_EQ:
10852 case ICmpInst::ICMP_NE:
10853 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10854 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10855 Changed = true;
10856 break;
10857
10858 // The "Should have been caught earlier!" messages refer to the fact
10859 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10860 // should have fired on the corresponding cases, and canonicalized the
10861 // check to trivial case.
10862
10863 case ICmpInst::ICMP_UGE:
10864 assert(!RA.isMinValue() && "Should have been caught earlier!");
10865 Pred = ICmpInst::ICMP_UGT;
10866 RHS = getConstant(RA - 1);
10867 Changed = true;
10868 break;
10869 case ICmpInst::ICMP_ULE:
10870 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10871 Pred = ICmpInst::ICMP_ULT;
10872 RHS = getConstant(RA + 1);
10873 Changed = true;
10874 break;
10875 case ICmpInst::ICMP_SGE:
10876 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10877 Pred = ICmpInst::ICMP_SGT;
10878 RHS = getConstant(RA - 1);
10879 Changed = true;
10880 break;
10881 case ICmpInst::ICMP_SLE:
10882 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10883 Pred = ICmpInst::ICMP_SLT;
10884 RHS = getConstant(RA + 1);
10885 Changed = true;
10886 break;
10887 }
10888 }
10889 }
10890
10891 // Check for obvious equality.
10892 if (HasSameValue(LHS, RHS)) {
10893 if (ICmpInst::isTrueWhenEqual(Pred))
10894 return TrivialCase(true);
10896 return TrivialCase(false);
10897 }
10898
10899 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10900 // adding or subtracting 1 from one of the operands.
10901 switch (Pred) {
10902 case ICmpInst::ICMP_SLE:
10903 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10904 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10906 Pred = ICmpInst::ICMP_SLT;
10907 Changed = true;
10908 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10909 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10911 Pred = ICmpInst::ICMP_SLT;
10912 Changed = true;
10913 }
10914 break;
10915 case ICmpInst::ICMP_SGE:
10916 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10917 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10919 Pred = ICmpInst::ICMP_SGT;
10920 Changed = true;
10921 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10922 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10924 Pred = ICmpInst::ICMP_SGT;
10925 Changed = true;
10926 }
10927 break;
10928 case ICmpInst::ICMP_ULE:
10929 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10930 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10932 Pred = ICmpInst::ICMP_ULT;
10933 Changed = true;
10934 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10935 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10936 Pred = ICmpInst::ICMP_ULT;
10937 Changed = true;
10938 }
10939 break;
10940 case ICmpInst::ICMP_UGE:
10941 // If RHS is an op we can fold the -1, try that first.
10942 // Otherwise prefer LHS to preserve the nuw flag.
10943 if ((isa<SCEVConstant>(RHS) ||
10944 (isa<SCEVAddExpr, SCEVAddRecExpr>(RHS) &&
10945 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10946 !getUnsignedRangeMin(RHS).isMinValue()) {
10947 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10948 Pred = ICmpInst::ICMP_UGT;
10949 Changed = true;
10950 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10951 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10953 Pred = ICmpInst::ICMP_UGT;
10954 Changed = true;
10955 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
10956 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10957 Pred = ICmpInst::ICMP_UGT;
10958 Changed = true;
10959 }
10960 break;
10961 default:
10962 break;
10963 }
10964
10965 // TODO: More simplifications are possible here.
10966
10967 // Recursively simplify until we either hit a recursion limit or nothing
10968 // changes.
10969 if (Changed)
10970 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10971
10972 return Changed;
10973}
10974
10976 return getSignedRangeMax(S).isNegative();
10977}
10978
10981}
10982
10984 return !getSignedRangeMin(S).isNegative();
10985}
10986
10989}
10990
10992 // Query push down for cases where the unsigned range is
10993 // less than sufficient.
10994 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10995 return isKnownNonZero(SExt->getOperand(0));
10996 return getUnsignedRangeMin(S) != 0;
10997}
10998
11000 bool OrNegative) {
11001 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11002 if (auto *C = dyn_cast<SCEVConstant>(S))
11003 return C->getAPInt().isPowerOf2() ||
11004 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11005
11006 // The vscale_range indicates vscale is a power-of-two.
11007 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11008 };
11009
11010 if (NonRecursive(S))
11011 return true;
11012
11013 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11014 if (!Mul)
11015 return false;
11016 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11017}
11018
11020 const SCEV *S, uint64_t M,
11022 if (M == 0)
11023 return false;
11024 if (M == 1)
11025 return true;
11026
11027 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11028 // starts with a multiple of M and at every iteration step S only adds
11029 // multiples of M.
11030 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11031 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11032 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11033
11034 // For a constant, check that "S % M == 0".
11035 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11036 APInt C = Cst->getAPInt();
11037 return C.urem(M) == 0;
11038 }
11039
11040 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11041
11042 // Basic tests have failed.
11043 // Check "S % M == 0" at compile time and record runtime Assumptions.
11044 auto *STy = dyn_cast<IntegerType>(S->getType());
11045 const SCEV *SmodM =
11046 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11047 const SCEV *Zero = getZero(STy);
11048
11049 // Check whether "S % M == 0" is known at compile time.
11050 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11051 return true;
11052
11053 // Check whether "S % M != 0" is known at compile time.
11054 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11055 return false;
11056
11058
11059 // Detect redundant predicates.
11060 for (auto *A : Assumptions)
11061 if (A->implies(P, *this))
11062 return true;
11063
11064 // Only record non-redundant predicates.
11065 Assumptions.push_back(P);
11066 return true;
11067}
11068
11069std::pair<const SCEV *, const SCEV *>
11071 // Compute SCEV on entry of loop L.
11072 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11073 if (Start == getCouldNotCompute())
11074 return { Start, Start };
11075 // Compute post increment SCEV for loop L.
11076 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11077 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11078 return { Start, PostInc };
11079}
11080
11082 const SCEV *RHS) {
11083 // First collect all loops.
11085 getUsedLoops(LHS, LoopsUsed);
11086 getUsedLoops(RHS, LoopsUsed);
11087
11088 if (LoopsUsed.empty())
11089 return false;
11090
11091 // Domination relationship must be a linear order on collected loops.
11092#ifndef NDEBUG
11093 for (const auto *L1 : LoopsUsed)
11094 for (const auto *L2 : LoopsUsed)
11095 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11096 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11097 "Domination relationship is not a linear order");
11098#endif
11099
11100 const Loop *MDL =
11101 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11102 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11103 });
11104
11105 // Get init and post increment value for LHS.
11106 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11107 // if LHS contains unknown non-invariant SCEV then bail out.
11108 if (SplitLHS.first == getCouldNotCompute())
11109 return false;
11110 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11111 // Get init and post increment value for RHS.
11112 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11113 // if RHS contains unknown non-invariant SCEV then bail out.
11114 if (SplitRHS.first == getCouldNotCompute())
11115 return false;
11116 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11117 // It is possible that init SCEV contains an invariant load but it does
11118 // not dominate MDL and is not available at MDL loop entry, so we should
11119 // check it here.
11120 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11121 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11122 return false;
11123
11124 // It seems backedge guard check is faster than entry one so in some cases
11125 // it can speed up whole estimation by short circuit
11126 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11127 SplitRHS.second) &&
11128 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11129}
11130
11132 const SCEV *RHS) {
11133 // Canonicalize the inputs first.
11134 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11135
11136 if (isKnownViaInduction(Pred, LHS, RHS))
11137 return true;
11138
11139 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11140 return true;
11141
11142 // Otherwise see what can be done with some simple reasoning.
11143 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11144}
11145
11147 const SCEV *LHS,
11148 const SCEV *RHS) {
11149 if (isKnownPredicate(Pred, LHS, RHS))
11150 return true;
11152 return false;
11153 return std::nullopt;
11154}
11155
11157 const SCEV *RHS,
11158 const Instruction *CtxI) {
11159 // TODO: Analyze guards and assumes from Context's block.
11160 return isKnownPredicate(Pred, LHS, RHS) ||
11162}
11163
11164std::optional<bool>
11166 const SCEV *RHS, const Instruction *CtxI) {
11167 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11168 if (KnownWithoutContext)
11169 return KnownWithoutContext;
11170
11171 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11172 return true;
11175 return false;
11176 return std::nullopt;
11177}
11178
11180 const SCEVAddRecExpr *LHS,
11181 const SCEV *RHS) {
11182 const Loop *L = LHS->getLoop();
11183 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11184 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11185}
11186
11187std::optional<ScalarEvolution::MonotonicPredicateType>
11189 ICmpInst::Predicate Pred) {
11190 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11191
11192#ifndef NDEBUG
11193 // Verify an invariant: inverting the predicate should turn a monotonically
11194 // increasing change to a monotonically decreasing one, and vice versa.
11195 if (Result) {
11196 auto ResultSwapped =
11197 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11198
11199 assert(*ResultSwapped != *Result &&
11200 "monotonicity should flip as we flip the predicate");
11201 }
11202#endif
11203
11204 return Result;
11205}
11206
11207std::optional<ScalarEvolution::MonotonicPredicateType>
11208ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11209 ICmpInst::Predicate Pred) {
11210 // A zero step value for LHS means the induction variable is essentially a
11211 // loop invariant value. We don't really depend on the predicate actually
11212 // flipping from false to true (for increasing predicates, and the other way
11213 // around for decreasing predicates), all we care about is that *if* the
11214 // predicate changes then it only changes from false to true.
11215 //
11216 // A zero step value in itself is not very useful, but there may be places
11217 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11218 // as general as possible.
11219
11220 // Only handle LE/LT/GE/GT predicates.
11221 if (!ICmpInst::isRelational(Pred))
11222 return std::nullopt;
11223
11224 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11225 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11226 "Should be greater or less!");
11227
11228 // Check that AR does not wrap.
11229 if (ICmpInst::isUnsigned(Pred)) {
11230 if (!LHS->hasNoUnsignedWrap())
11231 return std::nullopt;
11233 }
11234 assert(ICmpInst::isSigned(Pred) &&
11235 "Relational predicate is either signed or unsigned!");
11236 if (!LHS->hasNoSignedWrap())
11237 return std::nullopt;
11238
11239 const SCEV *Step = LHS->getStepRecurrence(*this);
11240
11241 if (isKnownNonNegative(Step))
11243
11244 if (isKnownNonPositive(Step))
11246
11247 return std::nullopt;
11248}
11249
11250std::optional<ScalarEvolution::LoopInvariantPredicate>
11252 const SCEV *RHS, const Loop *L,
11253 const Instruction *CtxI) {
11254 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11255 if (!isLoopInvariant(RHS, L)) {
11256 if (!isLoopInvariant(LHS, L))
11257 return std::nullopt;
11258
11259 std::swap(LHS, RHS);
11261 }
11262
11263 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11264 if (!ArLHS || ArLHS->getLoop() != L)
11265 return std::nullopt;
11266
11267 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11268 if (!MonotonicType)
11269 return std::nullopt;
11270 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11271 // true as the loop iterates, and the backedge is control dependent on
11272 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11273 //
11274 // * if the predicate was false in the first iteration then the predicate
11275 // is never evaluated again, since the loop exits without taking the
11276 // backedge.
11277 // * if the predicate was true in the first iteration then it will
11278 // continue to be true for all future iterations since it is
11279 // monotonically increasing.
11280 //
11281 // For both the above possibilities, we can replace the loop varying
11282 // predicate with its value on the first iteration of the loop (which is
11283 // loop invariant).
11284 //
11285 // A similar reasoning applies for a monotonically decreasing predicate, by
11286 // replacing true with false and false with true in the above two bullets.
11288 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11289
11292 RHS);
11293
11294 if (!CtxI)
11295 return std::nullopt;
11296 // Try to prove via context.
11297 // TODO: Support other cases.
11298 switch (Pred) {
11299 default:
11300 break;
11301 case ICmpInst::ICMP_ULE:
11302 case ICmpInst::ICMP_ULT: {
11303 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11304 // Given preconditions
11305 // (1) ArLHS does not cross the border of positive and negative parts of
11306 // range because of:
11307 // - Positive step; (TODO: lift this limitation)
11308 // - nuw - does not cross zero boundary;
11309 // - nsw - does not cross SINT_MAX boundary;
11310 // (2) ArLHS <s RHS
11311 // (3) RHS >=s 0
11312 // we can replace the loop variant ArLHS <u RHS condition with loop
11313 // invariant Start(ArLHS) <u RHS.
11314 //
11315 // Because of (1) there are two options:
11316 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11317 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11318 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11319 // Because of (2) ArLHS <u RHS is trivially true.
11320 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11321 // We can strengthen this to Start(ArLHS) <u RHS.
11322 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11323 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11324 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11326 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11328 RHS);
11329 }
11330 }
11331
11332 return std::nullopt;
11333}
11334
11335std::optional<ScalarEvolution::LoopInvariantPredicate>
11337 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11338 const Instruction *CtxI, const SCEV *MaxIter) {
11340 Pred, LHS, RHS, L, CtxI, MaxIter))
11341 return LIP;
11342 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11343 // Number of iterations expressed as UMIN isn't always great for expressing
11344 // the value on the last iteration. If the straightforward approach didn't
11345 // work, try the following trick: if the a predicate is invariant for X, it
11346 // is also invariant for umin(X, ...). So try to find something that works
11347 // among subexpressions of MaxIter expressed as umin.
11348 for (auto *Op : UMin->operands())
11350 Pred, LHS, RHS, L, CtxI, Op))
11351 return LIP;
11352 return std::nullopt;
11353}
11354
11355std::optional<ScalarEvolution::LoopInvariantPredicate>
11357 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11358 const Instruction *CtxI, const SCEV *MaxIter) {
11359 // Try to prove the following set of facts:
11360 // - The predicate is monotonic in the iteration space.
11361 // - If the check does not fail on the 1st iteration:
11362 // - No overflow will happen during first MaxIter iterations;
11363 // - It will not fail on the MaxIter'th iteration.
11364 // If the check does fail on the 1st iteration, we leave the loop and no
11365 // other checks matter.
11366
11367 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11368 if (!isLoopInvariant(RHS, L)) {
11369 if (!isLoopInvariant(LHS, L))
11370 return std::nullopt;
11371
11372 std::swap(LHS, RHS);
11374 }
11375
11376 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11377 if (!AR || AR->getLoop() != L)
11378 return std::nullopt;
11379
11380 // The predicate must be relational (i.e. <, <=, >=, >).
11381 if (!ICmpInst::isRelational(Pred))
11382 return std::nullopt;
11383
11384 // TODO: Support steps other than +/- 1.
11385 const SCEV *Step = AR->getStepRecurrence(*this);
11386 auto *One = getOne(Step->getType());
11387 auto *MinusOne = getNegativeSCEV(One);
11388 if (Step != One && Step != MinusOne)
11389 return std::nullopt;
11390
11391 // Type mismatch here means that MaxIter is potentially larger than max
11392 // unsigned value in start type, which mean we cannot prove no wrap for the
11393 // indvar.
11394 if (AR->getType() != MaxIter->getType())
11395 return std::nullopt;
11396
11397 // Value of IV on suggested last iteration.
11398 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11399 // Does it still meet the requirement?
11400 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11401 return std::nullopt;
11402 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11403 // not exceed max unsigned value of this type), this effectively proves
11404 // that there is no wrap during the iteration. To prove that there is no
11405 // signed/unsigned wrap, we need to check that
11406 // Start <= Last for step = 1 or Start >= Last for step = -1.
11407 ICmpInst::Predicate NoOverflowPred =
11409 if (Step == MinusOne)
11410 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11411 const SCEV *Start = AR->getStart();
11412 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11413 return std::nullopt;
11414
11415 // Everything is fine.
11416 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11417}
11418
11419bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11420 const SCEV *LHS,
11421 const SCEV *RHS) {
11422 if (HasSameValue(LHS, RHS))
11423 return ICmpInst::isTrueWhenEqual(Pred);
11424
11425 auto CheckRange = [&](bool IsSigned) {
11426 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11427 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11428 return RangeLHS.icmp(Pred, RangeRHS);
11429 };
11430
11431 // The check at the top of the function catches the case where the values are
11432 // known to be equal.
11433 if (Pred == CmpInst::ICMP_EQ)
11434 return false;
11435
11436 if (Pred == CmpInst::ICMP_NE) {
11437 if (CheckRange(true) || CheckRange(false))
11438 return true;
11439 auto *Diff = getMinusSCEV(LHS, RHS);
11440 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11441 }
11442
11443 return CheckRange(CmpInst::isSigned(Pred));
11444}
11445
11446bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11447 const SCEV *LHS,
11448 const SCEV *RHS) {
11449 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11450 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11451 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11452 // OutC1 and OutC2.
11453 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11454 APInt &OutC1, APInt &OutC2,
11455 SCEV::NoWrapFlags ExpectedFlags) {
11456 const SCEV *XNonConstOp, *XConstOp;
11457 const SCEV *YNonConstOp, *YConstOp;
11458 SCEV::NoWrapFlags XFlagsPresent;
11459 SCEV::NoWrapFlags YFlagsPresent;
11460
11461 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11462 XConstOp = getZero(X->getType());
11463 XNonConstOp = X;
11464 XFlagsPresent = ExpectedFlags;
11465 }
11466 if (!isa<SCEVConstant>(XConstOp))
11467 return false;
11468
11469 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11470 YConstOp = getZero(Y->getType());
11471 YNonConstOp = Y;
11472 YFlagsPresent = ExpectedFlags;
11473 }
11474
11475 if (YNonConstOp != XNonConstOp)
11476 return false;
11477
11478 if (!isa<SCEVConstant>(YConstOp))
11479 return false;
11480
11481 // When matching ADDs with NUW flags (and unsigned predicates), only the
11482 // second ADD (with the larger constant) requires NUW.
11483 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11484 return false;
11485 if (ExpectedFlags != SCEV::FlagNUW &&
11486 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11487 return false;
11488 }
11489
11490 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11491 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11492
11493 return true;
11494 };
11495
11496 APInt C1;
11497 APInt C2;
11498
11499 switch (Pred) {
11500 default:
11501 break;
11502
11503 case ICmpInst::ICMP_SGE:
11504 std::swap(LHS, RHS);
11505 [[fallthrough]];
11506 case ICmpInst::ICMP_SLE:
11507 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11508 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11509 return true;
11510
11511 break;
11512
11513 case ICmpInst::ICMP_SGT:
11514 std::swap(LHS, RHS);
11515 [[fallthrough]];
11516 case ICmpInst::ICMP_SLT:
11517 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11518 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11519 return true;
11520
11521 break;
11522
11523 case ICmpInst::ICMP_UGE:
11524 std::swap(LHS, RHS);
11525 [[fallthrough]];
11526 case ICmpInst::ICMP_ULE:
11527 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11528 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11529 return true;
11530
11531 break;
11532
11533 case ICmpInst::ICMP_UGT:
11534 std::swap(LHS, RHS);
11535 [[fallthrough]];
11536 case ICmpInst::ICMP_ULT:
11537 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11538 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11539 return true;
11540 break;
11541 }
11542
11543 return false;
11544}
11545
11546bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11547 const SCEV *LHS,
11548 const SCEV *RHS) {
11549 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11550 return false;
11551
11552 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11553 // the stack can result in exponential time complexity.
11554 SaveAndRestore Restore(ProvingSplitPredicate, true);
11555
11556 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11557 //
11558 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11559 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11560 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11561 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11562 // use isKnownPredicate later if needed.
11563 return isKnownNonNegative(RHS) &&
11566}
11567
11568bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11569 const SCEV *LHS, const SCEV *RHS) {
11570 // No need to even try if we know the module has no guards.
11571 if (!HasGuards)
11572 return false;
11573
11574 return any_of(*BB, [&](const Instruction &I) {
11575 using namespace llvm::PatternMatch;
11576
11577 Value *Condition;
11578 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11579 m_Value(Condition))) &&
11580 isImpliedCond(Pred, LHS, RHS, Condition, false);
11581 });
11582}
11583
11584/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11585/// protected by a conditional between LHS and RHS. This is used to
11586/// to eliminate casts.
11588 CmpPredicate Pred,
11589 const SCEV *LHS,
11590 const SCEV *RHS) {
11591 // Interpret a null as meaning no loop, where there is obviously no guard
11592 // (interprocedural conditions notwithstanding). Do not bother about
11593 // unreachable loops.
11594 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11595 return true;
11596
11597 if (VerifyIR)
11598 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11599 "This cannot be done on broken IR!");
11600
11601
11602 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11603 return true;
11604
11605 BasicBlock *Latch = L->getLoopLatch();
11606 if (!Latch)
11607 return false;
11608
11609 BranchInst *LoopContinuePredicate =
11610 dyn_cast<BranchInst>(Latch->getTerminator());
11611 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11612 isImpliedCond(Pred, LHS, RHS,
11613 LoopContinuePredicate->getCondition(),
11614 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11615 return true;
11616
11617 // We don't want more than one activation of the following loops on the stack
11618 // -- that can lead to O(n!) time complexity.
11619 if (WalkingBEDominatingConds)
11620 return false;
11621
11622 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11623
11624 // See if we can exploit a trip count to prove the predicate.
11625 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11626 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11627 if (LatchBECount != getCouldNotCompute()) {
11628 // We know that Latch branches back to the loop header exactly
11629 // LatchBECount times. This means the backdege condition at Latch is
11630 // equivalent to "{0,+,1} u< LatchBECount".
11631 Type *Ty = LatchBECount->getType();
11632 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11633 const SCEV *LoopCounter =
11634 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11635 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11636 LatchBECount))
11637 return true;
11638 }
11639
11640 // Check conditions due to any @llvm.assume intrinsics.
11641 for (auto &AssumeVH : AC.assumptions()) {
11642 if (!AssumeVH)
11643 continue;
11644 auto *CI = cast<CallInst>(AssumeVH);
11645 if (!DT.dominates(CI, Latch->getTerminator()))
11646 continue;
11647
11648 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11649 return true;
11650 }
11651
11652 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11653 return true;
11654
11655 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11656 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11657 assert(DTN && "should reach the loop header before reaching the root!");
11658
11659 BasicBlock *BB = DTN->getBlock();
11660 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11661 return true;
11662
11663 BasicBlock *PBB = BB->getSinglePredecessor();
11664 if (!PBB)
11665 continue;
11666
11667 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11668 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11669 continue;
11670
11671 Value *Condition = ContinuePredicate->getCondition();
11672
11673 // If we have an edge `E` within the loop body that dominates the only
11674 // latch, the condition guarding `E` also guards the backedge. This
11675 // reasoning works only for loops with a single latch.
11676
11677 BasicBlockEdge DominatingEdge(PBB, BB);
11678 if (DominatingEdge.isSingleEdge()) {
11679 // We're constructively (and conservatively) enumerating edges within the
11680 // loop body that dominate the latch. The dominator tree better agree
11681 // with us on this:
11682 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11683
11684 if (isImpliedCond(Pred, LHS, RHS, Condition,
11685 BB != ContinuePredicate->getSuccessor(0)))
11686 return true;
11687 }
11688 }
11689
11690 return false;
11691}
11692
11694 CmpPredicate Pred,
11695 const SCEV *LHS,
11696 const SCEV *RHS) {
11697 // Do not bother proving facts for unreachable code.
11698 if (!DT.isReachableFromEntry(BB))
11699 return true;
11700 if (VerifyIR)
11701 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11702 "This cannot be done on broken IR!");
11703
11704 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11705 // the facts (a >= b && a != b) separately. A typical situation is when the
11706 // non-strict comparison is known from ranges and non-equality is known from
11707 // dominating predicates. If we are proving strict comparison, we always try
11708 // to prove non-equality and non-strict comparison separately.
11709 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11710 const bool ProvingStrictComparison =
11711 Pred != NonStrictPredicate.dropSameSign();
11712 bool ProvedNonStrictComparison = false;
11713 bool ProvedNonEquality = false;
11714
11715 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11716 if (!ProvedNonStrictComparison)
11717 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11718 if (!ProvedNonEquality)
11719 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11720 if (ProvedNonStrictComparison && ProvedNonEquality)
11721 return true;
11722 return false;
11723 };
11724
11725 if (ProvingStrictComparison) {
11726 auto ProofFn = [&](CmpPredicate P) {
11727 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11728 };
11729 if (SplitAndProve(ProofFn))
11730 return true;
11731 }
11732
11733 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11734 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11735 const Instruction *CtxI = &BB->front();
11736 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11737 return true;
11738 if (ProvingStrictComparison) {
11739 auto ProofFn = [&](CmpPredicate P) {
11740 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11741 };
11742 if (SplitAndProve(ProofFn))
11743 return true;
11744 }
11745 return false;
11746 };
11747
11748 // Starting at the block's predecessor, climb up the predecessor chain, as long
11749 // as there are predecessors that can be found that have unique successors
11750 // leading to the original block.
11751 const Loop *ContainingLoop = LI.getLoopFor(BB);
11752 const BasicBlock *PredBB;
11753 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11754 PredBB = ContainingLoop->getLoopPredecessor();
11755 else
11756 PredBB = BB->getSinglePredecessor();
11757 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11758 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11759 const BranchInst *BlockEntryPredicate =
11760 dyn_cast<BranchInst>(Pair.first->getTerminator());
11761 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11762 continue;
11763
11764 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11765 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11766 return true;
11767 }
11768
11769 // Check conditions due to any @llvm.assume intrinsics.
11770 for (auto &AssumeVH : AC.assumptions()) {
11771 if (!AssumeVH)
11772 continue;
11773 auto *CI = cast<CallInst>(AssumeVH);
11774 if (!DT.dominates(CI, BB))
11775 continue;
11776
11777 if (ProveViaCond(CI->getArgOperand(0), false))
11778 return true;
11779 }
11780
11781 // Check conditions due to any @llvm.experimental.guard intrinsics.
11782 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11783 F.getParent(), Intrinsic::experimental_guard);
11784 if (GuardDecl)
11785 for (const auto *GU : GuardDecl->users())
11786 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11787 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11788 if (ProveViaCond(Guard->getArgOperand(0), false))
11789 return true;
11790 return false;
11791}
11792
11794 const SCEV *LHS,
11795 const SCEV *RHS) {
11796 // Interpret a null as meaning no loop, where there is obviously no guard
11797 // (interprocedural conditions notwithstanding).
11798 if (!L)
11799 return false;
11800
11801 // Both LHS and RHS must be available at loop entry.
11803 "LHS is not available at Loop Entry");
11805 "RHS is not available at Loop Entry");
11806
11807 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11808 return true;
11809
11810 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11811}
11812
11813bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11814 const SCEV *RHS,
11815 const Value *FoundCondValue, bool Inverse,
11816 const Instruction *CtxI) {
11817 // False conditions implies anything. Do not bother analyzing it further.
11818 if (FoundCondValue ==
11819 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11820 return true;
11821
11822 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11823 return false;
11824
11825 auto ClearOnExit =
11826 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11827
11828 // Recursively handle And and Or conditions.
11829 const Value *Op0, *Op1;
11830 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11831 if (!Inverse)
11832 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11833 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11834 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11835 if (Inverse)
11836 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11837 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11838 }
11839
11840 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11841 if (!ICI) return false;
11842
11843 // Now that we found a conditional branch that dominates the loop or controls
11844 // the loop latch. Check to see if it is the comparison we are looking for.
11845 CmpPredicate FoundPred;
11846 if (Inverse)
11847 FoundPred = ICI->getInverseCmpPredicate();
11848 else
11849 FoundPred = ICI->getCmpPredicate();
11850
11851 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11852 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11853
11854 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11855}
11856
11857bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11858 const SCEV *RHS, CmpPredicate FoundPred,
11859 const SCEV *FoundLHS, const SCEV *FoundRHS,
11860 const Instruction *CtxI) {
11861 // Balance the types.
11862 if (getTypeSizeInBits(LHS->getType()) <
11863 getTypeSizeInBits(FoundLHS->getType())) {
11864 // For unsigned and equality predicates, try to prove that both found
11865 // operands fit into narrow unsigned range. If so, try to prove facts in
11866 // narrow types.
11867 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11868 !FoundRHS->getType()->isPointerTy()) {
11869 auto *NarrowType = LHS->getType();
11870 auto *WideType = FoundLHS->getType();
11871 auto BitWidth = getTypeSizeInBits(NarrowType);
11872 const SCEV *MaxValue = getZeroExtendExpr(
11874 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11875 MaxValue) &&
11876 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11877 MaxValue)) {
11878 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11879 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11880 // We cannot preserve samesign after truncation.
11881 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11882 TruncFoundLHS, TruncFoundRHS, CtxI))
11883 return true;
11884 }
11885 }
11886
11887 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11888 return false;
11889 if (CmpInst::isSigned(Pred)) {
11890 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11891 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11892 } else {
11893 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11894 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11895 }
11896 } else if (getTypeSizeInBits(LHS->getType()) >
11897 getTypeSizeInBits(FoundLHS->getType())) {
11898 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11899 return false;
11900 if (CmpInst::isSigned(FoundPred)) {
11901 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11902 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11903 } else {
11904 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11905 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11906 }
11907 }
11908 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11909 FoundRHS, CtxI);
11910}
11911
11912bool ScalarEvolution::isImpliedCondBalancedTypes(
11913 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11914 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11916 getTypeSizeInBits(FoundLHS->getType()) &&
11917 "Types should be balanced!");
11918 // Canonicalize the query to match the way instcombine will have
11919 // canonicalized the comparison.
11920 if (SimplifyICmpOperands(Pred, LHS, RHS))
11921 if (LHS == RHS)
11922 return CmpInst::isTrueWhenEqual(Pred);
11923 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11924 if (FoundLHS == FoundRHS)
11925 return CmpInst::isFalseWhenEqual(FoundPred);
11926
11927 // Check to see if we can make the LHS or RHS match.
11928 if (LHS == FoundRHS || RHS == FoundLHS) {
11929 if (isa<SCEVConstant>(RHS)) {
11930 std::swap(FoundLHS, FoundRHS);
11931 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11932 } else {
11933 std::swap(LHS, RHS);
11935 }
11936 }
11937
11938 // Check whether the found predicate is the same as the desired predicate.
11939 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11940 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11941
11942 // Check whether swapping the found predicate makes it the same as the
11943 // desired predicate.
11944 if (auto P = CmpPredicate::getMatching(
11945 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11946 // We can write the implication
11947 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11948 // using one of the following ways:
11949 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11950 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11951 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11952 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11953 // Forms 1. and 2. require swapping the operands of one condition. Don't
11954 // do this if it would break canonical constant/addrec ordering.
11955 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11956 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
11957 LHS, FoundLHS, FoundRHS, CtxI);
11958 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11959 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11960
11961 // There's no clear preference between forms 3. and 4., try both. Avoid
11962 // forming getNotSCEV of pointer values as the resulting subtract is
11963 // not legal.
11964 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11965 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
11966 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
11967 FoundRHS, CtxI))
11968 return true;
11969
11970 if (!FoundLHS->getType()->isPointerTy() &&
11971 !FoundRHS->getType()->isPointerTy() &&
11972 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
11973 getNotSCEV(FoundRHS), CtxI))
11974 return true;
11975
11976 return false;
11977 }
11978
11979 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11980 CmpInst::Predicate P2) {
11981 assert(P1 != P2 && "Handled earlier!");
11982 return CmpInst::isRelational(P2) &&
11984 };
11985 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11986 // Unsigned comparison is the same as signed comparison when both the
11987 // operands are non-negative or negative.
11988 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11989 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11990 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11991 // Create local copies that we can freely swap and canonicalize our
11992 // conditions to "le/lt".
11993 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11994 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11995 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11996 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11997 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
11998 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
11999 std::swap(CanonicalLHS, CanonicalRHS);
12000 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12001 }
12002 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12003 "Must be!");
12004 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12005 ICmpInst::isLE(CanonicalFoundPred)) &&
12006 "Must be!");
12007 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12008 // Use implication:
12009 // x <u y && y >=s 0 --> x <s y.
12010 // If we can prove the left part, the right part is also proven.
12011 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12012 CanonicalRHS, CanonicalFoundLHS,
12013 CanonicalFoundRHS);
12014 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12015 // Use implication:
12016 // x <s y && y <s 0 --> x <u y.
12017 // If we can prove the left part, the right part is also proven.
12018 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12019 CanonicalRHS, CanonicalFoundLHS,
12020 CanonicalFoundRHS);
12021 }
12022
12023 // Check if we can make progress by sharpening ranges.
12024 if (FoundPred == ICmpInst::ICMP_NE &&
12025 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12026
12027 const SCEVConstant *C = nullptr;
12028 const SCEV *V = nullptr;
12029
12030 if (isa<SCEVConstant>(FoundLHS)) {
12031 C = cast<SCEVConstant>(FoundLHS);
12032 V = FoundRHS;
12033 } else {
12034 C = cast<SCEVConstant>(FoundRHS);
12035 V = FoundLHS;
12036 }
12037
12038 // The guarding predicate tells us that C != V. If the known range
12039 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12040 // range we consider has to correspond to same signedness as the
12041 // predicate we're interested in folding.
12042
12043 APInt Min = ICmpInst::isSigned(Pred) ?
12045
12046 if (Min == C->getAPInt()) {
12047 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12048 // This is true even if (Min + 1) wraps around -- in case of
12049 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12050
12051 APInt SharperMin = Min + 1;
12052
12053 switch (Pred) {
12054 case ICmpInst::ICMP_SGE:
12055 case ICmpInst::ICMP_UGE:
12056 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12057 // RHS, we're done.
12058 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12059 CtxI))
12060 return true;
12061 [[fallthrough]];
12062
12063 case ICmpInst::ICMP_SGT:
12064 case ICmpInst::ICMP_UGT:
12065 // We know from the range information that (V `Pred` Min ||
12066 // V == Min). We know from the guarding condition that !(V
12067 // == Min). This gives us
12068 //
12069 // V `Pred` Min || V == Min && !(V == Min)
12070 // => V `Pred` Min
12071 //
12072 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12073
12074 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12075 return true;
12076 break;
12077
12078 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12079 case ICmpInst::ICMP_SLE:
12080 case ICmpInst::ICMP_ULE:
12081 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12082 LHS, V, getConstant(SharperMin), CtxI))
12083 return true;
12084 [[fallthrough]];
12085
12086 case ICmpInst::ICMP_SLT:
12087 case ICmpInst::ICMP_ULT:
12088 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12089 LHS, V, getConstant(Min), CtxI))
12090 return true;
12091 break;
12092
12093 default:
12094 // No change
12095 break;
12096 }
12097 }
12098 }
12099
12100 // Check whether the actual condition is beyond sufficient.
12101 if (FoundPred == ICmpInst::ICMP_EQ)
12102 if (ICmpInst::isTrueWhenEqual(Pred))
12103 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12104 return true;
12105 if (Pred == ICmpInst::ICMP_NE)
12106 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12107 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12108 return true;
12109
12110 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12111 return true;
12112
12113 // Otherwise assume the worst.
12114 return false;
12115}
12116
12117bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12118 const SCEV *&L, const SCEV *&R,
12119 SCEV::NoWrapFlags &Flags) {
12120 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12121 if (!AE || AE->getNumOperands() != 2)
12122 return false;
12123
12124 L = AE->getOperand(0);
12125 R = AE->getOperand(1);
12126 Flags = AE->getNoWrapFlags();
12127 return true;
12128}
12129
12130std::optional<APInt>
12132 // We avoid subtracting expressions here because this function is usually
12133 // fairly deep in the call stack (i.e. is called many times).
12134
12135 unsigned BW = getTypeSizeInBits(More->getType());
12136 APInt Diff(BW, 0);
12137 APInt DiffMul(BW, 1);
12138 // Try various simplifications to reduce the difference to a constant. Limit
12139 // the number of allowed simplifications to keep compile-time low.
12140 for (unsigned I = 0; I < 8; ++I) {
12141 if (More == Less)
12142 return Diff;
12143
12144 // Reduce addrecs with identical steps to their start value.
12145 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
12146 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12147 const auto *MAR = cast<SCEVAddRecExpr>(More);
12148
12149 if (LAR->getLoop() != MAR->getLoop())
12150 return std::nullopt;
12151
12152 // We look at affine expressions only; not for correctness but to keep
12153 // getStepRecurrence cheap.
12154 if (!LAR->isAffine() || !MAR->isAffine())
12155 return std::nullopt;
12156
12157 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12158 return std::nullopt;
12159
12160 Less = LAR->getStart();
12161 More = MAR->getStart();
12162 continue;
12163 }
12164
12165 // Try to match a common constant multiply.
12166 auto MatchConstMul =
12167 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12168 auto *M = dyn_cast<SCEVMulExpr>(S);
12169 if (!M || M->getNumOperands() != 2 ||
12170 !isa<SCEVConstant>(M->getOperand(0)))
12171 return std::nullopt;
12172 return {
12173 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12174 };
12175 if (auto MatchedMore = MatchConstMul(More)) {
12176 if (auto MatchedLess = MatchConstMul(Less)) {
12177 if (MatchedMore->second == MatchedLess->second) {
12178 More = MatchedMore->first;
12179 Less = MatchedLess->first;
12180 DiffMul *= MatchedMore->second;
12181 continue;
12182 }
12183 }
12184 }
12185
12186 // Try to cancel out common factors in two add expressions.
12188 auto Add = [&](const SCEV *S, int Mul) {
12189 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12190 if (Mul == 1) {
12191 Diff += C->getAPInt() * DiffMul;
12192 } else {
12193 assert(Mul == -1);
12194 Diff -= C->getAPInt() * DiffMul;
12195 }
12196 } else
12197 Multiplicity[S] += Mul;
12198 };
12199 auto Decompose = [&](const SCEV *S, int Mul) {
12200 if (isa<SCEVAddExpr>(S)) {
12201 for (const SCEV *Op : S->operands())
12202 Add(Op, Mul);
12203 } else
12204 Add(S, Mul);
12205 };
12206 Decompose(More, 1);
12207 Decompose(Less, -1);
12208
12209 // Check whether all the non-constants cancel out, or reduce to new
12210 // More/Less values.
12211 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12212 for (const auto &[S, Mul] : Multiplicity) {
12213 if (Mul == 0)
12214 continue;
12215 if (Mul == 1) {
12216 if (NewMore)
12217 return std::nullopt;
12218 NewMore = S;
12219 } else if (Mul == -1) {
12220 if (NewLess)
12221 return std::nullopt;
12222 NewLess = S;
12223 } else
12224 return std::nullopt;
12225 }
12226
12227 // Values stayed the same, no point in trying further.
12228 if (NewMore == More || NewLess == Less)
12229 return std::nullopt;
12230
12231 More = NewMore;
12232 Less = NewLess;
12233
12234 // Reduced to constant.
12235 if (!More && !Less)
12236 return Diff;
12237
12238 // Left with variable on only one side, bail out.
12239 if (!More || !Less)
12240 return std::nullopt;
12241 }
12242
12243 // Did not reduce to constant.
12244 return std::nullopt;
12245}
12246
12247bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12248 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12249 const SCEV *FoundRHS, const Instruction *CtxI) {
12250 // Try to recognize the following pattern:
12251 //
12252 // FoundRHS = ...
12253 // ...
12254 // loop:
12255 // FoundLHS = {Start,+,W}
12256 // context_bb: // Basic block from the same loop
12257 // known(Pred, FoundLHS, FoundRHS)
12258 //
12259 // If some predicate is known in the context of a loop, it is also known on
12260 // each iteration of this loop, including the first iteration. Therefore, in
12261 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12262 // prove the original pred using this fact.
12263 if (!CtxI)
12264 return false;
12265 const BasicBlock *ContextBB = CtxI->getParent();
12266 // Make sure AR varies in the context block.
12267 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12268 const Loop *L = AR->getLoop();
12269 // Make sure that context belongs to the loop and executes on 1st iteration
12270 // (if it ever executes at all).
12271 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12272 return false;
12273 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12274 return false;
12275 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12276 }
12277
12278 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12279 const Loop *L = AR->getLoop();
12280 // Make sure that context belongs to the loop and executes on 1st iteration
12281 // (if it ever executes at all).
12282 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12283 return false;
12284 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12285 return false;
12286 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12287 }
12288
12289 return false;
12290}
12291
12292bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12293 const SCEV *LHS,
12294 const SCEV *RHS,
12295 const SCEV *FoundLHS,
12296 const SCEV *FoundRHS) {
12297 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12298 return false;
12299
12300 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12301 if (!AddRecLHS)
12302 return false;
12303
12304 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12305 if (!AddRecFoundLHS)
12306 return false;
12307
12308 // We'd like to let SCEV reason about control dependencies, so we constrain
12309 // both the inequalities to be about add recurrences on the same loop. This
12310 // way we can use isLoopEntryGuardedByCond later.
12311
12312 const Loop *L = AddRecFoundLHS->getLoop();
12313 if (L != AddRecLHS->getLoop())
12314 return false;
12315
12316 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12317 //
12318 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12319 // ... (2)
12320 //
12321 // Informal proof for (2), assuming (1) [*]:
12322 //
12323 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12324 //
12325 // Then
12326 //
12327 // FoundLHS s< FoundRHS s< INT_MIN - C
12328 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12329 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12330 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12331 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12332 // <=> FoundLHS + C s< FoundRHS + C
12333 //
12334 // [*]: (1) can be proved by ruling out overflow.
12335 //
12336 // [**]: This can be proved by analyzing all the four possibilities:
12337 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12338 // (A s>= 0, B s>= 0).
12339 //
12340 // Note:
12341 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12342 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12343 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12344 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12345 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12346 // C)".
12347
12348 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12349 if (!LDiff)
12350 return false;
12351 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12352 if (!RDiff || *LDiff != *RDiff)
12353 return false;
12354
12355 if (LDiff->isMinValue())
12356 return true;
12357
12358 APInt FoundRHSLimit;
12359
12360 if (Pred == CmpInst::ICMP_ULT) {
12361 FoundRHSLimit = -(*RDiff);
12362 } else {
12363 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12364 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12365 }
12366
12367 // Try to prove (1) or (2), as needed.
12368 return isAvailableAtLoopEntry(FoundRHS, L) &&
12369 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12370 getConstant(FoundRHSLimit));
12371}
12372
12373bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12374 const SCEV *RHS, const SCEV *FoundLHS,
12375 const SCEV *FoundRHS, unsigned Depth) {
12376 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12377
12378 auto ClearOnExit = make_scope_exit([&]() {
12379 if (LPhi) {
12380 bool Erased = PendingMerges.erase(LPhi);
12381 assert(Erased && "Failed to erase LPhi!");
12382 (void)Erased;
12383 }
12384 if (RPhi) {
12385 bool Erased = PendingMerges.erase(RPhi);
12386 assert(Erased && "Failed to erase RPhi!");
12387 (void)Erased;
12388 }
12389 });
12390
12391 // Find respective Phis and check that they are not being pending.
12392 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12393 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12394 if (!PendingMerges.insert(Phi).second)
12395 return false;
12396 LPhi = Phi;
12397 }
12398 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12399 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12400 // If we detect a loop of Phi nodes being processed by this method, for
12401 // example:
12402 //
12403 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12404 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12405 //
12406 // we don't want to deal with a case that complex, so return conservative
12407 // answer false.
12408 if (!PendingMerges.insert(Phi).second)
12409 return false;
12410 RPhi = Phi;
12411 }
12412
12413 // If none of LHS, RHS is a Phi, nothing to do here.
12414 if (!LPhi && !RPhi)
12415 return false;
12416
12417 // If there is a SCEVUnknown Phi we are interested in, make it left.
12418 if (!LPhi) {
12419 std::swap(LHS, RHS);
12420 std::swap(FoundLHS, FoundRHS);
12421 std::swap(LPhi, RPhi);
12423 }
12424
12425 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12426 const BasicBlock *LBB = LPhi->getParent();
12427 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12428
12429 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12430 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12431 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12432 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12433 };
12434
12435 if (RPhi && RPhi->getParent() == LBB) {
12436 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12437 // If we compare two Phis from the same block, and for each entry block
12438 // the predicate is true for incoming values from this block, then the
12439 // predicate is also true for the Phis.
12440 for (const BasicBlock *IncBB : predecessors(LBB)) {
12441 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12442 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12443 if (!ProvedEasily(L, R))
12444 return false;
12445 }
12446 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12447 // Case two: RHS is also a Phi from the same basic block, and it is an
12448 // AddRec. It means that there is a loop which has both AddRec and Unknown
12449 // PHIs, for it we can compare incoming values of AddRec from above the loop
12450 // and latch with their respective incoming values of LPhi.
12451 // TODO: Generalize to handle loops with many inputs in a header.
12452 if (LPhi->getNumIncomingValues() != 2) return false;
12453
12454 auto *RLoop = RAR->getLoop();
12455 auto *Predecessor = RLoop->getLoopPredecessor();
12456 assert(Predecessor && "Loop with AddRec with no predecessor?");
12457 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12458 if (!ProvedEasily(L1, RAR->getStart()))
12459 return false;
12460 auto *Latch = RLoop->getLoopLatch();
12461 assert(Latch && "Loop with AddRec with no latch?");
12462 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12463 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12464 return false;
12465 } else {
12466 // In all other cases go over inputs of LHS and compare each of them to RHS,
12467 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12468 // At this point RHS is either a non-Phi, or it is a Phi from some block
12469 // different from LBB.
12470 for (const BasicBlock *IncBB : predecessors(LBB)) {
12471 // Check that RHS is available in this block.
12472 if (!dominates(RHS, IncBB))
12473 return false;
12474 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12475 // Make sure L does not refer to a value from a potentially previous
12476 // iteration of a loop.
12477 if (!properlyDominates(L, LBB))
12478 return false;
12479 // Addrecs are considered to properly dominate their loop, so are missed
12480 // by the previous check. Discard any values that have computable
12481 // evolution in this loop.
12482 if (auto *Loop = LI.getLoopFor(LBB))
12484 return false;
12485 if (!ProvedEasily(L, RHS))
12486 return false;
12487 }
12488 }
12489 return true;
12490}
12491
12492bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12493 const SCEV *LHS,
12494 const SCEV *RHS,
12495 const SCEV *FoundLHS,
12496 const SCEV *FoundRHS) {
12497 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12498 // sure that we are dealing with same LHS.
12499 if (RHS == FoundRHS) {
12500 std::swap(LHS, RHS);
12501 std::swap(FoundLHS, FoundRHS);
12503 }
12504 if (LHS != FoundLHS)
12505 return false;
12506
12507 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12508 if (!SUFoundRHS)
12509 return false;
12510
12511 Value *Shiftee, *ShiftValue;
12512
12513 using namespace PatternMatch;
12514 if (match(SUFoundRHS->getValue(),
12515 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12516 auto *ShifteeS = getSCEV(Shiftee);
12517 // Prove one of the following:
12518 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12519 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12520 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12521 // ---> LHS <s RHS
12522 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12523 // ---> LHS <=s RHS
12524 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12525 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12526 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12527 if (isKnownNonNegative(ShifteeS))
12528 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12529 }
12530
12531 return false;
12532}
12533
12534bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12535 const SCEV *RHS,
12536 const SCEV *FoundLHS,
12537 const SCEV *FoundRHS,
12538 const Instruction *CtxI) {
12539 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12540 FoundRHS) ||
12541 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12542 FoundRHS) ||
12543 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12544 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12545 CtxI) ||
12546 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12547}
12548
12549/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12550template <typename MinMaxExprType>
12551static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12552 const SCEV *Candidate) {
12553 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12554 if (!MinMaxExpr)
12555 return false;
12556
12557 return is_contained(MinMaxExpr->operands(), Candidate);
12558}
12559
12561 CmpPredicate Pred, const SCEV *LHS,
12562 const SCEV *RHS) {
12563 // If both sides are affine addrecs for the same loop, with equal
12564 // steps, and we know the recurrences don't wrap, then we only
12565 // need to check the predicate on the starting values.
12566
12567 if (!ICmpInst::isRelational(Pred))
12568 return false;
12569
12570 const SCEV *LStart, *RStart, *Step;
12571 const Loop *L;
12572 if (!match(LHS,
12573 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12575 m_SpecificLoop(L))))
12576 return false;
12577 const SCEVAddRecExpr *LAR = cast<SCEVAddRecExpr>(LHS);
12578 const SCEVAddRecExpr *RAR = cast<SCEVAddRecExpr>(RHS);
12581 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12582 return false;
12583
12584 return SE.isKnownPredicate(Pred, LStart, RStart);
12585}
12586
12587/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12588/// expression?
12590 const SCEV *LHS, const SCEV *RHS) {
12591 switch (Pred) {
12592 default:
12593 return false;
12594
12595 case ICmpInst::ICMP_SGE:
12596 std::swap(LHS, RHS);
12597 [[fallthrough]];
12598 case ICmpInst::ICMP_SLE:
12599 return
12600 // min(A, ...) <= A
12601 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12602 // A <= max(A, ...)
12603 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12604
12605 case ICmpInst::ICMP_UGE:
12606 std::swap(LHS, RHS);
12607 [[fallthrough]];
12608 case ICmpInst::ICMP_ULE:
12609 return
12610 // min(A, ...) <= A
12611 // FIXME: what about umin_seq?
12612 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12613 // A <= max(A, ...)
12614 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12615 }
12616
12617 llvm_unreachable("covered switch fell through?!");
12618}
12619
12620bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12621 const SCEV *RHS,
12622 const SCEV *FoundLHS,
12623 const SCEV *FoundRHS,
12624 unsigned Depth) {
12627 "LHS and RHS have different sizes?");
12628 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12629 getTypeSizeInBits(FoundRHS->getType()) &&
12630 "FoundLHS and FoundRHS have different sizes?");
12631 // We want to avoid hurting the compile time with analysis of too big trees.
12633 return false;
12634
12635 // We only want to work with GT comparison so far.
12636 if (ICmpInst::isLT(Pred)) {
12638 std::swap(LHS, RHS);
12639 std::swap(FoundLHS, FoundRHS);
12640 }
12641
12643
12644 // For unsigned, try to reduce it to corresponding signed comparison.
12645 if (P == ICmpInst::ICMP_UGT)
12646 // We can replace unsigned predicate with its signed counterpart if all
12647 // involved values are non-negative.
12648 // TODO: We could have better support for unsigned.
12649 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12650 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12651 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12652 // use this fact to prove that LHS and RHS are non-negative.
12653 const SCEV *MinusOne = getMinusOne(LHS->getType());
12654 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12655 FoundRHS) &&
12656 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12657 FoundRHS))
12659 }
12660
12661 if (P != ICmpInst::ICMP_SGT)
12662 return false;
12663
12664 auto GetOpFromSExt = [&](const SCEV *S) {
12665 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12666 return Ext->getOperand();
12667 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12668 // the constant in some cases.
12669 return S;
12670 };
12671
12672 // Acquire values from extensions.
12673 auto *OrigLHS = LHS;
12674 auto *OrigFoundLHS = FoundLHS;
12675 LHS = GetOpFromSExt(LHS);
12676 FoundLHS = GetOpFromSExt(FoundLHS);
12677
12678 // Is the SGT predicate can be proved trivially or using the found context.
12679 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12680 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12681 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12682 FoundRHS, Depth + 1);
12683 };
12684
12685 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12686 // We want to avoid creation of any new non-constant SCEV. Since we are
12687 // going to compare the operands to RHS, we should be certain that we don't
12688 // need any size extensions for this. So let's decline all cases when the
12689 // sizes of types of LHS and RHS do not match.
12690 // TODO: Maybe try to get RHS from sext to catch more cases?
12692 return false;
12693
12694 // Should not overflow.
12695 if (!LHSAddExpr->hasNoSignedWrap())
12696 return false;
12697
12698 auto *LL = LHSAddExpr->getOperand(0);
12699 auto *LR = LHSAddExpr->getOperand(1);
12700 auto *MinusOne = getMinusOne(RHS->getType());
12701
12702 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12703 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12704 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12705 };
12706 // Try to prove the following rule:
12707 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12708 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12709 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12710 return true;
12711 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12712 Value *LL, *LR;
12713 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12714
12715 using namespace llvm::PatternMatch;
12716
12717 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12718 // Rules for division.
12719 // We are going to perform some comparisons with Denominator and its
12720 // derivative expressions. In general case, creating a SCEV for it may
12721 // lead to a complex analysis of the entire graph, and in particular it
12722 // can request trip count recalculation for the same loop. This would
12723 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12724 // this, we only want to create SCEVs that are constants in this section.
12725 // So we bail if Denominator is not a constant.
12726 if (!isa<ConstantInt>(LR))
12727 return false;
12728
12729 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12730
12731 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12732 // then a SCEV for the numerator already exists and matches with FoundLHS.
12733 auto *Numerator = getExistingSCEV(LL);
12734 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12735 return false;
12736
12737 // Make sure that the numerator matches with FoundLHS and the denominator
12738 // is positive.
12739 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12740 return false;
12741
12742 auto *DTy = Denominator->getType();
12743 auto *FRHSTy = FoundRHS->getType();
12744 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12745 // One of types is a pointer and another one is not. We cannot extend
12746 // them properly to a wider type, so let us just reject this case.
12747 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12748 // to avoid this check.
12749 return false;
12750
12751 // Given that:
12752 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12753 auto *WTy = getWiderType(DTy, FRHSTy);
12754 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12755 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12756
12757 // Try to prove the following rule:
12758 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12759 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12760 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12761 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12762 if (isKnownNonPositive(RHS) &&
12763 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12764 return true;
12765
12766 // Try to prove the following rule:
12767 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12768 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12769 // If we divide it by Denominator > 2, then:
12770 // 1. If FoundLHS is negative, then the result is 0.
12771 // 2. If FoundLHS is non-negative, then the result is non-negative.
12772 // Anyways, the result is non-negative.
12773 auto *MinusOne = getMinusOne(WTy);
12774 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12775 if (isKnownNegative(RHS) &&
12776 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12777 return true;
12778 }
12779 }
12780
12781 // If our expression contained SCEVUnknown Phis, and we split it down and now
12782 // need to prove something for them, try to prove the predicate for every
12783 // possible incoming values of those Phis.
12784 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12785 return true;
12786
12787 return false;
12788}
12789
12790static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
12791 const SCEV *RHS) {
12792 // zext x u<= sext x, sext x s<= zext x
12793 const SCEV *Op;
12794 switch (Pred) {
12795 case ICmpInst::ICMP_SGE:
12796 std::swap(LHS, RHS);
12797 [[fallthrough]];
12798 case ICmpInst::ICMP_SLE: {
12799 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12800 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12802 }
12803 case ICmpInst::ICMP_UGE:
12804 std::swap(LHS, RHS);
12805 [[fallthrough]];
12806 case ICmpInst::ICMP_ULE: {
12807 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12808 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12810 }
12811 default:
12812 return false;
12813 };
12814 llvm_unreachable("unhandled case");
12815}
12816
12817bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12818 const SCEV *LHS,
12819 const SCEV *RHS) {
12820 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12821 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12822 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12823 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12824 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12825}
12826
12827bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12828 const SCEV *LHS,
12829 const SCEV *RHS,
12830 const SCEV *FoundLHS,
12831 const SCEV *FoundRHS) {
12832 switch (Pred) {
12833 default:
12834 llvm_unreachable("Unexpected CmpPredicate value!");
12835 case ICmpInst::ICMP_EQ:
12836 case ICmpInst::ICMP_NE:
12837 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12838 return true;
12839 break;
12840 case ICmpInst::ICMP_SLT:
12841 case ICmpInst::ICMP_SLE:
12842 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12843 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12844 return true;
12845 break;
12846 case ICmpInst::ICMP_SGT:
12847 case ICmpInst::ICMP_SGE:
12848 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12849 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12850 return true;
12851 break;
12852 case ICmpInst::ICMP_ULT:
12853 case ICmpInst::ICMP_ULE:
12854 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12855 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12856 return true;
12857 break;
12858 case ICmpInst::ICMP_UGT:
12859 case ICmpInst::ICMP_UGE:
12860 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12861 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12862 return true;
12863 break;
12864 }
12865
12866 // Maybe it can be proved via operations?
12867 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12868 return true;
12869
12870 return false;
12871}
12872
12873bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12874 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12875 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12876 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12877 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12878 // reduce the compile time impact of this optimization.
12879 return false;
12880
12881 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12882 if (!Addend)
12883 return false;
12884
12885 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12886
12887 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12888 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12889 ConstantRange FoundLHSRange =
12890 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12891
12892 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12893 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12894
12895 // We can also compute the range of values for `LHS` that satisfy the
12896 // consequent, "`LHS` `Pred` `RHS`":
12897 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12898 // The antecedent implies the consequent if every value of `LHS` that
12899 // satisfies the antecedent also satisfies the consequent.
12900 return LHSRange.icmp(Pred, ConstRHS);
12901}
12902
12903bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12904 bool IsSigned) {
12905 assert(isKnownPositive(Stride) && "Positive stride expected!");
12906
12907 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12908 const SCEV *One = getOne(Stride->getType());
12909
12910 if (IsSigned) {
12911 APInt MaxRHS = getSignedRangeMax(RHS);
12913 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12914
12915 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12916 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12917 }
12918
12919 APInt MaxRHS = getUnsignedRangeMax(RHS);
12920 APInt MaxValue = APInt::getMaxValue(BitWidth);
12921 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12922
12923 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12924 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12925}
12926
12927bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12928 bool IsSigned) {
12929
12930 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12931 const SCEV *One = getOne(Stride->getType());
12932
12933 if (IsSigned) {
12934 APInt MinRHS = getSignedRangeMin(RHS);
12936 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12937
12938 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12939 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12940 }
12941
12942 APInt MinRHS = getUnsignedRangeMin(RHS);
12943 APInt MinValue = APInt::getMinValue(BitWidth);
12944 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12945
12946 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12947 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12948}
12949
12951 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12952 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12953 // expression fixes the case of N=0.
12954 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12955 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12956 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12957}
12958
12959const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12960 const SCEV *Stride,
12961 const SCEV *End,
12962 unsigned BitWidth,
12963 bool IsSigned) {
12964 // The logic in this function assumes we can represent a positive stride.
12965 // If we can't, the backedge-taken count must be zero.
12966 if (IsSigned && BitWidth == 1)
12967 return getZero(Stride->getType());
12968
12969 // This code below only been closely audited for negative strides in the
12970 // unsigned comparison case, it may be correct for signed comparison, but
12971 // that needs to be established.
12972 if (IsSigned && isKnownNegative(Stride))
12973 return getCouldNotCompute();
12974
12975 // Calculate the maximum backedge count based on the range of values
12976 // permitted by Start, End, and Stride.
12977 APInt MinStart =
12978 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12979
12980 APInt MinStride =
12981 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12982
12983 // We assume either the stride is positive, or the backedge-taken count
12984 // is zero. So force StrideForMaxBECount to be at least one.
12985 APInt One(BitWidth, 1);
12986 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12987 : APIntOps::umax(One, MinStride);
12988
12989 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12990 : APInt::getMaxValue(BitWidth);
12991 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12992
12993 // Although End can be a MAX expression we estimate MaxEnd considering only
12994 // the case End = RHS of the loop termination condition. This is safe because
12995 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12996 // taken count.
12997 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12998 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12999
13000 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13001 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13002 : APIntOps::umax(MaxEnd, MinStart);
13003
13004 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13005 getConstant(StrideForMaxBECount) /* Step */);
13006}
13007
13009ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13010 const Loop *L, bool IsSigned,
13011 bool ControlsOnlyExit, bool AllowPredicates) {
13013
13014 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13015 bool PredicatedIV = false;
13016 if (!IV) {
13017 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13018 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13019 if (AR && AR->getLoop() == L && AR->isAffine()) {
13020 auto canProveNUW = [&]() {
13021 // We can use the comparison to infer no-wrap flags only if it fully
13022 // controls the loop exit.
13023 if (!ControlsOnlyExit)
13024 return false;
13025
13026 if (!isLoopInvariant(RHS, L))
13027 return false;
13028
13029 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13030 // We need the sequence defined by AR to strictly increase in the
13031 // unsigned integer domain for the logic below to hold.
13032 return false;
13033
13034 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13035 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13036 // If RHS <=u Limit, then there must exist a value V in the sequence
13037 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13038 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13039 // overflow occurs. This limit also implies that a signed comparison
13040 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13041 // the high bits on both sides must be zero.
13042 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13043 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13044 Limit = Limit.zext(OuterBitWidth);
13045 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13046 };
13047 auto Flags = AR->getNoWrapFlags();
13048 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13049 Flags = setFlags(Flags, SCEV::FlagNUW);
13050
13051 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13052 if (AR->hasNoUnsignedWrap()) {
13053 // Emulate what getZeroExtendExpr would have done during construction
13054 // if we'd been able to infer the fact just above at that time.
13055 const SCEV *Step = AR->getStepRecurrence(*this);
13056 Type *Ty = ZExt->getType();
13057 auto *S = getAddRecExpr(
13058 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
13059 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13060 IV = dyn_cast<SCEVAddRecExpr>(S);
13061 }
13062 }
13063 }
13064 }
13065
13066
13067 if (!IV && AllowPredicates) {
13068 // Try to make this an AddRec using runtime tests, in the first X
13069 // iterations of this loop, where X is the SCEV expression found by the
13070 // algorithm below.
13071 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13072 PredicatedIV = true;
13073 }
13074
13075 // Avoid weird loops
13076 if (!IV || IV->getLoop() != L || !IV->isAffine())
13077 return getCouldNotCompute();
13078
13079 // A precondition of this method is that the condition being analyzed
13080 // reaches an exiting branch which dominates the latch. Given that, we can
13081 // assume that an increment which violates the nowrap specification and
13082 // produces poison must cause undefined behavior when the resulting poison
13083 // value is branched upon and thus we can conclude that the backedge is
13084 // taken no more often than would be required to produce that poison value.
13085 // Note that a well defined loop can exit on the iteration which violates
13086 // the nowrap specification if there is another exit (either explicit or
13087 // implicit/exceptional) which causes the loop to execute before the
13088 // exiting instruction we're analyzing would trigger UB.
13089 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13090 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13092
13093 const SCEV *Stride = IV->getStepRecurrence(*this);
13094
13095 bool PositiveStride = isKnownPositive(Stride);
13096
13097 // Avoid negative or zero stride values.
13098 if (!PositiveStride) {
13099 // We can compute the correct backedge taken count for loops with unknown
13100 // strides if we can prove that the loop is not an infinite loop with side
13101 // effects. Here's the loop structure we are trying to handle -
13102 //
13103 // i = start
13104 // do {
13105 // A[i] = i;
13106 // i += s;
13107 // } while (i < end);
13108 //
13109 // The backedge taken count for such loops is evaluated as -
13110 // (max(end, start + stride) - start - 1) /u stride
13111 //
13112 // The additional preconditions that we need to check to prove correctness
13113 // of the above formula is as follows -
13114 //
13115 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13116 // NoWrap flag).
13117 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13118 // no side effects within the loop)
13119 // c) loop has a single static exit (with no abnormal exits)
13120 //
13121 // Precondition a) implies that if the stride is negative, this is a single
13122 // trip loop. The backedge taken count formula reduces to zero in this case.
13123 //
13124 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13125 // then a zero stride means the backedge can't be taken without executing
13126 // undefined behavior.
13127 //
13128 // The positive stride case is the same as isKnownPositive(Stride) returning
13129 // true (original behavior of the function).
13130 //
13131 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13133 return getCouldNotCompute();
13134
13135 if (!isKnownNonZero(Stride)) {
13136 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13137 // if it might eventually be greater than start and if so, on which
13138 // iteration. We can't even produce a useful upper bound.
13139 if (!isLoopInvariant(RHS, L))
13140 return getCouldNotCompute();
13141
13142 // We allow a potentially zero stride, but we need to divide by stride
13143 // below. Since the loop can't be infinite and this check must control
13144 // the sole exit, we can infer the exit must be taken on the first
13145 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13146 // we know the numerator in the divides below must be zero, so we can
13147 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13148 // and produce the right result.
13149 // FIXME: Handle the case where Stride is poison?
13150 auto wouldZeroStrideBeUB = [&]() {
13151 // Proof by contradiction. Suppose the stride were zero. If we can
13152 // prove that the backedge *is* taken on the first iteration, then since
13153 // we know this condition controls the sole exit, we must have an
13154 // infinite loop. We can't have a (well defined) infinite loop per
13155 // check just above.
13156 // Note: The (Start - Stride) term is used to get the start' term from
13157 // (start' + stride,+,stride). Remember that we only care about the
13158 // result of this expression when stride == 0 at runtime.
13159 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13160 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13161 };
13162 if (!wouldZeroStrideBeUB()) {
13163 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13164 }
13165 }
13166 } else if (!NoWrap) {
13167 // Avoid proven overflow cases: this will ensure that the backedge taken
13168 // count will not generate any unsigned overflow.
13169 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13170 return getCouldNotCompute();
13171 }
13172
13173 // On all paths just preceeding, we established the following invariant:
13174 // IV can be assumed not to overflow up to and including the exiting
13175 // iteration. We proved this in one of two ways:
13176 // 1) We can show overflow doesn't occur before the exiting iteration
13177 // 1a) canIVOverflowOnLT, and b) step of one
13178 // 2) We can show that if overflow occurs, the loop must execute UB
13179 // before any possible exit.
13180 // Note that we have not yet proved RHS invariant (in general).
13181
13182 const SCEV *Start = IV->getStart();
13183
13184 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13185 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13186 // Use integer-typed versions for actual computation; we can't subtract
13187 // pointers in general.
13188 const SCEV *OrigStart = Start;
13189 const SCEV *OrigRHS = RHS;
13190 if (Start->getType()->isPointerTy()) {
13191 Start = getLosslessPtrToIntExpr(Start);
13192 if (isa<SCEVCouldNotCompute>(Start))
13193 return Start;
13194 }
13195 if (RHS->getType()->isPointerTy()) {
13197 if (isa<SCEVCouldNotCompute>(RHS))
13198 return RHS;
13199 }
13200
13201 const SCEV *End = nullptr, *BECount = nullptr,
13202 *BECountIfBackedgeTaken = nullptr;
13203 if (!isLoopInvariant(RHS, L)) {
13204 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13205 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13206 RHSAddRec->getNoWrapFlags()) {
13207 // The structure of loop we are trying to calculate backedge count of:
13208 //
13209 // left = left_start
13210 // right = right_start
13211 //
13212 // while(left < right){
13213 // ... do something here ...
13214 // left += s1; // stride of left is s1 (s1 > 0)
13215 // right += s2; // stride of right is s2 (s2 < 0)
13216 // }
13217 //
13218
13219 const SCEV *RHSStart = RHSAddRec->getStart();
13220 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13221
13222 // If Stride - RHSStride is positive and does not overflow, we can write
13223 // backedge count as ->
13224 // ceil((End - Start) /u (Stride - RHSStride))
13225 // Where, End = max(RHSStart, Start)
13226
13227 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13228 if (isKnownNegative(RHSStride) &&
13229 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13230 RHSStride)) {
13231
13232 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13233 if (isKnownPositive(Denominator)) {
13234 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13235 : getUMaxExpr(RHSStart, Start);
13236
13237 // We can do this because End >= Start, as End = max(RHSStart, Start)
13238 const SCEV *Delta = getMinusSCEV(End, Start);
13239
13240 BECount = getUDivCeilSCEV(Delta, Denominator);
13241 BECountIfBackedgeTaken =
13242 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13243 }
13244 }
13245 }
13246 if (BECount == nullptr) {
13247 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13248 // given the start, stride and max value for the end bound of the
13249 // loop (RHS), and the fact that IV does not overflow (which is
13250 // checked above).
13251 const SCEV *MaxBECount = computeMaxBECountForLT(
13252 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13253 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13254 MaxBECount, false /*MaxOrZero*/, Predicates);
13255 }
13256 } else {
13257 // We use the expression (max(End,Start)-Start)/Stride to describe the
13258 // backedge count, as if the backedge is taken at least once
13259 // max(End,Start) is End and so the result is as above, and if not
13260 // max(End,Start) is Start so we get a backedge count of zero.
13261 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13262 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13263 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13264 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13265 // Can we prove (max(RHS,Start) > Start - Stride?
13266 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13267 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13268 // In this case, we can use a refined formula for computing backedge
13269 // taken count. The general formula remains:
13270 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13271 // We want to use the alternate formula:
13272 // "((End - 1) - (Start - Stride)) /u Stride"
13273 // Let's do a quick case analysis to show these are equivalent under
13274 // our precondition that max(RHS,Start) > Start - Stride.
13275 // * For RHS <= Start, the backedge-taken count must be zero.
13276 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13277 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13278 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13279 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13280 // reducing this to the stride of 1 case.
13281 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13282 // Stride".
13283 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13284 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13285 // "((RHS - (Start - Stride) - 1) /u Stride".
13286 // Our preconditions trivially imply no overflow in that form.
13287 const SCEV *MinusOne = getMinusOne(Stride->getType());
13288 const SCEV *Numerator =
13289 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13290 BECount = getUDivExpr(Numerator, Stride);
13291 }
13292
13293 if (!BECount) {
13294 auto canProveRHSGreaterThanEqualStart = [&]() {
13295 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13296 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13297 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13298
13299 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13300 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13301 return true;
13302
13303 // (RHS > Start - 1) implies RHS >= Start.
13304 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13305 // "Start - 1" doesn't overflow.
13306 // * For signed comparison, if Start - 1 does overflow, it's equal
13307 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13308 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13309 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13310 //
13311 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13312 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13313 auto *StartMinusOne =
13314 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13315 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13316 };
13317
13318 // If we know that RHS >= Start in the context of loop, then we know
13319 // that max(RHS, Start) = RHS at this point.
13320 if (canProveRHSGreaterThanEqualStart()) {
13321 End = RHS;
13322 } else {
13323 // If RHS < Start, the backedge will be taken zero times. So in
13324 // general, we can write the backedge-taken count as:
13325 //
13326 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13327 //
13328 // We convert it to the following to make it more convenient for SCEV:
13329 //
13330 // ceil(max(RHS, Start) - Start) / Stride
13331 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13332
13333 // See what would happen if we assume the backedge is taken. This is
13334 // used to compute MaxBECount.
13335 BECountIfBackedgeTaken =
13336 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13337 }
13338
13339 // At this point, we know:
13340 //
13341 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13342 // 2. The index variable doesn't overflow.
13343 //
13344 // Therefore, we know N exists such that
13345 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13346 // doesn't overflow.
13347 //
13348 // Using this information, try to prove whether the addition in
13349 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13350 const SCEV *One = getOne(Stride->getType());
13351 bool MayAddOverflow = [&] {
13352 if (isKnownToBeAPowerOfTwo(Stride)) {
13353 // Suppose Stride is a power of two, and Start/End are unsigned
13354 // integers. Let UMAX be the largest representable unsigned
13355 // integer.
13356 //
13357 // By the preconditions of this function, we know
13358 // "(Start + Stride * N) >= End", and this doesn't overflow.
13359 // As a formula:
13360 //
13361 // End <= (Start + Stride * N) <= UMAX
13362 //
13363 // Subtracting Start from all the terms:
13364 //
13365 // End - Start <= Stride * N <= UMAX - Start
13366 //
13367 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13368 //
13369 // End - Start <= Stride * N <= UMAX
13370 //
13371 // Stride * N is a multiple of Stride. Therefore,
13372 //
13373 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13374 //
13375 // Since Stride is a power of two, UMAX + 1 is divisible by
13376 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13377 // write:
13378 //
13379 // End - Start <= Stride * N <= UMAX - Stride - 1
13380 //
13381 // Dropping the middle term:
13382 //
13383 // End - Start <= UMAX - Stride - 1
13384 //
13385 // Adding Stride - 1 to both sides:
13386 //
13387 // (End - Start) + (Stride - 1) <= UMAX
13388 //
13389 // In other words, the addition doesn't have unsigned overflow.
13390 //
13391 // A similar proof works if we treat Start/End as signed values.
13392 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13393 // to use signed max instead of unsigned max. Note that we're
13394 // trying to prove a lack of unsigned overflow in either case.
13395 return false;
13396 }
13397 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13398 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13399 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13400 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13401 // 1 <s End.
13402 //
13403 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13404 // End.
13405 return false;
13406 }
13407 return true;
13408 }();
13409
13410 const SCEV *Delta = getMinusSCEV(End, Start);
13411 if (!MayAddOverflow) {
13412 // floor((D + (S - 1)) / S)
13413 // We prefer this formulation if it's legal because it's fewer
13414 // operations.
13415 BECount =
13416 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13417 } else {
13418 BECount = getUDivCeilSCEV(Delta, Stride);
13419 }
13420 }
13421 }
13422
13423 const SCEV *ConstantMaxBECount;
13424 bool MaxOrZero = false;
13425 if (isa<SCEVConstant>(BECount)) {
13426 ConstantMaxBECount = BECount;
13427 } else if (BECountIfBackedgeTaken &&
13428 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13429 // If we know exactly how many times the backedge will be taken if it's
13430 // taken at least once, then the backedge count will either be that or
13431 // zero.
13432 ConstantMaxBECount = BECountIfBackedgeTaken;
13433 MaxOrZero = true;
13434 } else {
13435 ConstantMaxBECount = computeMaxBECountForLT(
13436 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13437 }
13438
13439 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13440 !isa<SCEVCouldNotCompute>(BECount))
13441 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13442
13443 const SCEV *SymbolicMaxBECount =
13444 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13445 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13446 Predicates);
13447}
13448
13449ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13450 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13451 bool ControlsOnlyExit, bool AllowPredicates) {
13453 // We handle only IV > Invariant
13454 if (!isLoopInvariant(RHS, L))
13455 return getCouldNotCompute();
13456
13457 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13458 if (!IV && AllowPredicates)
13459 // Try to make this an AddRec using runtime tests, in the first X
13460 // iterations of this loop, where X is the SCEV expression found by the
13461 // algorithm below.
13462 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13463
13464 // Avoid weird loops
13465 if (!IV || IV->getLoop() != L || !IV->isAffine())
13466 return getCouldNotCompute();
13467
13468 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13469 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13471
13472 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13473
13474 // Avoid negative or zero stride values
13475 if (!isKnownPositive(Stride))
13476 return getCouldNotCompute();
13477
13478 // Avoid proven overflow cases: this will ensure that the backedge taken count
13479 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13480 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13481 // behaviors like the case of C language.
13482 if (!Stride->isOne() && !NoWrap)
13483 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13484 return getCouldNotCompute();
13485
13486 const SCEV *Start = IV->getStart();
13487 const SCEV *End = RHS;
13488 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13489 // If we know that Start >= RHS in the context of loop, then we know that
13490 // min(RHS, Start) = RHS at this point.
13492 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13493 End = RHS;
13494 else
13495 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13496 }
13497
13498 if (Start->getType()->isPointerTy()) {
13499 Start = getLosslessPtrToIntExpr(Start);
13500 if (isa<SCEVCouldNotCompute>(Start))
13501 return Start;
13502 }
13503 if (End->getType()->isPointerTy()) {
13505 if (isa<SCEVCouldNotCompute>(End))
13506 return End;
13507 }
13508
13509 // Compute ((Start - End) + (Stride - 1)) / Stride.
13510 // FIXME: This can overflow. Holding off on fixing this for now;
13511 // howManyGreaterThans will hopefully be gone soon.
13512 const SCEV *One = getOne(Stride->getType());
13513 const SCEV *BECount = getUDivExpr(
13514 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13515
13516 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13517 : getUnsignedRangeMax(Start);
13518
13519 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13520 : getUnsignedRangeMin(Stride);
13521
13522 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13523 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13524 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13525
13526 // Although End can be a MIN expression we estimate MinEnd considering only
13527 // the case End = RHS. This is safe because in the other case (Start - End)
13528 // is zero, leading to a zero maximum backedge taken count.
13529 APInt MinEnd =
13530 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13531 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13532
13533 const SCEV *ConstantMaxBECount =
13534 isa<SCEVConstant>(BECount)
13535 ? BECount
13536 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13537 getConstant(MinStride));
13538
13539 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13540 ConstantMaxBECount = BECount;
13541 const SCEV *SymbolicMaxBECount =
13542 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13543
13544 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13545 Predicates);
13546}
13547
13549 ScalarEvolution &SE) const {
13550 if (Range.isFullSet()) // Infinite loop.
13551 return SE.getCouldNotCompute();
13552
13553 // If the start is a non-zero constant, shift the range to simplify things.
13554 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13555 if (!SC->getValue()->isZero()) {
13557 Operands[0] = SE.getZero(SC->getType());
13558 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13559 getNoWrapFlags(FlagNW));
13560 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13561 return ShiftedAddRec->getNumIterationsInRange(
13562 Range.subtract(SC->getAPInt()), SE);
13563 // This is strange and shouldn't happen.
13564 return SE.getCouldNotCompute();
13565 }
13566
13567 // The only time we can solve this is when we have all constant indices.
13568 // Otherwise, we cannot determine the overflow conditions.
13569 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13570 return SE.getCouldNotCompute();
13571
13572 // Okay at this point we know that all elements of the chrec are constants and
13573 // that the start element is zero.
13574
13575 // First check to see if the range contains zero. If not, the first
13576 // iteration exits.
13577 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13578 if (!Range.contains(APInt(BitWidth, 0)))
13579 return SE.getZero(getType());
13580
13581 if (isAffine()) {
13582 // If this is an affine expression then we have this situation:
13583 // Solve {0,+,A} in Range === Ax in Range
13584
13585 // We know that zero is in the range. If A is positive then we know that
13586 // the upper value of the range must be the first possible exit value.
13587 // If A is negative then the lower of the range is the last possible loop
13588 // value. Also note that we already checked for a full range.
13589 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13590 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13591
13592 // The exit value should be (End+A)/A.
13593 APInt ExitVal = (End + A).udiv(A);
13594 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13595
13596 // Evaluate at the exit value. If we really did fall out of the valid
13597 // range, then we computed our trip count, otherwise wrap around or other
13598 // things must have happened.
13599 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13600 if (Range.contains(Val->getValue()))
13601 return SE.getCouldNotCompute(); // Something strange happened
13602
13603 // Ensure that the previous value is in the range.
13606 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13607 "Linear scev computation is off in a bad way!");
13608 return SE.getConstant(ExitValue);
13609 }
13610
13611 if (isQuadratic()) {
13612 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13613 return SE.getConstant(*S);
13614 }
13615
13616 return SE.getCouldNotCompute();
13617}
13618
13619const SCEVAddRecExpr *
13621 assert(getNumOperands() > 1 && "AddRec with zero step?");
13622 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13623 // but in this case we cannot guarantee that the value returned will be an
13624 // AddRec because SCEV does not have a fixed point where it stops
13625 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13626 // may happen if we reach arithmetic depth limit while simplifying. So we
13627 // construct the returned value explicitly.
13629 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13630 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13631 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13632 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13633 // We know that the last operand is not a constant zero (otherwise it would
13634 // have been popped out earlier). This guarantees us that if the result has
13635 // the same last operand, then it will also not be popped out, meaning that
13636 // the returned value will be an AddRec.
13637 const SCEV *Last = getOperand(getNumOperands() - 1);
13638 assert(!Last->isZero() && "Recurrency with zero step?");
13639 Ops.push_back(Last);
13640 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13642}
13643
13644// Return true when S contains at least an undef value.
13646 return SCEVExprContains(S, [](const SCEV *S) {
13647 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13648 return isa<UndefValue>(SU->getValue());
13649 return false;
13650 });
13651}
13652
13653// Return true when S contains a value that is a nullptr.
13655 return SCEVExprContains(S, [](const SCEV *S) {
13656 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13657 return SU->getValue() == nullptr;
13658 return false;
13659 });
13660}
13661
13662/// Return the size of an element read or written by Inst.
13664 Type *Ty;
13665 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13666 Ty = Store->getValueOperand()->getType();
13667 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13668 Ty = Load->getType();
13669 else
13670 return nullptr;
13671
13673 return getSizeOfExpr(ETy, Ty);
13674}
13675
13676//===----------------------------------------------------------------------===//
13677// SCEVCallbackVH Class Implementation
13678//===----------------------------------------------------------------------===//
13679
13680void ScalarEvolution::SCEVCallbackVH::deleted() {
13681 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13682 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13683 SE->ConstantEvolutionLoopExitValue.erase(PN);
13684 SE->eraseValueFromMap(getValPtr());
13685 // this now dangles!
13686}
13687
13688void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13689 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13690
13691 // Forget all the expressions associated with users of the old value,
13692 // so that future queries will recompute the expressions using the new
13693 // value.
13694 SE->forgetValue(getValPtr());
13695 // this now dangles!
13696}
13697
13698ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13699 : CallbackVH(V), SE(se) {}
13700
13701//===----------------------------------------------------------------------===//
13702// ScalarEvolution Class Implementation
13703//===----------------------------------------------------------------------===//
13704
13707 LoopInfo &LI)
13708 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13709 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13710 LoopDispositions(64), BlockDispositions(64) {
13711 // To use guards for proving predicates, we need to scan every instruction in
13712 // relevant basic blocks, and not just terminators. Doing this is a waste of
13713 // time if the IR does not actually contain any calls to
13714 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13715 //
13716 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13717 // to _add_ guards to the module when there weren't any before, and wants
13718 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13719 // efficient in lieu of being smart in that rather obscure case.
13720
13721 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13722 F.getParent(), Intrinsic::experimental_guard);
13723 HasGuards = GuardDecl && !GuardDecl->use_empty();
13724}
13725
13727 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13728 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13729 ValueExprMap(std::move(Arg.ValueExprMap)),
13730 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13731 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13732 PendingMerges(std::move(Arg.PendingMerges)),
13733 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13734 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13735 PredicatedBackedgeTakenCounts(
13736 std::move(Arg.PredicatedBackedgeTakenCounts)),
13737 BECountUsers(std::move(Arg.BECountUsers)),
13738 ConstantEvolutionLoopExitValue(
13739 std::move(Arg.ConstantEvolutionLoopExitValue)),
13740 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13741 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13742 LoopDispositions(std::move(Arg.LoopDispositions)),
13743 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13744 BlockDispositions(std::move(Arg.BlockDispositions)),
13745 SCEVUsers(std::move(Arg.SCEVUsers)),
13746 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13747 SignedRanges(std::move(Arg.SignedRanges)),
13748 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13749 UniquePreds(std::move(Arg.UniquePreds)),
13750 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13751 LoopUsers(std::move(Arg.LoopUsers)),
13752 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13753 FirstUnknown(Arg.FirstUnknown) {
13754 Arg.FirstUnknown = nullptr;
13755}
13756
13758 // Iterate through all the SCEVUnknown instances and call their
13759 // destructors, so that they release their references to their values.
13760 for (SCEVUnknown *U = FirstUnknown; U;) {
13761 SCEVUnknown *Tmp = U;
13762 U = U->Next;
13763 Tmp->~SCEVUnknown();
13764 }
13765 FirstUnknown = nullptr;
13766
13767 ExprValueMap.clear();
13768 ValueExprMap.clear();
13769 HasRecMap.clear();
13770 BackedgeTakenCounts.clear();
13771 PredicatedBackedgeTakenCounts.clear();
13772
13773 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13774 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13775 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13776 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13777 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13778}
13779
13781 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13782}
13783
13784/// When printing a top-level SCEV for trip counts, it's helpful to include
13785/// a type for constants which are otherwise hard to disambiguate.
13786static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13787 if (isa<SCEVConstant>(S))
13788 OS << *S->getType() << " ";
13789 OS << *S;
13790}
13791
13793 const Loop *L) {
13794 // Print all inner loops first
13795 for (Loop *I : *L)
13796 PrintLoopInfo(OS, SE, I);
13797
13798 OS << "Loop ";
13799 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13800 OS << ": ";
13801
13802 SmallVector<BasicBlock *, 8> ExitingBlocks;
13803 L->getExitingBlocks(ExitingBlocks);
13804 if (ExitingBlocks.size() != 1)
13805 OS << "<multiple exits> ";
13806
13807 auto *BTC = SE->getBackedgeTakenCount(L);
13808 if (!isa<SCEVCouldNotCompute>(BTC)) {
13809 OS << "backedge-taken count is ";
13811 } else
13812 OS << "Unpredictable backedge-taken count.";
13813 OS << "\n";
13814
13815 if (ExitingBlocks.size() > 1)
13816 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13817 OS << " exit count for " << ExitingBlock->getName() << ": ";
13818 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13820 if (isa<SCEVCouldNotCompute>(EC)) {
13821 // Retry with predicates.
13823 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13824 if (!isa<SCEVCouldNotCompute>(EC)) {
13825 OS << "\n predicated exit count for " << ExitingBlock->getName()
13826 << ": ";
13828 OS << "\n Predicates:\n";
13829 for (const auto *P : Predicates)
13830 P->print(OS, 4);
13831 }
13832 }
13833 OS << "\n";
13834 }
13835
13836 OS << "Loop ";
13837 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13838 OS << ": ";
13839
13840 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13841 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13842 OS << "constant max backedge-taken count is ";
13843 PrintSCEVWithTypeHint(OS, ConstantBTC);
13845 OS << ", actual taken count either this or zero.";
13846 } else {
13847 OS << "Unpredictable constant max backedge-taken count. ";
13848 }
13849
13850 OS << "\n"
13851 "Loop ";
13852 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13853 OS << ": ";
13854
13855 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13856 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13857 OS << "symbolic max backedge-taken count is ";
13858 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13860 OS << ", actual taken count either this or zero.";
13861 } else {
13862 OS << "Unpredictable symbolic max backedge-taken count. ";
13863 }
13864 OS << "\n";
13865
13866 if (ExitingBlocks.size() > 1)
13867 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13868 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13869 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13871 PrintSCEVWithTypeHint(OS, ExitBTC);
13872 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13873 // Retry with predicates.
13875 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13877 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13878 OS << "\n predicated symbolic max exit count for "
13879 << ExitingBlock->getName() << ": ";
13880 PrintSCEVWithTypeHint(OS, ExitBTC);
13881 OS << "\n Predicates:\n";
13882 for (const auto *P : Predicates)
13883 P->print(OS, 4);
13884 }
13885 }
13886 OS << "\n";
13887 }
13888
13890 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13891 if (PBT != BTC) {
13892 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13893 OS << "Loop ";
13894 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13895 OS << ": ";
13896 if (!isa<SCEVCouldNotCompute>(PBT)) {
13897 OS << "Predicated backedge-taken count is ";
13899 } else
13900 OS << "Unpredictable predicated backedge-taken count.";
13901 OS << "\n";
13902 OS << " Predicates:\n";
13903 for (const auto *P : Preds)
13904 P->print(OS, 4);
13905 }
13906 Preds.clear();
13907
13908 auto *PredConstantMax =
13910 if (PredConstantMax != ConstantBTC) {
13911 assert(!Preds.empty() &&
13912 "different predicated constant max BTC but no predicates");
13913 OS << "Loop ";
13914 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13915 OS << ": ";
13916 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13917 OS << "Predicated constant max backedge-taken count is ";
13918 PrintSCEVWithTypeHint(OS, PredConstantMax);
13919 } else
13920 OS << "Unpredictable predicated constant max backedge-taken count.";
13921 OS << "\n";
13922 OS << " Predicates:\n";
13923 for (const auto *P : Preds)
13924 P->print(OS, 4);
13925 }
13926 Preds.clear();
13927
13928 auto *PredSymbolicMax =
13930 if (SymbolicBTC != PredSymbolicMax) {
13931 assert(!Preds.empty() &&
13932 "Different predicated symbolic max BTC, but no predicates");
13933 OS << "Loop ";
13934 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13935 OS << ": ";
13936 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13937 OS << "Predicated symbolic max backedge-taken count is ";
13938 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13939 } else
13940 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13941 OS << "\n";
13942 OS << " Predicates:\n";
13943 for (const auto *P : Preds)
13944 P->print(OS, 4);
13945 }
13946
13948 OS << "Loop ";
13949 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13950 OS << ": ";
13951 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13952 }
13953}
13954
13955namespace llvm {
13957 switch (LD) {
13959 OS << "Variant";
13960 break;
13962 OS << "Invariant";
13963 break;
13965 OS << "Computable";
13966 break;
13967 }
13968 return OS;
13969}
13970
13972 switch (BD) {
13974 OS << "DoesNotDominate";
13975 break;
13977 OS << "Dominates";
13978 break;
13980 OS << "ProperlyDominates";
13981 break;
13982 }
13983 return OS;
13984}
13985} // namespace llvm
13986
13988 // ScalarEvolution's implementation of the print method is to print
13989 // out SCEV values of all instructions that are interesting. Doing
13990 // this potentially causes it to create new SCEV objects though,
13991 // which technically conflicts with the const qualifier. This isn't
13992 // observable from outside the class though, so casting away the
13993 // const isn't dangerous.
13994 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13995
13996 if (ClassifyExpressions) {
13997 OS << "Classifying expressions for: ";
13998 F.printAsOperand(OS, /*PrintType=*/false);
13999 OS << "\n";
14000 for (Instruction &I : instructions(F))
14001 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14002 OS << I << '\n';
14003 OS << " --> ";
14004 const SCEV *SV = SE.getSCEV(&I);
14005 SV->print(OS);
14006 if (!isa<SCEVCouldNotCompute>(SV)) {
14007 OS << " U: ";
14008 SE.getUnsignedRange(SV).print(OS);
14009 OS << " S: ";
14010 SE.getSignedRange(SV).print(OS);
14011 }
14012
14013 const Loop *L = LI.getLoopFor(I.getParent());
14014
14015 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14016 if (AtUse != SV) {
14017 OS << " --> ";
14018 AtUse->print(OS);
14019 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14020 OS << " U: ";
14021 SE.getUnsignedRange(AtUse).print(OS);
14022 OS << " S: ";
14023 SE.getSignedRange(AtUse).print(OS);
14024 }
14025 }
14026
14027 if (L) {
14028 OS << "\t\t" "Exits: ";
14029 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14030 if (!SE.isLoopInvariant(ExitValue, L)) {
14031 OS << "<<Unknown>>";
14032 } else {
14033 OS << *ExitValue;
14034 }
14035
14036 bool First = true;
14037 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14038 if (First) {
14039 OS << "\t\t" "LoopDispositions: { ";
14040 First = false;
14041 } else {
14042 OS << ", ";
14043 }
14044
14045 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14046 OS << ": " << SE.getLoopDisposition(SV, Iter);
14047 }
14048
14049 for (const auto *InnerL : depth_first(L)) {
14050 if (InnerL == L)
14051 continue;
14052 if (First) {
14053 OS << "\t\t" "LoopDispositions: { ";
14054 First = false;
14055 } else {
14056 OS << ", ";
14057 }
14058
14059 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14060 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14061 }
14062
14063 OS << " }";
14064 }
14065
14066 OS << "\n";
14067 }
14068 }
14069
14070 OS << "Determining loop execution counts for: ";
14071 F.printAsOperand(OS, /*PrintType=*/false);
14072 OS << "\n";
14073 for (Loop *I : LI)
14074 PrintLoopInfo(OS, &SE, I);
14075}
14076
14079 auto &Values = LoopDispositions[S];
14080 for (auto &V : Values) {
14081 if (V.getPointer() == L)
14082 return V.getInt();
14083 }
14084 Values.emplace_back(L, LoopVariant);
14085 LoopDisposition D = computeLoopDisposition(S, L);
14086 auto &Values2 = LoopDispositions[S];
14087 for (auto &V : llvm::reverse(Values2)) {
14088 if (V.getPointer() == L) {
14089 V.setInt(D);
14090 break;
14091 }
14092 }
14093 return D;
14094}
14095
14097ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14098 switch (S->getSCEVType()) {
14099 case scConstant:
14100 case scVScale:
14101 return LoopInvariant;
14102 case scAddRecExpr: {
14103 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14104
14105 // If L is the addrec's loop, it's computable.
14106 if (AR->getLoop() == L)
14107 return LoopComputable;
14108
14109 // Add recurrences are never invariant in the function-body (null loop).
14110 if (!L)
14111 return LoopVariant;
14112
14113 // Everything that is not defined at loop entry is variant.
14114 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14115 return LoopVariant;
14116 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14117 " dominate the contained loop's header?");
14118
14119 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14120 if (AR->getLoop()->contains(L))
14121 return LoopInvariant;
14122
14123 // This recurrence is variant w.r.t. L if any of its operands
14124 // are variant.
14125 for (const auto *Op : AR->operands())
14126 if (!isLoopInvariant(Op, L))
14127 return LoopVariant;
14128
14129 // Otherwise it's loop-invariant.
14130 return LoopInvariant;
14131 }
14132 case scTruncate:
14133 case scZeroExtend:
14134 case scSignExtend:
14135 case scPtrToInt:
14136 case scAddExpr:
14137 case scMulExpr:
14138 case scUDivExpr:
14139 case scUMaxExpr:
14140 case scSMaxExpr:
14141 case scUMinExpr:
14142 case scSMinExpr:
14143 case scSequentialUMinExpr: {
14144 bool HasVarying = false;
14145 for (const auto *Op : S->operands()) {
14147 if (D == LoopVariant)
14148 return LoopVariant;
14149 if (D == LoopComputable)
14150 HasVarying = true;
14151 }
14152 return HasVarying ? LoopComputable : LoopInvariant;
14153 }
14154 case scUnknown:
14155 // All non-instruction values are loop invariant. All instructions are loop
14156 // invariant if they are not contained in the specified loop.
14157 // Instructions are never considered invariant in the function body
14158 // (null loop) because they are defined within the "loop".
14159 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14160 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14161 return LoopInvariant;
14162 case scCouldNotCompute:
14163 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14164 }
14165 llvm_unreachable("Unknown SCEV kind!");
14166}
14167
14169 return getLoopDisposition(S, L) == LoopInvariant;
14170}
14171
14173 return getLoopDisposition(S, L) == LoopComputable;
14174}
14175
14178 auto &Values = BlockDispositions[S];
14179 for (auto &V : Values) {
14180 if (V.getPointer() == BB)
14181 return V.getInt();
14182 }
14183 Values.emplace_back(BB, DoesNotDominateBlock);
14184 BlockDisposition D = computeBlockDisposition(S, BB);
14185 auto &Values2 = BlockDispositions[S];
14186 for (auto &V : llvm::reverse(Values2)) {
14187 if (V.getPointer() == BB) {
14188 V.setInt(D);
14189 break;
14190 }
14191 }
14192 return D;
14193}
14194
14196ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14197 switch (S->getSCEVType()) {
14198 case scConstant:
14199 case scVScale:
14201 case scAddRecExpr: {
14202 // This uses a "dominates" query instead of "properly dominates" query
14203 // to test for proper dominance too, because the instruction which
14204 // produces the addrec's value is a PHI, and a PHI effectively properly
14205 // dominates its entire containing block.
14206 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14207 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14208 return DoesNotDominateBlock;
14209
14210 // Fall through into SCEVNAryExpr handling.
14211 [[fallthrough]];
14212 }
14213 case scTruncate:
14214 case scZeroExtend:
14215 case scSignExtend:
14216 case scPtrToInt:
14217 case scAddExpr:
14218 case scMulExpr:
14219 case scUDivExpr:
14220 case scUMaxExpr:
14221 case scSMaxExpr:
14222 case scUMinExpr:
14223 case scSMinExpr:
14224 case scSequentialUMinExpr: {
14225 bool Proper = true;
14226 for (const SCEV *NAryOp : S->operands()) {
14228 if (D == DoesNotDominateBlock)
14229 return DoesNotDominateBlock;
14230 if (D == DominatesBlock)
14231 Proper = false;
14232 }
14233 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14234 }
14235 case scUnknown:
14236 if (Instruction *I =
14237 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14238 if (I->getParent() == BB)
14239 return DominatesBlock;
14240 if (DT.properlyDominates(I->getParent(), BB))
14242 return DoesNotDominateBlock;
14243 }
14245 case scCouldNotCompute:
14246 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14247 }
14248 llvm_unreachable("Unknown SCEV kind!");
14249}
14250
14251bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14252 return getBlockDisposition(S, BB) >= DominatesBlock;
14253}
14254
14257}
14258
14259bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14260 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14261}
14262
14263void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14264 bool Predicated) {
14265 auto &BECounts =
14266 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14267 auto It = BECounts.find(L);
14268 if (It != BECounts.end()) {
14269 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14270 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14271 if (!isa<SCEVConstant>(S)) {
14272 auto UserIt = BECountUsers.find(S);
14273 assert(UserIt != BECountUsers.end());
14274 UserIt->second.erase({L, Predicated});
14275 }
14276 }
14277 }
14278 BECounts.erase(It);
14279 }
14280}
14281
14282void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14284 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14285
14286 while (!Worklist.empty()) {
14287 const SCEV *Curr = Worklist.pop_back_val();
14288 auto Users = SCEVUsers.find(Curr);
14289 if (Users != SCEVUsers.end())
14290 for (const auto *User : Users->second)
14291 if (ToForget.insert(User).second)
14292 Worklist.push_back(User);
14293 }
14294
14295 for (const auto *S : ToForget)
14296 forgetMemoizedResultsImpl(S);
14297
14298 for (auto I = PredicatedSCEVRewrites.begin();
14299 I != PredicatedSCEVRewrites.end();) {
14300 std::pair<const SCEV *, const Loop *> Entry = I->first;
14301 if (ToForget.count(Entry.first))
14302 PredicatedSCEVRewrites.erase(I++);
14303 else
14304 ++I;
14305 }
14306}
14307
14308void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14309 LoopDispositions.erase(S);
14310 BlockDispositions.erase(S);
14311 UnsignedRanges.erase(S);
14312 SignedRanges.erase(S);
14313 HasRecMap.erase(S);
14314 ConstantMultipleCache.erase(S);
14315
14316 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14317 UnsignedWrapViaInductionTried.erase(AR);
14318 SignedWrapViaInductionTried.erase(AR);
14319 }
14320
14321 auto ExprIt = ExprValueMap.find(S);
14322 if (ExprIt != ExprValueMap.end()) {
14323 for (Value *V : ExprIt->second) {
14324 auto ValueIt = ValueExprMap.find_as(V);
14325 if (ValueIt != ValueExprMap.end())
14326 ValueExprMap.erase(ValueIt);
14327 }
14328 ExprValueMap.erase(ExprIt);
14329 }
14330
14331 auto ScopeIt = ValuesAtScopes.find(S);
14332 if (ScopeIt != ValuesAtScopes.end()) {
14333 for (const auto &Pair : ScopeIt->second)
14334 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14335 llvm::erase(ValuesAtScopesUsers[Pair.second],
14336 std::make_pair(Pair.first, S));
14337 ValuesAtScopes.erase(ScopeIt);
14338 }
14339
14340 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14341 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14342 for (const auto &Pair : ScopeUserIt->second)
14343 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14344 ValuesAtScopesUsers.erase(ScopeUserIt);
14345 }
14346
14347 auto BEUsersIt = BECountUsers.find(S);
14348 if (BEUsersIt != BECountUsers.end()) {
14349 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14350 auto Copy = BEUsersIt->second;
14351 for (const auto &Pair : Copy)
14352 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14353 BECountUsers.erase(BEUsersIt);
14354 }
14355
14356 auto FoldUser = FoldCacheUser.find(S);
14357 if (FoldUser != FoldCacheUser.end())
14358 for (auto &KV : FoldUser->second)
14359 FoldCache.erase(KV);
14360 FoldCacheUser.erase(S);
14361}
14362
14363void
14364ScalarEvolution::getUsedLoops(const SCEV *S,
14365 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14366 struct FindUsedLoops {
14367 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14368 : LoopsUsed(LoopsUsed) {}
14370 bool follow(const SCEV *S) {
14371 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14372 LoopsUsed.insert(AR->getLoop());
14373 return true;
14374 }
14375
14376 bool isDone() const { return false; }
14377 };
14378
14379 FindUsedLoops F(LoopsUsed);
14381}
14382
14383void ScalarEvolution::getReachableBlocks(
14386 Worklist.push_back(&F.getEntryBlock());
14387 while (!Worklist.empty()) {
14388 BasicBlock *BB = Worklist.pop_back_val();
14389 if (!Reachable.insert(BB).second)
14390 continue;
14391
14392 Value *Cond;
14393 BasicBlock *TrueBB, *FalseBB;
14394 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14395 m_BasicBlock(FalseBB)))) {
14396 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14397 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14398 continue;
14399 }
14400
14401 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14402 const SCEV *L = getSCEV(Cmp->getOperand(0));
14403 const SCEV *R = getSCEV(Cmp->getOperand(1));
14404 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14405 Worklist.push_back(TrueBB);
14406 continue;
14407 }
14408 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14409 R)) {
14410 Worklist.push_back(FalseBB);
14411 continue;
14412 }
14413 }
14414 }
14415
14416 append_range(Worklist, successors(BB));
14417 }
14418}
14419
14421 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14422 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14423
14424 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14425
14426 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14427 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14428 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14429
14430 const SCEV *visitConstant(const SCEVConstant *Constant) {
14431 return SE.getConstant(Constant->getAPInt());
14432 }
14433
14434 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14435 return SE.getUnknown(Expr->getValue());
14436 }
14437
14438 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14439 return SE.getCouldNotCompute();
14440 }
14441 };
14442
14443 SCEVMapper SCM(SE2);
14444 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14445 SE2.getReachableBlocks(ReachableBlocks, F);
14446
14447 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14448 if (containsUndefs(Old) || containsUndefs(New)) {
14449 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14450 // not propagate undef aggressively). This means we can (and do) fail
14451 // verification in cases where a transform makes a value go from "undef"
14452 // to "undef+1" (say). The transform is fine, since in both cases the
14453 // result is "undef", but SCEV thinks the value increased by 1.
14454 return nullptr;
14455 }
14456
14457 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14458 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14459 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14460 return nullptr;
14461
14462 return Delta;
14463 };
14464
14465 while (!LoopStack.empty()) {
14466 auto *L = LoopStack.pop_back_val();
14467 llvm::append_range(LoopStack, *L);
14468
14469 // Only verify BECounts in reachable loops. For an unreachable loop,
14470 // any BECount is legal.
14471 if (!ReachableBlocks.contains(L->getHeader()))
14472 continue;
14473
14474 // Only verify cached BECounts. Computing new BECounts may change the
14475 // results of subsequent SCEV uses.
14476 auto It = BackedgeTakenCounts.find(L);
14477 if (It == BackedgeTakenCounts.end())
14478 continue;
14479
14480 auto *CurBECount =
14481 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14482 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14483
14484 if (CurBECount == SE2.getCouldNotCompute() ||
14485 NewBECount == SE2.getCouldNotCompute()) {
14486 // NB! This situation is legal, but is very suspicious -- whatever pass
14487 // change the loop to make a trip count go from could not compute to
14488 // computable or vice-versa *should have* invalidated SCEV. However, we
14489 // choose not to assert here (for now) since we don't want false
14490 // positives.
14491 continue;
14492 }
14493
14494 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14495 SE.getTypeSizeInBits(NewBECount->getType()))
14496 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14497 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14498 SE.getTypeSizeInBits(NewBECount->getType()))
14499 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14500
14501 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14502 if (Delta && !Delta->isZero()) {
14503 dbgs() << "Trip Count for " << *L << " Changed!\n";
14504 dbgs() << "Old: " << *CurBECount << "\n";
14505 dbgs() << "New: " << *NewBECount << "\n";
14506 dbgs() << "Delta: " << *Delta << "\n";
14507 std::abort();
14508 }
14509 }
14510
14511 // Collect all valid loops currently in LoopInfo.
14512 SmallPtrSet<Loop *, 32> ValidLoops;
14513 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14514 while (!Worklist.empty()) {
14515 Loop *L = Worklist.pop_back_val();
14516 if (ValidLoops.insert(L).second)
14517 Worklist.append(L->begin(), L->end());
14518 }
14519 for (const auto &KV : ValueExprMap) {
14520#ifndef NDEBUG
14521 // Check for SCEV expressions referencing invalid/deleted loops.
14522 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14523 assert(ValidLoops.contains(AR->getLoop()) &&
14524 "AddRec references invalid loop");
14525 }
14526#endif
14527
14528 // Check that the value is also part of the reverse map.
14529 auto It = ExprValueMap.find(KV.second);
14530 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14531 dbgs() << "Value " << *KV.first
14532 << " is in ValueExprMap but not in ExprValueMap\n";
14533 std::abort();
14534 }
14535
14536 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14537 if (!ReachableBlocks.contains(I->getParent()))
14538 continue;
14539 const SCEV *OldSCEV = SCM.visit(KV.second);
14540 const SCEV *NewSCEV = SE2.getSCEV(I);
14541 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14542 if (Delta && !Delta->isZero()) {
14543 dbgs() << "SCEV for value " << *I << " changed!\n"
14544 << "Old: " << *OldSCEV << "\n"
14545 << "New: " << *NewSCEV << "\n"
14546 << "Delta: " << *Delta << "\n";
14547 std::abort();
14548 }
14549 }
14550 }
14551
14552 for (const auto &KV : ExprValueMap) {
14553 for (Value *V : KV.second) {
14554 const SCEV *S = ValueExprMap.lookup(V);
14555 if (!S) {
14556 dbgs() << "Value " << *V
14557 << " is in ExprValueMap but not in ValueExprMap\n";
14558 std::abort();
14559 }
14560 if (S != KV.first) {
14561 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14562 << *KV.first << "\n";
14563 std::abort();
14564 }
14565 }
14566 }
14567
14568 // Verify integrity of SCEV users.
14569 for (const auto &S : UniqueSCEVs) {
14570 for (const auto *Op : S.operands()) {
14571 // We do not store dependencies of constants.
14572 if (isa<SCEVConstant>(Op))
14573 continue;
14574 auto It = SCEVUsers.find(Op);
14575 if (It != SCEVUsers.end() && It->second.count(&S))
14576 continue;
14577 dbgs() << "Use of operand " << *Op << " by user " << S
14578 << " is not being tracked!\n";
14579 std::abort();
14580 }
14581 }
14582
14583 // Verify integrity of ValuesAtScopes users.
14584 for (const auto &ValueAndVec : ValuesAtScopes) {
14585 const SCEV *Value = ValueAndVec.first;
14586 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14587 const Loop *L = LoopAndValueAtScope.first;
14588 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14589 if (!isa<SCEVConstant>(ValueAtScope)) {
14590 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14591 if (It != ValuesAtScopesUsers.end() &&
14592 is_contained(It->second, std::make_pair(L, Value)))
14593 continue;
14594 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14595 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14596 std::abort();
14597 }
14598 }
14599 }
14600
14601 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14602 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14603 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14604 const Loop *L = LoopAndValue.first;
14605 const SCEV *Value = LoopAndValue.second;
14606 assert(!isa<SCEVConstant>(Value));
14607 auto It = ValuesAtScopes.find(Value);
14608 if (It != ValuesAtScopes.end() &&
14609 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14610 continue;
14611 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14612 << *ValueAtScope << " missing in ValuesAtScopes\n";
14613 std::abort();
14614 }
14615 }
14616
14617 // Verify integrity of BECountUsers.
14618 auto VerifyBECountUsers = [&](bool Predicated) {
14619 auto &BECounts =
14620 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14621 for (const auto &LoopAndBEInfo : BECounts) {
14622 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14623 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14624 if (!isa<SCEVConstant>(S)) {
14625 auto UserIt = BECountUsers.find(S);
14626 if (UserIt != BECountUsers.end() &&
14627 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14628 continue;
14629 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14630 << " missing from BECountUsers\n";
14631 std::abort();
14632 }
14633 }
14634 }
14635 }
14636 };
14637 VerifyBECountUsers(/* Predicated */ false);
14638 VerifyBECountUsers(/* Predicated */ true);
14639
14640 // Verify intergity of loop disposition cache.
14641 for (auto &[S, Values] : LoopDispositions) {
14642 for (auto [Loop, CachedDisposition] : Values) {
14643 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14644 if (CachedDisposition != RecomputedDisposition) {
14645 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14646 << " is incorrect: cached " << CachedDisposition << ", actual "
14647 << RecomputedDisposition << "\n";
14648 std::abort();
14649 }
14650 }
14651 }
14652
14653 // Verify integrity of the block disposition cache.
14654 for (auto &[S, Values] : BlockDispositions) {
14655 for (auto [BB, CachedDisposition] : Values) {
14656 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14657 if (CachedDisposition != RecomputedDisposition) {
14658 dbgs() << "Cached disposition of " << *S << " for block %"
14659 << BB->getName() << " is incorrect: cached " << CachedDisposition
14660 << ", actual " << RecomputedDisposition << "\n";
14661 std::abort();
14662 }
14663 }
14664 }
14665
14666 // Verify FoldCache/FoldCacheUser caches.
14667 for (auto [FoldID, Expr] : FoldCache) {
14668 auto I = FoldCacheUser.find(Expr);
14669 if (I == FoldCacheUser.end()) {
14670 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14671 << "!\n";
14672 std::abort();
14673 }
14674 if (!is_contained(I->second, FoldID)) {
14675 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14676 std::abort();
14677 }
14678 }
14679 for (auto [Expr, IDs] : FoldCacheUser) {
14680 for (auto &FoldID : IDs) {
14681 const SCEV *S = FoldCache.lookup(FoldID);
14682 if (!S) {
14683 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14684 << "!\n";
14685 std::abort();
14686 }
14687 if (S != Expr) {
14688 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14689 << " != " << *Expr << "!\n";
14690 std::abort();
14691 }
14692 }
14693 }
14694
14695 // Verify that ConstantMultipleCache computations are correct. We check that
14696 // cached multiples and recomputed multiples are multiples of each other to
14697 // verify correctness. It is possible that a recomputed multiple is different
14698 // from the cached multiple due to strengthened no wrap flags or changes in
14699 // KnownBits computations.
14700 for (auto [S, Multiple] : ConstantMultipleCache) {
14701 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14702 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14703 Multiple.urem(RecomputedMultiple) != 0 &&
14704 RecomputedMultiple.urem(Multiple) != 0)) {
14705 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14706 << *S << " : Computed " << RecomputedMultiple
14707 << " but cache contains " << Multiple << "!\n";
14708 std::abort();
14709 }
14710 }
14711}
14712
14714 Function &F, const PreservedAnalyses &PA,
14716 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14717 // of its dependencies is invalidated.
14718 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14719 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14720 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14722 Inv.invalidate<LoopAnalysis>(F, PA);
14723}
14724
14725AnalysisKey ScalarEvolutionAnalysis::Key;
14726
14729 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14730 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14731 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14732 auto &LI = AM.getResult<LoopAnalysis>(F);
14733 return ScalarEvolution(F, TLI, AC, DT, LI);
14734}
14735
14739 return PreservedAnalyses::all();
14740}
14741
14744 // For compatibility with opt's -analyze feature under legacy pass manager
14745 // which was not ported to NPM. This keeps tests using
14746 // update_analyze_test_checks.py working.
14747 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14748 << F.getName() << "':\n";
14750 return PreservedAnalyses::all();
14751}
14752
14754 "Scalar Evolution Analysis", false, true)
14760 "Scalar Evolution Analysis", false, true)
14761
14763
14765
14767 SE.reset(new ScalarEvolution(
14768 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14769 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14770 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14771 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14772 return false;
14773}
14774
14776
14778 SE->print(OS);
14779}
14780
14782 if (!VerifySCEV)
14783 return;
14784
14785 SE->verify();
14786}
14787
14789 AU.setPreservesAll();
14794}
14795
14797 const SCEV *RHS) {
14799}
14800
14801const SCEVPredicate *
14803 const SCEV *LHS, const SCEV *RHS) {
14805 assert(LHS->getType() == RHS->getType() &&
14806 "Type mismatch between LHS and RHS");
14807 // Unique this node based on the arguments
14808 ID.AddInteger(SCEVPredicate::P_Compare);
14809 ID.AddInteger(Pred);
14810 ID.AddPointer(LHS);
14811 ID.AddPointer(RHS);
14812 void *IP = nullptr;
14813 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14814 return S;
14815 SCEVComparePredicate *Eq = new (SCEVAllocator)
14816 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14817 UniquePreds.InsertNode(Eq, IP);
14818 return Eq;
14819}
14820
14822 const SCEVAddRecExpr *AR,
14825 // Unique this node based on the arguments
14826 ID.AddInteger(SCEVPredicate::P_Wrap);
14827 ID.AddPointer(AR);
14828 ID.AddInteger(AddedFlags);
14829 void *IP = nullptr;
14830 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14831 return S;
14832 auto *OF = new (SCEVAllocator)
14833 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14834 UniquePreds.InsertNode(OF, IP);
14835 return OF;
14836}
14837
14838namespace {
14839
14840class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14841public:
14842
14843 /// Rewrites \p S in the context of a loop L and the SCEV predication
14844 /// infrastructure.
14845 ///
14846 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14847 /// equivalences present in \p Pred.
14848 ///
14849 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14850 /// \p NewPreds such that the result will be an AddRecExpr.
14851 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14853 const SCEVPredicate *Pred) {
14854 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14855 return Rewriter.visit(S);
14856 }
14857
14858 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14859 if (Pred) {
14860 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14861 for (const auto *Pred : U->getPredicates())
14862 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14863 if (IPred->getLHS() == Expr &&
14864 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14865 return IPred->getRHS();
14866 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14867 if (IPred->getLHS() == Expr &&
14868 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14869 return IPred->getRHS();
14870 }
14871 }
14872 return convertToAddRecWithPreds(Expr);
14873 }
14874
14875 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14876 const SCEV *Operand = visit(Expr->getOperand());
14877 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14878 if (AR && AR->getLoop() == L && AR->isAffine()) {
14879 // This couldn't be folded because the operand didn't have the nuw
14880 // flag. Add the nusw flag as an assumption that we could make.
14881 const SCEV *Step = AR->getStepRecurrence(SE);
14882 Type *Ty = Expr->getType();
14883 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14884 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14885 SE.getSignExtendExpr(Step, Ty), L,
14886 AR->getNoWrapFlags());
14887 }
14888 return SE.getZeroExtendExpr(Operand, Expr->getType());
14889 }
14890
14891 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14892 const SCEV *Operand = visit(Expr->getOperand());
14893 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14894 if (AR && AR->getLoop() == L && AR->isAffine()) {
14895 // This couldn't be folded because the operand didn't have the nsw
14896 // flag. Add the nssw flag as an assumption that we could make.
14897 const SCEV *Step = AR->getStepRecurrence(SE);
14898 Type *Ty = Expr->getType();
14899 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14900 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14901 SE.getSignExtendExpr(Step, Ty), L,
14902 AR->getNoWrapFlags());
14903 }
14904 return SE.getSignExtendExpr(Operand, Expr->getType());
14905 }
14906
14907private:
14908 explicit SCEVPredicateRewriter(
14909 const Loop *L, ScalarEvolution &SE,
14911 const SCEVPredicate *Pred)
14912 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14913
14914 bool addOverflowAssumption(const SCEVPredicate *P) {
14915 if (!NewPreds) {
14916 // Check if we've already made this assumption.
14917 return Pred && Pred->implies(P, SE);
14918 }
14919 NewPreds->push_back(P);
14920 return true;
14921 }
14922
14923 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14925 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14926 return addOverflowAssumption(A);
14927 }
14928
14929 // If \p Expr represents a PHINode, we try to see if it can be represented
14930 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14931 // to add this predicate as a runtime overflow check, we return the AddRec.
14932 // If \p Expr does not meet these conditions (is not a PHI node, or we
14933 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14934 // return \p Expr.
14935 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14936 if (!isa<PHINode>(Expr->getValue()))
14937 return Expr;
14938 std::optional<
14939 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14940 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14941 if (!PredicatedRewrite)
14942 return Expr;
14943 for (const auto *P : PredicatedRewrite->second){
14944 // Wrap predicates from outer loops are not supported.
14945 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14946 if (L != WP->getExpr()->getLoop())
14947 return Expr;
14948 }
14949 if (!addOverflowAssumption(P))
14950 return Expr;
14951 }
14952 return PredicatedRewrite->first;
14953 }
14954
14956 const SCEVPredicate *Pred;
14957 const Loop *L;
14958};
14959
14960} // end anonymous namespace
14961
14962const SCEV *
14964 const SCEVPredicate &Preds) {
14965 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14966}
14967
14969 const SCEV *S, const Loop *L,
14972 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14973 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14974
14975 if (!AddRec)
14976 return nullptr;
14977
14978 // Check if any of the transformed predicates is known to be false. In that
14979 // case, it doesn't make sense to convert to a predicated AddRec, as the
14980 // versioned loop will never execute.
14981 for (const SCEVPredicate *Pred : TransformPreds) {
14982 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
14983 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
14984 continue;
14985
14986 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
14987 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
14988 if (isa<SCEVCouldNotCompute>(ExitCount))
14989 continue;
14990
14991 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
14992 if (!Step->isOne())
14993 continue;
14994
14995 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
14996 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
14997 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
14998 return nullptr;
14999 }
15000
15001 // Since the transformation was successful, we can now transfer the SCEV
15002 // predicates.
15003 Preds.append(TransformPreds.begin(), TransformPreds.end());
15004
15005 return AddRec;
15006}
15007
15008/// SCEV predicates
15010 SCEVPredicateKind Kind)
15011 : FastID(ID), Kind(Kind) {}
15012
15014 const ICmpInst::Predicate Pred,
15015 const SCEV *LHS, const SCEV *RHS)
15016 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15017 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15018 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15019}
15020
15022 ScalarEvolution &SE) const {
15023 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15024
15025 if (!Op)
15026 return false;
15027
15028 if (Pred != ICmpInst::ICMP_EQ)
15029 return false;
15030
15031 return Op->LHS == LHS && Op->RHS == RHS;
15032}
15033
15034bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15035
15037 if (Pred == ICmpInst::ICMP_EQ)
15038 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15039 else
15040 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15041 << *RHS << "\n";
15042
15043}
15044
15046 const SCEVAddRecExpr *AR,
15047 IncrementWrapFlags Flags)
15048 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15049
15050const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15051
15053 ScalarEvolution &SE) const {
15054 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15055 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15056 return false;
15057
15058 if (Op->AR == AR)
15059 return true;
15060
15061 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15063 return false;
15064
15065 const SCEV *Start = AR->getStart();
15066 const SCEV *OpStart = Op->AR->getStart();
15067 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15068 return false;
15069
15070 // Reject pointers to different address spaces.
15071 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15072 return false;
15073
15074 const SCEV *Step = AR->getStepRecurrence(SE);
15075 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15076 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15077 return false;
15078
15079 // If both steps are positive, this implies N, if N's start and step are
15080 // ULE/SLE (for NSUW/NSSW) than this'.
15081 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15082 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15083 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15084
15085 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15086 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15087 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15088 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15089 : SE.getNoopOrSignExtend(Start, WiderTy);
15091 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15092 SE.isKnownPredicate(Pred, OpStart, Start);
15093}
15094
15096 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15097 IncrementWrapFlags IFlags = Flags;
15098
15099 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15100 IFlags = clearFlags(IFlags, IncrementNSSW);
15101
15102 return IFlags == IncrementAnyWrap;
15103}
15104
15106 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15108 OS << "<nusw>";
15110 OS << "<nssw>";
15111 OS << "\n";
15112}
15113
15116 ScalarEvolution &SE) {
15117 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15118 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15119
15120 // We can safely transfer the NSW flag as NSSW.
15121 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15122 ImpliedFlags = IncrementNSSW;
15123
15124 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15125 // If the increment is positive, the SCEV NUW flag will also imply the
15126 // WrapPredicate NUSW flag.
15127 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15128 if (Step->getValue()->getValue().isNonNegative())
15129 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15130 }
15131
15132 return ImpliedFlags;
15133}
15134
15135/// Union predicates don't get cached so create a dummy set ID for it.
15137 ScalarEvolution &SE)
15138 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15139 for (const auto *P : Preds)
15140 add(P, SE);
15141}
15142
15144 return all_of(Preds,
15145 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15146}
15147
15149 ScalarEvolution &SE) const {
15150 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15151 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15152 return this->implies(I, SE);
15153 });
15154
15155 return any_of(Preds,
15156 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15157}
15158
15160 for (const auto *Pred : Preds)
15161 Pred->print(OS, Depth);
15162}
15163
15164void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15165 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15166 for (const auto *Pred : Set->Preds)
15167 add(Pred, SE);
15168 return;
15169 }
15170
15171 // Only add predicate if it is not already implied by this union predicate.
15172 if (implies(N, SE))
15173 return;
15174
15175 // Build a new vector containing the current predicates, except the ones that
15176 // are implied by the new predicate N.
15178 for (auto *P : Preds) {
15179 if (N->implies(P, SE))
15180 continue;
15181 PrunedPreds.push_back(P);
15182 }
15183 Preds = std::move(PrunedPreds);
15184 Preds.push_back(N);
15185}
15186
15188 Loop &L)
15189 : SE(SE), L(L) {
15191 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15192}
15193
15196 for (const auto *Op : Ops)
15197 // We do not expect that forgetting cached data for SCEVConstants will ever
15198 // open any prospects for sharpening or introduce any correctness issues,
15199 // so we don't bother storing their dependencies.
15200 if (!isa<SCEVConstant>(Op))
15201 SCEVUsers[Op].insert(User);
15202}
15203
15205 const SCEV *Expr = SE.getSCEV(V);
15206 RewriteEntry &Entry = RewriteMap[Expr];
15207
15208 // If we already have an entry and the version matches, return it.
15209 if (Entry.second && Generation == Entry.first)
15210 return Entry.second;
15211
15212 // We found an entry but it's stale. Rewrite the stale entry
15213 // according to the current predicate.
15214 if (Entry.second)
15215 Expr = Entry.second;
15216
15217 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15218 Entry = {Generation, NewSCEV};
15219
15220 return NewSCEV;
15221}
15222
15224 if (!BackedgeCount) {
15226 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15227 for (const auto *P : Preds)
15228 addPredicate(*P);
15229 }
15230 return BackedgeCount;
15231}
15232
15234 if (!SymbolicMaxBackedgeCount) {
15236 SymbolicMaxBackedgeCount =
15238 for (const auto *P : Preds)
15239 addPredicate(*P);
15240 }
15241 return SymbolicMaxBackedgeCount;
15242}
15243
15245 if (!SmallConstantMaxTripCount) {
15247 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15248 for (const auto *P : Preds)
15249 addPredicate(*P);
15250 }
15251 return *SmallConstantMaxTripCount;
15252}
15253
15255 if (Preds->implies(&Pred, SE))
15256 return;
15257
15258 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15259 NewPreds.push_back(&Pred);
15260 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15261 updateGeneration();
15262}
15263
15265 return *Preds;
15266}
15267
15268void PredicatedScalarEvolution::updateGeneration() {
15269 // If the generation number wrapped recompute everything.
15270 if (++Generation == 0) {
15271 for (auto &II : RewriteMap) {
15272 const SCEV *Rewritten = II.second.second;
15273 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15274 }
15275 }
15276}
15277
15280 const SCEV *Expr = getSCEV(V);
15281 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15282
15283 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15284
15285 // Clear the statically implied flags.
15286 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15287 addPredicate(*SE.getWrapPredicate(AR, Flags));
15288
15289 auto II = FlagsMap.insert({V, Flags});
15290 if (!II.second)
15291 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15292}
15293
15296 const SCEV *Expr = getSCEV(V);
15297 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15298
15300 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15301
15302 auto II = FlagsMap.find(V);
15303
15304 if (II != FlagsMap.end())
15305 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15306
15308}
15309
15311 const SCEV *Expr = this->getSCEV(V);
15313 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15314
15315 if (!New)
15316 return nullptr;
15317
15318 for (const auto *P : NewPreds)
15319 addPredicate(*P);
15320
15321 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15322 return New;
15323}
15324
15327 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15328 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15329 SE)),
15330 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15331 for (auto I : Init.FlagsMap)
15332 FlagsMap.insert(I);
15333}
15334
15336 // For each block.
15337 for (auto *BB : L.getBlocks())
15338 for (auto &I : *BB) {
15339 if (!SE.isSCEVable(I.getType()))
15340 continue;
15341
15342 auto *Expr = SE.getSCEV(&I);
15343 auto II = RewriteMap.find(Expr);
15344
15345 if (II == RewriteMap.end())
15346 continue;
15347
15348 // Don't print things that are not interesting.
15349 if (II->second.second == Expr)
15350 continue;
15351
15352 OS.indent(Depth) << "[PSE]" << I << ":\n";
15353 OS.indent(Depth + 2) << *Expr << "\n";
15354 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15355 }
15356}
15357
15358// Match the mathematical pattern A - (A / B) * B, where A and B can be
15359// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15360// for URem with constant power-of-2 second operands.
15361// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15362// 4, A / B becomes X / 8).
15363bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15364 const SCEV *&RHS) {
15365 if (Expr->getType()->isPointerTy())
15366 return false;
15367
15368 // Try to match 'zext (trunc A to iB) to iY', which is used
15369 // for URem with constant power-of-2 second operands. Make sure the size of
15370 // the operand A matches the size of the whole expressions.
15371 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15372 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15373 LHS = Trunc->getOperand();
15374 // Bail out if the type of the LHS is larger than the type of the
15375 // expression for now.
15376 if (getTypeSizeInBits(LHS->getType()) >
15377 getTypeSizeInBits(Expr->getType()))
15378 return false;
15379 if (LHS->getType() != Expr->getType())
15380 LHS = getZeroExtendExpr(LHS, Expr->getType());
15382 << getTypeSizeInBits(Trunc->getType()));
15383 return true;
15384 }
15385 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15386 if (Add == nullptr || Add->getNumOperands() != 2)
15387 return false;
15388
15389 const SCEV *A = Add->getOperand(1);
15390 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15391
15392 if (Mul == nullptr)
15393 return false;
15394
15395 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15396 // (SomeExpr + (-(SomeExpr / B) * B)).
15397 if (Expr == getURemExpr(A, B)) {
15398 LHS = A;
15399 RHS = B;
15400 return true;
15401 }
15402 return false;
15403 };
15404
15405 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15406 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15407 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15408 MatchURemWithDivisor(Mul->getOperand(2));
15409
15410 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15411 if (Mul->getNumOperands() == 2)
15412 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15413 MatchURemWithDivisor(Mul->getOperand(0)) ||
15414 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15415 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15416 return false;
15417}
15418
15421 BasicBlock *Header = L->getHeader();
15422 BasicBlock *Pred = L->getLoopPredecessor();
15423 LoopGuards Guards(SE);
15424 if (!Pred)
15425 return Guards;
15427 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15428 return Guards;
15429}
15430
15431void ScalarEvolution::LoopGuards::collectFromPHI(
15433 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15435 unsigned Depth) {
15436 if (!SE.isSCEVable(Phi.getType()))
15437 return;
15438
15439 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15440 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15441 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15442 if (!VisitedBlocks.insert(InBlock).second)
15443 return {nullptr, scCouldNotCompute};
15444 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15445 if (Inserted)
15446 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15447 Depth + 1);
15448 auto &RewriteMap = G->second.RewriteMap;
15449 if (RewriteMap.empty())
15450 return {nullptr, scCouldNotCompute};
15451 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15452 if (S == RewriteMap.end())
15453 return {nullptr, scCouldNotCompute};
15454 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15455 if (!SM)
15456 return {nullptr, scCouldNotCompute};
15457 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15458 return {C0, SM->getSCEVType()};
15459 return {nullptr, scCouldNotCompute};
15460 };
15461 auto MergeMinMaxConst = [](MinMaxPattern P1,
15462 MinMaxPattern P2) -> MinMaxPattern {
15463 auto [C1, T1] = P1;
15464 auto [C2, T2] = P2;
15465 if (!C1 || !C2 || T1 != T2)
15466 return {nullptr, scCouldNotCompute};
15467 switch (T1) {
15468 case scUMaxExpr:
15469 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15470 case scSMaxExpr:
15471 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15472 case scUMinExpr:
15473 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15474 case scSMinExpr:
15475 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15476 default:
15477 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15478 }
15479 };
15480 auto P = GetMinMaxConst(0);
15481 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15482 if (!P.first)
15483 break;
15484 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15485 }
15486 if (P.first) {
15487 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15488 SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15489 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15490 Guards.RewriteMap.insert({LHS, RHS});
15491 }
15492}
15493
15494void ScalarEvolution::LoopGuards::collectFromBlock(
15496 const BasicBlock *Block, const BasicBlock *Pred,
15497 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15498 SmallVector<const SCEV *> ExprsToRewrite;
15499 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15500 const SCEV *RHS,
15502 &RewriteMap) {
15503 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15504 // replacement SCEV which isn't directly implied by the structure of that
15505 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15506 // legal. See the scoping rules for flags in the header to understand why.
15507
15508 // If LHS is a constant, apply information to the other expression.
15509 if (isa<SCEVConstant>(LHS)) {
15510 std::swap(LHS, RHS);
15512 }
15513
15514 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15515 // create this form when combining two checks of the form (X u< C2 + C1) and
15516 // (X >=u C1).
15517 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15518 &ExprsToRewrite]() {
15519 const SCEVConstant *C1;
15520 const SCEVUnknown *LHSUnknown;
15521 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15522 if (!match(LHS,
15523 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15524 !C2)
15525 return false;
15526
15527 auto ExactRegion =
15529 .sub(C1->getAPInt());
15530
15531 // Bail out, unless we have a non-wrapping, monotonic range.
15532 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15533 return false;
15534 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15535 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15536 I->second = SE.getUMaxExpr(
15537 SE.getConstant(ExactRegion.getUnsignedMin()),
15538 SE.getUMinExpr(RewrittenLHS,
15539 SE.getConstant(ExactRegion.getUnsignedMax())));
15540 ExprsToRewrite.push_back(LHSUnknown);
15541 return true;
15542 };
15543 if (MatchRangeCheckIdiom())
15544 return;
15545
15546 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15547 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15548 // the non-constant operand and in \p LHS the constant operand.
15549 auto IsMinMaxSCEVWithNonNegativeConstant =
15550 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15551 const SCEV *&RHS) {
15552 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15553 if (MinMax->getNumOperands() != 2)
15554 return false;
15555 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15556 if (C->getAPInt().isNegative())
15557 return false;
15558 SCTy = MinMax->getSCEVType();
15559 LHS = MinMax->getOperand(0);
15560 RHS = MinMax->getOperand(1);
15561 return true;
15562 }
15563 }
15564 return false;
15565 };
15566
15567 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15568 // constant, and returns their APInt in ExprVal and in DivisorVal.
15569 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15570 APInt &ExprVal, APInt &DivisorVal) {
15571 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15572 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15573 if (!ConstExpr || !ConstDivisor)
15574 return false;
15575 ExprVal = ConstExpr->getAPInt();
15576 DivisorVal = ConstDivisor->getAPInt();
15577 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15578 };
15579
15580 // Return a new SCEV that modifies \p Expr to the closest number divides by
15581 // \p Divisor and greater or equal than Expr.
15582 // For now, only handle constant Expr and Divisor.
15583 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15584 const SCEV *Divisor) {
15585 APInt ExprVal;
15586 APInt DivisorVal;
15587 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15588 return Expr;
15589 APInt Rem = ExprVal.urem(DivisorVal);
15590 if (!Rem.isZero())
15591 // return the SCEV: Expr + Divisor - Expr % Divisor
15592 return SE.getConstant(ExprVal + DivisorVal - Rem);
15593 return Expr;
15594 };
15595
15596 // Return a new SCEV that modifies \p Expr to the closest number divides by
15597 // \p Divisor and less or equal than Expr.
15598 // For now, only handle constant Expr and Divisor.
15599 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15600 const SCEV *Divisor) {
15601 APInt ExprVal;
15602 APInt DivisorVal;
15603 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15604 return Expr;
15605 APInt Rem = ExprVal.urem(DivisorVal);
15606 // return the SCEV: Expr - Expr % Divisor
15607 return SE.getConstant(ExprVal - Rem);
15608 };
15609
15610 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15611 // recursively. This is done by aligning up/down the constant value to the
15612 // Divisor.
15613 std::function<const SCEV *(const SCEV *, const SCEV *)>
15614 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15615 const SCEV *Divisor) {
15616 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15617 SCEVTypes SCTy;
15618 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15619 MinMaxRHS))
15620 return MinMaxExpr;
15621 auto IsMin =
15622 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15623 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15624 "Expected non-negative operand!");
15625 auto *DivisibleExpr =
15626 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15627 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15629 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15630 return SE.getMinMaxExpr(SCTy, Ops);
15631 };
15632
15633 // If we have LHS == 0, check if LHS is computing a property of some unknown
15634 // SCEV %v which we can rewrite %v to express explicitly.
15635 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15636 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15637 // explicitly express that.
15638 const SCEV *URemLHS = nullptr;
15639 const SCEV *URemRHS = nullptr;
15640 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15641 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15642 auto I = RewriteMap.find(LHSUnknown);
15643 const SCEV *RewrittenLHS =
15644 I != RewriteMap.end() ? I->second : LHSUnknown;
15645 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15646 const auto *Multiple =
15647 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15648 RewriteMap[LHSUnknown] = Multiple;
15649 ExprsToRewrite.push_back(LHSUnknown);
15650 return;
15651 }
15652 }
15653 }
15654
15655 // Do not apply information for constants or if RHS contains an AddRec.
15656 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15657 return;
15658
15659 // If RHS is SCEVUnknown, make sure the information is applied to it.
15660 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15661 std::swap(LHS, RHS);
15663 }
15664
15665 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15666 // and \p FromRewritten are the same (i.e. there has been no rewrite
15667 // registered for \p From), then puts this value in the list of rewritten
15668 // expressions.
15669 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15670 const SCEV *To) {
15671 if (From == FromRewritten)
15672 ExprsToRewrite.push_back(From);
15673 RewriteMap[From] = To;
15674 };
15675
15676 // Checks whether \p S has already been rewritten. In that case returns the
15677 // existing rewrite because we want to chain further rewrites onto the
15678 // already rewritten value. Otherwise returns \p S.
15679 auto GetMaybeRewritten = [&](const SCEV *S) {
15680 return RewriteMap.lookup_or(S, S);
15681 };
15682
15683 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15684 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15685 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15686 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15687 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15688 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15689 // DividesBy.
15690 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15691 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15692 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15693 if (Mul->getNumOperands() != 2)
15694 return false;
15695 auto *MulLHS = Mul->getOperand(0);
15696 auto *MulRHS = Mul->getOperand(1);
15697 if (isa<SCEVConstant>(MulLHS))
15698 std::swap(MulLHS, MulRHS);
15699 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15700 if (Div->getOperand(1) == MulRHS) {
15701 DividesBy = MulRHS;
15702 return true;
15703 }
15704 }
15705 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15706 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15707 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15708 return false;
15709 };
15710
15711 // Return true if Expr known to divide by \p DividesBy.
15712 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15713 [&](const SCEV *Expr, const SCEV *DividesBy) {
15714 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15715 return true;
15716 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15717 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15718 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15719 return false;
15720 };
15721
15722 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15723 const SCEV *DividesBy = nullptr;
15724 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15725 // Check that the whole expression is divided by DividesBy
15726 DividesBy =
15727 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15728
15729 // Collect rewrites for LHS and its transitive operands based on the
15730 // condition.
15731 // For min/max expressions, also apply the guard to its operands:
15732 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15733 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15734 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15735 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15736
15737 // We cannot express strict predicates in SCEV, so instead we replace them
15738 // with non-strict ones against plus or minus one of RHS depending on the
15739 // predicate.
15740 const SCEV *One = SE.getOne(RHS->getType());
15741 switch (Predicate) {
15742 case CmpInst::ICMP_ULT:
15743 if (RHS->getType()->isPointerTy())
15744 return;
15745 RHS = SE.getUMaxExpr(RHS, One);
15746 [[fallthrough]];
15747 case CmpInst::ICMP_SLT: {
15748 RHS = SE.getMinusSCEV(RHS, One);
15749 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15750 break;
15751 }
15752 case CmpInst::ICMP_UGT:
15753 case CmpInst::ICMP_SGT:
15754 RHS = SE.getAddExpr(RHS, One);
15755 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15756 break;
15757 case CmpInst::ICMP_ULE:
15758 case CmpInst::ICMP_SLE:
15759 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15760 break;
15761 case CmpInst::ICMP_UGE:
15762 case CmpInst::ICMP_SGE:
15763 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15764 break;
15765 default:
15766 break;
15767 }
15768
15769 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15771
15772 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15773 append_range(Worklist, S->operands());
15774 };
15775
15776 while (!Worklist.empty()) {
15777 const SCEV *From = Worklist.pop_back_val();
15778 if (isa<SCEVConstant>(From))
15779 continue;
15780 if (!Visited.insert(From).second)
15781 continue;
15782 const SCEV *FromRewritten = GetMaybeRewritten(From);
15783 const SCEV *To = nullptr;
15784
15785 switch (Predicate) {
15786 case CmpInst::ICMP_ULT:
15787 case CmpInst::ICMP_ULE:
15788 To = SE.getUMinExpr(FromRewritten, RHS);
15789 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15790 EnqueueOperands(UMax);
15791 break;
15792 case CmpInst::ICMP_SLT:
15793 case CmpInst::ICMP_SLE:
15794 To = SE.getSMinExpr(FromRewritten, RHS);
15795 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15796 EnqueueOperands(SMax);
15797 break;
15798 case CmpInst::ICMP_UGT:
15799 case CmpInst::ICMP_UGE:
15800 To = SE.getUMaxExpr(FromRewritten, RHS);
15801 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15802 EnqueueOperands(UMin);
15803 break;
15804 case CmpInst::ICMP_SGT:
15805 case CmpInst::ICMP_SGE:
15806 To = SE.getSMaxExpr(FromRewritten, RHS);
15807 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15808 EnqueueOperands(SMin);
15809 break;
15810 case CmpInst::ICMP_EQ:
15811 if (isa<SCEVConstant>(RHS))
15812 To = RHS;
15813 break;
15814 case CmpInst::ICMP_NE:
15815 if (match(RHS, m_scev_Zero())) {
15816 const SCEV *OneAlignedUp =
15817 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15818 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15819 }
15820 break;
15821 default:
15822 break;
15823 }
15824
15825 if (To)
15826 AddRewrite(From, FromRewritten, To);
15827 }
15828 };
15829
15831 // First, collect information from assumptions dominating the loop.
15832 for (auto &AssumeVH : SE.AC.assumptions()) {
15833 if (!AssumeVH)
15834 continue;
15835 auto *AssumeI = cast<CallInst>(AssumeVH);
15836 if (!SE.DT.dominates(AssumeI, Block))
15837 continue;
15838 Terms.emplace_back(AssumeI->getOperand(0), true);
15839 }
15840
15841 // Second, collect information from llvm.experimental.guards dominating the loop.
15842 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15843 SE.F.getParent(), Intrinsic::experimental_guard);
15844 if (GuardDecl)
15845 for (const auto *GU : GuardDecl->users())
15846 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15847 if (Guard->getFunction() == Block->getParent() &&
15848 SE.DT.dominates(Guard, Block))
15849 Terms.emplace_back(Guard->getArgOperand(0), true);
15850
15851 // Third, collect conditions from dominating branches. Starting at the loop
15852 // predecessor, climb up the predecessor chain, as long as there are
15853 // predecessors that can be found that have unique successors leading to the
15854 // original header.
15855 // TODO: share this logic with isLoopEntryGuardedByCond.
15856 unsigned NumCollectedConditions = 0;
15857 VisitedBlocks.insert(Block);
15858 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15859 for (; Pair.first;
15860 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15861 VisitedBlocks.insert(Pair.second);
15862 const BranchInst *LoopEntryPredicate =
15863 dyn_cast<BranchInst>(Pair.first->getTerminator());
15864 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15865 continue;
15866
15867 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15868 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15869 NumCollectedConditions++;
15870
15871 // If we are recursively collecting guards stop after 2
15872 // conditions to limit compile-time impact for now.
15873 if (Depth > 0 && NumCollectedConditions == 2)
15874 break;
15875 }
15876 // Finally, if we stopped climbing the predecessor chain because
15877 // there wasn't a unique one to continue, try to collect conditions
15878 // for PHINodes by recursively following all of their incoming
15879 // blocks and try to merge the found conditions to build a new one
15880 // for the Phi.
15881 if (Pair.second->hasNPredecessorsOrMore(2) &&
15884 for (auto &Phi : Pair.second->phis())
15885 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15886 }
15887
15888 // Now apply the information from the collected conditions to
15889 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15890 // earliest conditions is processed first. This ensures the SCEVs with the
15891 // shortest dependency chains are constructed first.
15892 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15893 SmallVector<Value *, 8> Worklist;
15895 Worklist.push_back(Term);
15896 while (!Worklist.empty()) {
15897 Value *Cond = Worklist.pop_back_val();
15898 if (!Visited.insert(Cond).second)
15899 continue;
15900
15901 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15902 auto Predicate =
15903 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15904 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15905 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15906 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15907 continue;
15908 }
15909
15910 Value *L, *R;
15911 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15912 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15913 Worklist.push_back(L);
15914 Worklist.push_back(R);
15915 }
15916 }
15917 }
15918
15919 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15920 // the replacement expressions are contained in the ranges of the replaced
15921 // expressions.
15922 Guards.PreserveNUW = true;
15923 Guards.PreserveNSW = true;
15924 for (const SCEV *Expr : ExprsToRewrite) {
15925 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15926 Guards.PreserveNUW &=
15927 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15928 Guards.PreserveNSW &=
15929 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15930 }
15931
15932 // Now that all rewrite information is collect, rewrite the collected
15933 // expressions with the information in the map. This applies information to
15934 // sub-expressions.
15935 if (ExprsToRewrite.size() > 1) {
15936 for (const SCEV *Expr : ExprsToRewrite) {
15937 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15938 Guards.RewriteMap.erase(Expr);
15939 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15940 }
15941 }
15942}
15943
15945 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15946 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15947 /// replacement is loop invariant in the loop of the AddRec.
15948 class SCEVLoopGuardRewriter
15949 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15951
15953
15954 public:
15955 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15956 const ScalarEvolution::LoopGuards &Guards)
15957 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15958 if (Guards.PreserveNUW)
15959 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15960 if (Guards.PreserveNSW)
15961 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15962 }
15963
15964 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15965
15966 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15967 return Map.lookup_or(Expr, Expr);
15968 }
15969
15970 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15971 if (const SCEV *S = Map.lookup(Expr))
15972 return S;
15973
15974 // If we didn't find the extact ZExt expr in the map, check if there's
15975 // an entry for a smaller ZExt we can use instead.
15976 Type *Ty = Expr->getType();
15977 const SCEV *Op = Expr->getOperand(0);
15978 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15979 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15980 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15981 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15982 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15983 if (const SCEV *S = Map.lookup(NarrowExt))
15984 return SE.getZeroExtendExpr(S, Ty);
15985 Bitwidth = Bitwidth / 2;
15986 }
15987
15989 Expr);
15990 }
15991
15992 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15993 if (const SCEV *S = Map.lookup(Expr))
15994 return S;
15996 Expr);
15997 }
15998
15999 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16000 if (const SCEV *S = Map.lookup(Expr))
16001 return S;
16003 }
16004
16005 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16006 if (const SCEV *S = Map.lookup(Expr))
16007 return S;
16009 }
16010
16011 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16013 bool Changed = false;
16014 for (const auto *Op : Expr->operands()) {
16015 Operands.push_back(
16017 Changed |= Op != Operands.back();
16018 }
16019 // We are only replacing operands with equivalent values, so transfer the
16020 // flags from the original expression.
16021 return !Changed ? Expr
16022 : SE.getAddExpr(Operands,
16024 Expr->getNoWrapFlags(), FlagMask));
16025 }
16026
16027 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16029 bool Changed = false;
16030 for (const auto *Op : Expr->operands()) {
16031 Operands.push_back(
16033 Changed |= Op != Operands.back();
16034 }
16035 // We are only replacing operands with equivalent values, so transfer the
16036 // flags from the original expression.
16037 return !Changed ? Expr
16038 : SE.getMulExpr(Operands,
16040 Expr->getNoWrapFlags(), FlagMask));
16041 }
16042 };
16043
16044 if (RewriteMap.empty())
16045 return Expr;
16046
16047 SCEVLoopGuardRewriter Rewriter(SE, *this);
16048 return Rewriter.visit(Expr);
16049}
16050
16051const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16052 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16053}
16054
16056 const LoopGuards &Guards) {
16057 return Guards.rewrite(Expr);
16058}
@ Poison
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:546
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition: Debug.h:119
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:247
Virtual Register Rewriter
Definition: VirtRegMap.cpp:269
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:83
Class for arbitrary precision integers.
Definition: APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1971
LLVM_ABI APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1573
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:206
APInt abs() const
Get the absolute value.
Definition: APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition: APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:219
unsigned countTrailingZeros() const
Definition: APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:356
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition: APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1221
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:50
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:294
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:312
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:412
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:224
iterator end() const
Definition: ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:147
iterator begin() const
Definition: ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:52
LLVM Basic Block Representation.
Definition: BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:459
const Instruction & front() const
Definition: BasicBlock.h:482
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:437
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:213
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:233
Value * getRHS() const
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:384
void setValPtr(Value *P)
Definition: ValueHandle.h:391
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:950
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:678
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:707
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:708
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:701
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:705
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:703
@ ICMP_EQ
equal
Definition: InstrTypes.h:699
@ ICMP_NE
not equal
Definition: InstrTypes.h:700
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:706
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:704
bool isSigned() const
Definition: InstrTypes.h:932
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:829
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:944
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:791
bool isUnsigned() const
Definition: InstrTypes.h:938
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:928
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
Definition: CmpPredicate.h:23
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Definition: CmpPredicate.h:47
static LLVM_ABI Constant * getNot(Constant *C)
Definition: Constants.cpp:2641
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2300
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2647
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2635
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2272
This is the shared class of boolean and integer constants.
Definition: Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:875
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:882
This class represents a range of values.
Definition: ConstantRange.h:47
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
LLVM_ABI ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:43
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:708
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:850
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
Definition: DataLayout.cpp:753
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:877
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:674
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:203
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:177
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition: DenseMap.h:245
bool erase(const KeyT &Val)
Definition: DenseMap.h:319
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:185
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:173
iterator end()
Definition: DenseMap.h:87
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:168
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:230
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:284
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:334
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:135
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:330
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:314
const BasicBlock & getEntryBlock() const
Definition: Function.h:807
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:727
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:663
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:408
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:405
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
Definition: DerivedTypes.h:42
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:319
An instruction for reading from memory.
Definition: Instructions.h:180
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:570
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:597
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:40
bool isLoopInvariant(const Value *V, bool HasCoroSuspendInst=false) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1077
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:33
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:720
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1885
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:275
constexpr bool isValid() const
Definition: Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:380
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:401
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:476
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:541
bool empty() const
Definition: SmallVector.h:82
size_t size() const
Definition: SmallVector.h:79
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:574
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:938
void reserve(size_type N)
Definition: SmallVector.h:664
iterator erase(const_iterator CI)
Definition: SmallVector.h:738
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:684
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:806
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
An instruction for storing to memory.
Definition: Instructions.h:296
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:626
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:657
TypeSize getSizeInBits() const
Definition: DataLayout.h:637
Class to represent struct types.
Definition: DerivedTypes.h:218
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:267
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:255
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:240
A Use represents the edge between a Value definition and its users.
Definition: Use.h:35
op_range operands()
Definition: User.h:292
Use & Op()
Definition: User.h:196
Value * getOperand(unsigned i) const
Definition: User.h:232
LLVM Value Representation.
Definition: Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5305
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1098
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:322
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:172
const ParentTy * getParent() const
Definition: ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:798
@ Entry
Definition: COFF.h:862
@ Exit
Definition: COFF.h:863
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Definition: Intrinsics.cpp:762
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:876
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
class_match< const SCEV > m_SCEV()
@ ReallyHidden
Definition: CommandLine.h:139
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:444
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:464
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:338
@ Offset
Definition: DWP.cpp:477
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:403
void stable_sort(R &&Range)
Definition: STLExtras.h:2077
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1744
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7502
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
constexpr from_range_t from_range
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2155
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:252
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition: STLExtras.h:2072
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:157
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2147
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1751
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition: iterator.h:336
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:428
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1174
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1164
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:207
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:288
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1973
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2049
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:312
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:223
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1886
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1980
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:257
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1916
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:856
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:858
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:29
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:294
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:101
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:427
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:189
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:138
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:122
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:98
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.