LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (matchURem(Op, LHS, RHS))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 if (SM->getNumOperands() == 2)
1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845 if (MulLHS->getAPInt().isPowerOf2())
1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848 MulLHS->getAPInt().logBase2();
1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850 return getMulExpr(
1851 getZeroExtendExpr(MulLHS, Ty),
1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2427static SCEV::NoWrapFlags
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 if (Ops.size() == 2) {
2704 if (Mul && Mul->getNumOperands() == 2 &&
2705 Mul->getOperand(0)->isAllOnesValue()) {
2706 const SCEV *X;
2707 const SCEV *Y;
2708 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709 return getMulExpr(Y, getUDivExpr(X, Y));
2710 }
2711 }
2712 }
2713
2714 // Skip past any other cast SCEVs.
2715 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2716 ++Idx;
2717
2718 // If there are add operands they would be next.
2719 if (Idx < Ops.size()) {
2720 bool DeletedAdd = false;
2721 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2722 // common NUW flag for expression after inlining. Other flags cannot be
2723 // preserved, because they may depend on the original order of operations.
2724 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2725 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2726 if (Ops.size() > AddOpsInlineThreshold ||
2727 Add->getNumOperands() > AddOpsInlineThreshold)
2728 break;
2729 // If we have an add, expand the add operands onto the end of the operands
2730 // list.
2731 Ops.erase(Ops.begin()+Idx);
2732 append_range(Ops, Add->operands());
2733 DeletedAdd = true;
2734 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2735 }
2736
2737 // If we deleted at least one add, we added operands to the end of the list,
2738 // and they are not necessarily sorted. Recurse to resort and resimplify
2739 // any operands we just acquired.
2740 if (DeletedAdd)
2741 return getAddExpr(Ops, CommonFlags, Depth + 1);
2742 }
2743
2744 // Skip over the add expression until we get to a multiply.
2745 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2746 ++Idx;
2747
2748 // Check to see if there are any folding opportunities present with
2749 // operands multiplied by constant values.
2750 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2754 APInt AccumulatedConstant(BitWidth, 0);
2755 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2756 Ops, APInt(BitWidth, 1), *this)) {
2757 struct APIntCompare {
2758 bool operator()(const APInt &LHS, const APInt &RHS) const {
2759 return LHS.ult(RHS);
2760 }
2761 };
2762
2763 // Some interesting folding opportunity is present, so its worthwhile to
2764 // re-generate the operands list. Group the operands by constant scale,
2765 // to avoid multiplying by the same constant scale multiple times.
2766 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2767 for (const SCEV *NewOp : NewOps)
2768 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2769 // Re-generate the operands list.
2770 Ops.clear();
2771 if (AccumulatedConstant != 0)
2772 Ops.push_back(getConstant(AccumulatedConstant));
2773 for (auto &MulOp : MulOpLists) {
2774 if (MulOp.first == 1) {
2775 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2776 } else if (MulOp.first != 0) {
2777 Ops.push_back(getMulExpr(
2778 getConstant(MulOp.first),
2779 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2780 SCEV::FlagAnyWrap, Depth + 1));
2781 }
2782 }
2783 if (Ops.empty())
2784 return getZero(Ty);
2785 if (Ops.size() == 1)
2786 return Ops[0];
2787 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2788 }
2789 }
2790
2791 // If we are adding something to a multiply expression, make sure the
2792 // something is not already an operand of the multiply. If so, merge it into
2793 // the multiply.
2794 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2795 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2796 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2797 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2798 if (isa<SCEVConstant>(MulOpSCEV))
2799 continue;
2800 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2801 if (MulOpSCEV == Ops[AddOp]) {
2802 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2803 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2804 if (Mul->getNumOperands() != 2) {
2805 // If the multiply has more than two operands, we must get the
2806 // Y*Z term.
2808 Mul->operands().take_front(MulOp));
2809 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2810 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2811 }
2812 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2813 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2814 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2816 if (Ops.size() == 2) return OuterMul;
2817 if (AddOp < Idx) {
2818 Ops.erase(Ops.begin()+AddOp);
2819 Ops.erase(Ops.begin()+Idx-1);
2820 } else {
2821 Ops.erase(Ops.begin()+Idx);
2822 Ops.erase(Ops.begin()+AddOp-1);
2823 }
2824 Ops.push_back(OuterMul);
2825 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2826 }
2827
2828 // Check this multiply against other multiplies being added together.
2829 for (unsigned OtherMulIdx = Idx+1;
2830 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 ++OtherMulIdx) {
2832 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2833 // If MulOp occurs in OtherMul, we can fold the two multiplies
2834 // together.
2835 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2836 OMulOp != e; ++OMulOp)
2837 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2838 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2839 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2840 if (Mul->getNumOperands() != 2) {
2842 Mul->operands().take_front(MulOp));
2843 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2844 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2845 }
2846 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2847 if (OtherMul->getNumOperands() != 2) {
2849 OtherMul->operands().take_front(OMulOp));
2850 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2851 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2852 }
2853 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2854 const SCEV *InnerMulSum =
2855 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2856 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2858 if (Ops.size() == 2) return OuterMul;
2859 Ops.erase(Ops.begin()+Idx);
2860 Ops.erase(Ops.begin()+OtherMulIdx-1);
2861 Ops.push_back(OuterMul);
2862 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2863 }
2864 }
2865 }
2866 }
2867
2868 // If there are any add recurrences in the operands list, see if any other
2869 // added values are loop invariant. If so, we can fold them into the
2870 // recurrence.
2871 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2872 ++Idx;
2873
2874 // Scan over all recurrences, trying to fold loop invariants into them.
2875 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2876 // Scan all of the other operands to this add and add them to the vector if
2877 // they are loop invariant w.r.t. the recurrence.
2879 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2880 const Loop *AddRecLoop = AddRec->getLoop();
2881 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2882 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2883 LIOps.push_back(Ops[i]);
2884 Ops.erase(Ops.begin()+i);
2885 --i; --e;
2886 }
2887
2888 // If we found some loop invariants, fold them into the recurrence.
2889 if (!LIOps.empty()) {
2890 // Compute nowrap flags for the addition of the loop-invariant ops and
2891 // the addrec. Temporarily push it as an operand for that purpose. These
2892 // flags are valid in the scope of the addrec only.
2893 LIOps.push_back(AddRec);
2894 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2895 LIOps.pop_back();
2896
2897 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2898 LIOps.push_back(AddRec->getStart());
2899
2900 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2901
2902 // It is not in general safe to propagate flags valid on an add within
2903 // the addrec scope to one outside it. We must prove that the inner
2904 // scope is guaranteed to execute if the outer one does to be able to
2905 // safely propagate. We know the program is undefined if poison is
2906 // produced on the inner scoped addrec. We also know that *for this use*
2907 // the outer scoped add can't overflow (because of the flags we just
2908 // computed for the inner scoped add) without the program being undefined.
2909 // Proving that entry to the outer scope neccesitates entry to the inner
2910 // scope, thus proves the program undefined if the flags would be violated
2911 // in the outer scope.
2912 SCEV::NoWrapFlags AddFlags = Flags;
2913 if (AddFlags != SCEV::FlagAnyWrap) {
2914 auto *DefI = getDefiningScopeBound(LIOps);
2915 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2916 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2917 AddFlags = SCEV::FlagAnyWrap;
2918 }
2919 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2920
2921 // Build the new addrec. Propagate the NUW and NSW flags if both the
2922 // outer add and the inner addrec are guaranteed to have no overflow.
2923 // Always propagate NW.
2924 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2925 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2926
2927 // If all of the other operands were loop invariant, we are done.
2928 if (Ops.size() == 1) return NewRec;
2929
2930 // Otherwise, add the folded AddRec by the non-invariant parts.
2931 for (unsigned i = 0;; ++i)
2932 if (Ops[i] == AddRec) {
2933 Ops[i] = NewRec;
2934 break;
2935 }
2936 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2937 }
2938
2939 // Okay, if there weren't any loop invariants to be folded, check to see if
2940 // there are multiple AddRec's with the same loop induction variable being
2941 // added together. If so, we can fold them.
2942 for (unsigned OtherIdx = Idx+1;
2943 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2944 ++OtherIdx) {
2945 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2946 // so that the 1st found AddRecExpr is dominated by all others.
2947 assert(DT.dominates(
2948 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2949 AddRec->getLoop()->getHeader()) &&
2950 "AddRecExprs are not sorted in reverse dominance order?");
2951 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2952 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2953 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2954 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 ++OtherIdx) {
2956 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2957 if (OtherAddRec->getLoop() == AddRecLoop) {
2958 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2959 i != e; ++i) {
2960 if (i >= AddRecOps.size()) {
2961 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2962 break;
2963 }
2965 AddRecOps[i], OtherAddRec->getOperand(i)};
2966 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2969 }
2970 }
2971 // Step size has changed, so we cannot guarantee no self-wraparound.
2972 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2973 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2974 }
2975 }
2976
2977 // Otherwise couldn't fold anything into this recurrence. Move onto the
2978 // next one.
2979 }
2980
2981 // Okay, it looks like we really DO need an add expr. Check to see if we
2982 // already have one, otherwise create a new one.
2983 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2984}
2985
2986const SCEV *
2987ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2988 SCEV::NoWrapFlags Flags) {
2990 ID.AddInteger(scAddExpr);
2991 for (const SCEV *Op : Ops)
2992 ID.AddPointer(Op);
2993 void *IP = nullptr;
2994 SCEVAddExpr *S =
2995 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2996 if (!S) {
2997 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2999 S = new (SCEVAllocator)
3000 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3001 UniqueSCEVs.InsertNode(S, IP);
3002 registerUser(S, Ops);
3003 }
3004 S->setNoWrapFlags(Flags);
3005 return S;
3006}
3007
3008const SCEV *
3009ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3010 const Loop *L, SCEV::NoWrapFlags Flags) {
3011 FoldingSetNodeID ID;
3012 ID.AddInteger(scAddRecExpr);
3013 for (const SCEV *Op : Ops)
3014 ID.AddPointer(Op);
3015 ID.AddPointer(L);
3016 void *IP = nullptr;
3017 SCEVAddRecExpr *S =
3018 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3019 if (!S) {
3020 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3022 S = new (SCEVAllocator)
3023 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3024 UniqueSCEVs.InsertNode(S, IP);
3025 LoopUsers[L].push_back(S);
3026 registerUser(S, Ops);
3027 }
3028 setNoWrapFlags(S, Flags);
3029 return S;
3030}
3031
3032const SCEV *
3033ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3034 SCEV::NoWrapFlags Flags) {
3035 FoldingSetNodeID ID;
3036 ID.AddInteger(scMulExpr);
3037 for (const SCEV *Op : Ops)
3038 ID.AddPointer(Op);
3039 void *IP = nullptr;
3040 SCEVMulExpr *S =
3041 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3042 if (!S) {
3043 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3045 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3046 O, Ops.size());
3047 UniqueSCEVs.InsertNode(S, IP);
3048 registerUser(S, Ops);
3049 }
3050 S->setNoWrapFlags(Flags);
3051 return S;
3052}
3053
3054static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3055 uint64_t k = i*j;
3056 if (j > 1 && k / j != i) Overflow = true;
3057 return k;
3058}
3059
3060/// Compute the result of "n choose k", the binomial coefficient. If an
3061/// intermediate computation overflows, Overflow will be set and the return will
3062/// be garbage. Overflow is not cleared on absence of overflow.
3063static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3064 // We use the multiplicative formula:
3065 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3066 // At each iteration, we take the n-th term of the numeral and divide by the
3067 // (k-n)th term of the denominator. This division will always produce an
3068 // integral result, and helps reduce the chance of overflow in the
3069 // intermediate computations. However, we can still overflow even when the
3070 // final result would fit.
3071
3072 if (n == 0 || n == k) return 1;
3073 if (k > n) return 0;
3074
3075 if (k > n/2)
3076 k = n-k;
3077
3078 uint64_t r = 1;
3079 for (uint64_t i = 1; i <= k; ++i) {
3080 r = umul_ov(r, n-(i-1), Overflow);
3081 r /= i;
3082 }
3083 return r;
3084}
3085
3086/// Determine if any of the operands in this SCEV are a constant or if
3087/// any of the add or multiply expressions in this SCEV contain a constant.
3088static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3089 struct FindConstantInAddMulChain {
3090 bool FoundConstant = false;
3091
3092 bool follow(const SCEV *S) {
3093 FoundConstant |= isa<SCEVConstant>(S);
3094 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3095 }
3096
3097 bool isDone() const {
3098 return FoundConstant;
3099 }
3100 };
3101
3102 FindConstantInAddMulChain F;
3104 ST.visitAll(StartExpr);
3105 return F.FoundConstant;
3106}
3107
3108/// Get a canonical multiply expression, or something simpler if possible.
3110 SCEV::NoWrapFlags OrigFlags,
3111 unsigned Depth) {
3112 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3113 "only nuw or nsw allowed");
3114 assert(!Ops.empty() && "Cannot get empty mul!");
3115 if (Ops.size() == 1) return Ops[0];
3116#ifndef NDEBUG
3117 Type *ETy = Ops[0]->getType();
3118 assert(!ETy->isPointerTy());
3119 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3120 assert(Ops[i]->getType() == ETy &&
3121 "SCEVMulExpr operand types don't match!");
3122#endif
3123
3124 const SCEV *Folded = constantFoldAndGroupOps(
3125 *this, LI, DT, Ops,
3126 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3127 [](const APInt &C) { return C.isOne(); }, // identity
3128 [](const APInt &C) { return C.isZero(); }); // absorber
3129 if (Folded)
3130 return Folded;
3131
3132 // Delay expensive flag strengthening until necessary.
3133 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3134 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3135 };
3136
3137 // Limit recursion calls depth.
3139 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3140
3141 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3142 // Don't strengthen flags if we have no new information.
3143 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3144 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3145 Mul->setNoWrapFlags(ComputeFlags(Ops));
3146 return S;
3147 }
3148
3149 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3150 if (Ops.size() == 2) {
3151 // C1*(C2+V) -> C1*C2 + C1*V
3152 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3153 // If any of Add's ops are Adds or Muls with a constant, apply this
3154 // transformation as well.
3155 //
3156 // TODO: There are some cases where this transformation is not
3157 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3158 // this transformation should be narrowed down.
3159 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3160 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3162 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3164 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3165 }
3166
3167 if (Ops[0]->isAllOnesValue()) {
3168 // If we have a mul by -1 of an add, try distributing the -1 among the
3169 // add operands.
3170 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3172 bool AnyFolded = false;
3173 for (const SCEV *AddOp : Add->operands()) {
3174 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3175 Depth + 1);
3176 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3177 NewOps.push_back(Mul);
3178 }
3179 if (AnyFolded)
3180 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3181 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3182 // Negation preserves a recurrence's no self-wrap property.
3184 for (const SCEV *AddRecOp : AddRec->operands())
3185 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3186 Depth + 1));
3187 // Let M be the minimum representable signed value. AddRec with nsw
3188 // multiplied by -1 can have signed overflow if and only if it takes a
3189 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3190 // maximum signed value. In all other cases signed overflow is
3191 // impossible.
3192 auto FlagsMask = SCEV::FlagNW;
3193 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3194 auto MinInt =
3195 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3196 if (getSignedRangeMin(AddRec) != MinInt)
3197 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3198 }
3199 return getAddRecExpr(Operands, AddRec->getLoop(),
3200 AddRec->getNoWrapFlags(FlagsMask));
3201 }
3202 }
3203
3204 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3205 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3206 const SCEVAddExpr *InnerAdd;
3207 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3208 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3209 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3210 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3211 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3213 SCEV::FlagNUW)) {
3214 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3215 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3216 };
3217 }
3218
3219 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3220 // D is a multiple of C2, and C1 is a multiple of C2.
3221 const SCEV *D;
3222 APInt C1V = LHSC->getAPInt();
3223 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN.
3224 if (C1V.isNegative() && !C1V.isMinSignedValue())
3225 C1V = C1V.abs();
3226 const SCEVConstant *C2;
3227 if (C1V.isPowerOf2() &&
3229 C2->getAPInt().isPowerOf2() && C1V.uge(C2->getAPInt()) &&
3230 C1V.logBase2() <= getMinTrailingZeros(D)) {
3231 const SCEV *NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3232 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3233 }
3234 }
3235 }
3236
3237 // Skip over the add expression until we get to a multiply.
3238 unsigned Idx = 0;
3239 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3240 ++Idx;
3241
3242 // If there are mul operands inline them all into this expression.
3243 if (Idx < Ops.size()) {
3244 bool DeletedMul = false;
3245 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3246 if (Ops.size() > MulOpsInlineThreshold)
3247 break;
3248 // If we have an mul, expand the mul operands onto the end of the
3249 // operands list.
3250 Ops.erase(Ops.begin()+Idx);
3251 append_range(Ops, Mul->operands());
3252 DeletedMul = true;
3253 }
3254
3255 // If we deleted at least one mul, we added operands to the end of the
3256 // list, and they are not necessarily sorted. Recurse to resort and
3257 // resimplify any operands we just acquired.
3258 if (DeletedMul)
3259 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3260 }
3261
3262 // If there are any add recurrences in the operands list, see if any other
3263 // added values are loop invariant. If so, we can fold them into the
3264 // recurrence.
3265 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3266 ++Idx;
3267
3268 // Scan over all recurrences, trying to fold loop invariants into them.
3269 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3270 // Scan all of the other operands to this mul and add them to the vector
3271 // if they are loop invariant w.r.t. the recurrence.
3273 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3274 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3275 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3276 LIOps.push_back(Ops[i]);
3277 Ops.erase(Ops.begin()+i);
3278 --i; --e;
3279 }
3280
3281 // If we found some loop invariants, fold them into the recurrence.
3282 if (!LIOps.empty()) {
3283 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3285 NewOps.reserve(AddRec->getNumOperands());
3286 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3287
3288 // If both the mul and addrec are nuw, we can preserve nuw.
3289 // If both the mul and addrec are nsw, we can only preserve nsw if either
3290 // a) they are also nuw, or
3291 // b) all multiplications of addrec operands with scale are nsw.
3292 SCEV::NoWrapFlags Flags =
3293 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3294
3295 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3296 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3297 SCEV::FlagAnyWrap, Depth + 1));
3298
3299 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3301 Instruction::Mul, getSignedRange(Scale),
3303 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3304 Flags = clearFlags(Flags, SCEV::FlagNSW);
3305 }
3306 }
3307
3308 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3309
3310 // If all of the other operands were loop invariant, we are done.
3311 if (Ops.size() == 1) return NewRec;
3312
3313 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3314 for (unsigned i = 0;; ++i)
3315 if (Ops[i] == AddRec) {
3316 Ops[i] = NewRec;
3317 break;
3318 }
3319 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3320 }
3321
3322 // Okay, if there weren't any loop invariants to be folded, check to see
3323 // if there are multiple AddRec's with the same loop induction variable
3324 // being multiplied together. If so, we can fold them.
3325
3326 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3327 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3328 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3329 // ]]],+,...up to x=2n}.
3330 // Note that the arguments to choose() are always integers with values
3331 // known at compile time, never SCEV objects.
3332 //
3333 // The implementation avoids pointless extra computations when the two
3334 // addrec's are of different length (mathematically, it's equivalent to
3335 // an infinite stream of zeros on the right).
3336 bool OpsModified = false;
3337 for (unsigned OtherIdx = Idx+1;
3338 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3339 ++OtherIdx) {
3340 const SCEVAddRecExpr *OtherAddRec =
3341 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3342 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3343 continue;
3344
3345 // Limit max number of arguments to avoid creation of unreasonably big
3346 // SCEVAddRecs with very complex operands.
3347 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3348 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3349 continue;
3350
3351 bool Overflow = false;
3352 Type *Ty = AddRec->getType();
3353 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3355 for (int x = 0, xe = AddRec->getNumOperands() +
3356 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3358 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3359 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3360 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3361 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3362 z < ze && !Overflow; ++z) {
3363 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3364 uint64_t Coeff;
3365 if (LargerThan64Bits)
3366 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3367 else
3368 Coeff = Coeff1*Coeff2;
3369 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3370 const SCEV *Term1 = AddRec->getOperand(y-z);
3371 const SCEV *Term2 = OtherAddRec->getOperand(z);
3372 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3373 SCEV::FlagAnyWrap, Depth + 1));
3374 }
3375 }
3376 if (SumOps.empty())
3377 SumOps.push_back(getZero(Ty));
3378 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3379 }
3380 if (!Overflow) {
3381 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3383 if (Ops.size() == 2) return NewAddRec;
3384 Ops[Idx] = NewAddRec;
3385 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3386 OpsModified = true;
3387 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3388 if (!AddRec)
3389 break;
3390 }
3391 }
3392 if (OpsModified)
3393 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3394
3395 // Otherwise couldn't fold anything into this recurrence. Move onto the
3396 // next one.
3397 }
3398
3399 // Okay, it looks like we really DO need an mul expr. Check to see if we
3400 // already have one, otherwise create a new one.
3401 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3402}
3403
3404/// Represents an unsigned remainder expression based on unsigned division.
3406 const SCEV *RHS) {
3407 assert(getEffectiveSCEVType(LHS->getType()) ==
3408 getEffectiveSCEVType(RHS->getType()) &&
3409 "SCEVURemExpr operand types don't match!");
3410
3411 // Short-circuit easy cases
3412 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3413 // If constant is one, the result is trivial
3414 if (RHSC->getValue()->isOne())
3415 return getZero(LHS->getType()); // X urem 1 --> 0
3416
3417 // If constant is a power of two, fold into a zext(trunc(LHS)).
3418 if (RHSC->getAPInt().isPowerOf2()) {
3419 Type *FullTy = LHS->getType();
3420 Type *TruncTy =
3421 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3422 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3423 }
3424 }
3425
3426 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3427 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3428 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3429 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3430}
3431
3432/// Get a canonical unsigned division expression, or something simpler if
3433/// possible.
3435 const SCEV *RHS) {
3436 assert(!LHS->getType()->isPointerTy() &&
3437 "SCEVUDivExpr operand can't be pointer!");
3438 assert(LHS->getType() == RHS->getType() &&
3439 "SCEVUDivExpr operand types don't match!");
3440
3442 ID.AddInteger(scUDivExpr);
3443 ID.AddPointer(LHS);
3444 ID.AddPointer(RHS);
3445 void *IP = nullptr;
3446 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3447 return S;
3448
3449 // 0 udiv Y == 0
3450 if (match(LHS, m_scev_Zero()))
3451 return LHS;
3452
3453 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3454 if (RHSC->getValue()->isOne())
3455 return LHS; // X udiv 1 --> x
3456 // If the denominator is zero, the result of the udiv is undefined. Don't
3457 // try to analyze it, because the resolution chosen here may differ from
3458 // the resolution chosen in other parts of the compiler.
3459 if (!RHSC->getValue()->isZero()) {
3460 // Determine if the division can be folded into the operands of
3461 // its operands.
3462 // TODO: Generalize this to non-constants by using known-bits information.
3463 Type *Ty = LHS->getType();
3464 unsigned LZ = RHSC->getAPInt().countl_zero();
3465 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3466 // For non-power-of-two values, effectively round the value up to the
3467 // nearest power of two.
3468 if (!RHSC->getAPInt().isPowerOf2())
3469 ++MaxShiftAmt;
3470 IntegerType *ExtTy =
3471 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3472 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3473 if (const SCEVConstant *Step =
3474 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3475 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3476 const APInt &StepInt = Step->getAPInt();
3477 const APInt &DivInt = RHSC->getAPInt();
3478 if (!StepInt.urem(DivInt) &&
3479 getZeroExtendExpr(AR, ExtTy) ==
3480 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3481 getZeroExtendExpr(Step, ExtTy),
3482 AR->getLoop(), SCEV::FlagAnyWrap)) {
3484 for (const SCEV *Op : AR->operands())
3485 Operands.push_back(getUDivExpr(Op, RHS));
3486 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3487 }
3488 /// Get a canonical UDivExpr for a recurrence.
3489 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3490 // We can currently only fold X%N if X is constant.
3491 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3492 if (StartC && !DivInt.urem(StepInt) &&
3493 getZeroExtendExpr(AR, ExtTy) ==
3494 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3495 getZeroExtendExpr(Step, ExtTy),
3496 AR->getLoop(), SCEV::FlagAnyWrap)) {
3497 const APInt &StartInt = StartC->getAPInt();
3498 const APInt &StartRem = StartInt.urem(StepInt);
3499 if (StartRem != 0) {
3500 const SCEV *NewLHS =
3501 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3502 AR->getLoop(), SCEV::FlagNW);
3503 if (LHS != NewLHS) {
3504 LHS = NewLHS;
3505
3506 // Reset the ID to include the new LHS, and check if it is
3507 // already cached.
3508 ID.clear();
3509 ID.AddInteger(scUDivExpr);
3510 ID.AddPointer(LHS);
3511 ID.AddPointer(RHS);
3512 IP = nullptr;
3513 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3514 return S;
3515 }
3516 }
3517 }
3518 }
3519 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3520 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3522 for (const SCEV *Op : M->operands())
3523 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3524 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3525 // Find an operand that's safely divisible.
3526 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3527 const SCEV *Op = M->getOperand(i);
3528 const SCEV *Div = getUDivExpr(Op, RHSC);
3529 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3530 Operands = SmallVector<const SCEV *, 4>(M->operands());
3531 Operands[i] = Div;
3532 return getMulExpr(Operands);
3533 }
3534 }
3535 }
3536
3537 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3538 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3539 if (auto *DivisorConstant =
3540 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3541 bool Overflow = false;
3542 APInt NewRHS =
3543 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3544 if (Overflow) {
3545 return getConstant(RHSC->getType(), 0, false);
3546 }
3547 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3548 }
3549 }
3550
3551 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3552 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3554 for (const SCEV *Op : A->operands())
3555 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3556 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3557 Operands.clear();
3558 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3559 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3560 if (isa<SCEVUDivExpr>(Op) ||
3561 getMulExpr(Op, RHS) != A->getOperand(i))
3562 break;
3563 Operands.push_back(Op);
3564 }
3565 if (Operands.size() == A->getNumOperands())
3566 return getAddExpr(Operands);
3567 }
3568 }
3569
3570 // Fold if both operands are constant.
3571 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3572 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3573 }
3574 }
3575
3576 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3577 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3578 AE && AE->getNumOperands() == 2) {
3579 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3580 const APInt &NegC = VC->getAPInt();
3581 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3582 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3583 if (MME && MME->getNumOperands() == 2 &&
3584 isa<SCEVConstant>(MME->getOperand(0)) &&
3585 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3586 MME->getOperand(1) == RHS)
3587 return getZero(LHS->getType());
3588 }
3589 }
3590 }
3591
3592 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3593 // changes). Make sure we get a new one.
3594 IP = nullptr;
3595 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3596 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3597 LHS, RHS);
3598 UniqueSCEVs.InsertNode(S, IP);
3599 registerUser(S, {LHS, RHS});
3600 return S;
3601}
3602
3603APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3604 APInt A = C1->getAPInt().abs();
3605 APInt B = C2->getAPInt().abs();
3606 uint32_t ABW = A.getBitWidth();
3607 uint32_t BBW = B.getBitWidth();
3608
3609 if (ABW > BBW)
3610 B = B.zext(ABW);
3611 else if (ABW < BBW)
3612 A = A.zext(BBW);
3613
3614 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3615}
3616
3617/// Get a canonical unsigned division expression, or something simpler if
3618/// possible. There is no representation for an exact udiv in SCEV IR, but we
3619/// can attempt to remove factors from the LHS and RHS. We can't do this when
3620/// it's not exact because the udiv may be clearing bits.
3622 const SCEV *RHS) {
3623 // TODO: we could try to find factors in all sorts of things, but for now we
3624 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3625 // end of this file for inspiration.
3626
3628 if (!Mul || !Mul->hasNoUnsignedWrap())
3629 return getUDivExpr(LHS, RHS);
3630
3631 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3632 // If the mulexpr multiplies by a constant, then that constant must be the
3633 // first element of the mulexpr.
3634 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3635 if (LHSCst == RHSCst) {
3637 return getMulExpr(Operands);
3638 }
3639
3640 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3641 // that there's a factor provided by one of the other terms. We need to
3642 // check.
3643 APInt Factor = gcd(LHSCst, RHSCst);
3644 if (!Factor.isIntN(1)) {
3645 LHSCst =
3646 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3647 RHSCst =
3648 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3650 Operands.push_back(LHSCst);
3651 append_range(Operands, Mul->operands().drop_front());
3652 LHS = getMulExpr(Operands);
3653 RHS = RHSCst;
3655 if (!Mul)
3656 return getUDivExactExpr(LHS, RHS);
3657 }
3658 }
3659 }
3660
3661 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3662 if (Mul->getOperand(i) == RHS) {
3664 append_range(Operands, Mul->operands().take_front(i));
3665 append_range(Operands, Mul->operands().drop_front(i + 1));
3666 return getMulExpr(Operands);
3667 }
3668 }
3669
3670 return getUDivExpr(LHS, RHS);
3671}
3672
3673/// Get an add recurrence expression for the specified loop. Simplify the
3674/// expression as much as possible.
3675const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3676 const Loop *L,
3677 SCEV::NoWrapFlags Flags) {
3679 Operands.push_back(Start);
3680 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3681 if (StepChrec->getLoop() == L) {
3682 append_range(Operands, StepChrec->operands());
3683 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3684 }
3685
3686 Operands.push_back(Step);
3687 return getAddRecExpr(Operands, L, Flags);
3688}
3689
3690/// Get an add recurrence expression for the specified loop. Simplify the
3691/// expression as much as possible.
3692const SCEV *
3694 const Loop *L, SCEV::NoWrapFlags Flags) {
3695 if (Operands.size() == 1) return Operands[0];
3696#ifndef NDEBUG
3698 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3699 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3700 "SCEVAddRecExpr operand types don't match!");
3701 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3702 }
3703 for (const SCEV *Op : Operands)
3705 "SCEVAddRecExpr operand is not available at loop entry!");
3706#endif
3707
3708 if (Operands.back()->isZero()) {
3709 Operands.pop_back();
3710 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3711 }
3712
3713 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3714 // use that information to infer NUW and NSW flags. However, computing a
3715 // BE count requires calling getAddRecExpr, so we may not yet have a
3716 // meaningful BE count at this point (and if we don't, we'd be stuck
3717 // with a SCEVCouldNotCompute as the cached BE count).
3718
3719 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3720
3721 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3722 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3723 const Loop *NestedLoop = NestedAR->getLoop();
3724 if (L->contains(NestedLoop)
3725 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3726 : (!NestedLoop->contains(L) &&
3727 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3728 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3729 Operands[0] = NestedAR->getStart();
3730 // AddRecs require their operands be loop-invariant with respect to their
3731 // loops. Don't perform this transformation if it would break this
3732 // requirement.
3733 bool AllInvariant = all_of(
3734 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3735
3736 if (AllInvariant) {
3737 // Create a recurrence for the outer loop with the same step size.
3738 //
3739 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3740 // inner recurrence has the same property.
3741 SCEV::NoWrapFlags OuterFlags =
3742 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3743
3744 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3745 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3746 return isLoopInvariant(Op, NestedLoop);
3747 });
3748
3749 if (AllInvariant) {
3750 // Ok, both add recurrences are valid after the transformation.
3751 //
3752 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3753 // the outer recurrence has the same property.
3754 SCEV::NoWrapFlags InnerFlags =
3755 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3756 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3757 }
3758 }
3759 // Reset Operands to its original state.
3760 Operands[0] = NestedAR;
3761 }
3762 }
3763
3764 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3765 // already have one, otherwise create a new one.
3766 return getOrCreateAddRecExpr(Operands, L, Flags);
3767}
3768
3769const SCEV *
3771 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3772 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3773 // getSCEV(Base)->getType() has the same address space as Base->getType()
3774 // because SCEV::getType() preserves the address space.
3775 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3776 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3777 if (NW != GEPNoWrapFlags::none()) {
3778 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3779 // but to do that, we have to ensure that said flag is valid in the entire
3780 // defined scope of the SCEV.
3781 // TODO: non-instructions have global scope. We might be able to prove
3782 // some global scope cases
3783 auto *GEPI = dyn_cast<Instruction>(GEP);
3784 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3785 NW = GEPNoWrapFlags::none();
3786 }
3787
3789 if (NW.hasNoUnsignedSignedWrap())
3790 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3791 if (NW.hasNoUnsignedWrap())
3792 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3793
3794 Type *CurTy = GEP->getType();
3795 bool FirstIter = true;
3797 for (const SCEV *IndexExpr : IndexExprs) {
3798 // Compute the (potentially symbolic) offset in bytes for this index.
3799 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3800 // For a struct, add the member offset.
3801 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3802 unsigned FieldNo = Index->getZExtValue();
3803 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3804 Offsets.push_back(FieldOffset);
3805
3806 // Update CurTy to the type of the field at Index.
3807 CurTy = STy->getTypeAtIndex(Index);
3808 } else {
3809 // Update CurTy to its element type.
3810 if (FirstIter) {
3811 assert(isa<PointerType>(CurTy) &&
3812 "The first index of a GEP indexes a pointer");
3813 CurTy = GEP->getSourceElementType();
3814 FirstIter = false;
3815 } else {
3817 }
3818 // For an array, add the element offset, explicitly scaled.
3819 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3820 // Getelementptr indices are signed.
3821 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3822
3823 // Multiply the index by the element size to compute the element offset.
3824 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3825 Offsets.push_back(LocalOffset);
3826 }
3827 }
3828
3829 // Handle degenerate case of GEP without offsets.
3830 if (Offsets.empty())
3831 return BaseExpr;
3832
3833 // Add the offsets together, assuming nsw if inbounds.
3834 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3835 // Add the base address and the offset. We cannot use the nsw flag, as the
3836 // base address is unsigned. However, if we know that the offset is
3837 // non-negative, we can use nuw.
3838 bool NUW = NW.hasNoUnsignedWrap() ||
3841 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3842 assert(BaseExpr->getType() == GEPExpr->getType() &&
3843 "GEP should not change type mid-flight.");
3844 return GEPExpr;
3845}
3846
3847SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3850 ID.AddInteger(SCEVType);
3851 for (const SCEV *Op : Ops)
3852 ID.AddPointer(Op);
3853 void *IP = nullptr;
3854 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3855}
3856
3857const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3859 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3860}
3861
3864 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3865 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3866 if (Ops.size() == 1) return Ops[0];
3867#ifndef NDEBUG
3868 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3869 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3870 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3871 "Operand types don't match!");
3872 assert(Ops[0]->getType()->isPointerTy() ==
3873 Ops[i]->getType()->isPointerTy() &&
3874 "min/max should be consistently pointerish");
3875 }
3876#endif
3877
3878 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3879 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3880
3881 const SCEV *Folded = constantFoldAndGroupOps(
3882 *this, LI, DT, Ops,
3883 [&](const APInt &C1, const APInt &C2) {
3884 switch (Kind) {
3885 case scSMaxExpr:
3886 return APIntOps::smax(C1, C2);
3887 case scSMinExpr:
3888 return APIntOps::smin(C1, C2);
3889 case scUMaxExpr:
3890 return APIntOps::umax(C1, C2);
3891 case scUMinExpr:
3892 return APIntOps::umin(C1, C2);
3893 default:
3894 llvm_unreachable("Unknown SCEV min/max opcode");
3895 }
3896 },
3897 [&](const APInt &C) {
3898 // identity
3899 if (IsMax)
3900 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3901 else
3902 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3903 },
3904 [&](const APInt &C) {
3905 // absorber
3906 if (IsMax)
3907 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3908 else
3909 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3910 });
3911 if (Folded)
3912 return Folded;
3913
3914 // Check if we have created the same expression before.
3915 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3916 return S;
3917 }
3918
3919 // Find the first operation of the same kind
3920 unsigned Idx = 0;
3921 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3922 ++Idx;
3923
3924 // Check to see if one of the operands is of the same kind. If so, expand its
3925 // operands onto our operand list, and recurse to simplify.
3926 if (Idx < Ops.size()) {
3927 bool DeletedAny = false;
3928 while (Ops[Idx]->getSCEVType() == Kind) {
3929 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3930 Ops.erase(Ops.begin()+Idx);
3931 append_range(Ops, SMME->operands());
3932 DeletedAny = true;
3933 }
3934
3935 if (DeletedAny)
3936 return getMinMaxExpr(Kind, Ops);
3937 }
3938
3939 // Okay, check to see if the same value occurs in the operand list twice. If
3940 // so, delete one. Since we sorted the list, these values are required to
3941 // be adjacent.
3946 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3947 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3948 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3949 if (Ops[i] == Ops[i + 1] ||
3950 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3951 // X op Y op Y --> X op Y
3952 // X op Y --> X, if we know X, Y are ordered appropriately
3953 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3954 --i;
3955 --e;
3956 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3957 Ops[i + 1])) {
3958 // X op Y --> Y, if we know X, Y are ordered appropriately
3959 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3960 --i;
3961 --e;
3962 }
3963 }
3964
3965 if (Ops.size() == 1) return Ops[0];
3966
3967 assert(!Ops.empty() && "Reduced smax down to nothing!");
3968
3969 // Okay, it looks like we really DO need an expr. Check to see if we
3970 // already have one, otherwise create a new one.
3972 ID.AddInteger(Kind);
3973 for (const SCEV *Op : Ops)
3974 ID.AddPointer(Op);
3975 void *IP = nullptr;
3976 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3977 if (ExistingSCEV)
3978 return ExistingSCEV;
3979 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3981 SCEV *S = new (SCEVAllocator)
3982 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3983
3984 UniqueSCEVs.InsertNode(S, IP);
3985 registerUser(S, Ops);
3986 return S;
3987}
3988
3989namespace {
3990
3991class SCEVSequentialMinMaxDeduplicatingVisitor final
3992 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3993 std::optional<const SCEV *>> {
3994 using RetVal = std::optional<const SCEV *>;
3996
3997 ScalarEvolution &SE;
3998 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3999 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4001
4002 bool canRecurseInto(SCEVTypes Kind) const {
4003 // We can only recurse into the SCEV expression of the same effective type
4004 // as the type of our root SCEV expression.
4005 return RootKind == Kind || NonSequentialRootKind == Kind;
4006 };
4007
4008 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4010 "Only for min/max expressions.");
4011 SCEVTypes Kind = S->getSCEVType();
4012
4013 if (!canRecurseInto(Kind))
4014 return S;
4015
4016 auto *NAry = cast<SCEVNAryExpr>(S);
4018 bool Changed = visit(Kind, NAry->operands(), NewOps);
4019
4020 if (!Changed)
4021 return S;
4022 if (NewOps.empty())
4023 return std::nullopt;
4024
4026 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4027 : SE.getMinMaxExpr(Kind, NewOps);
4028 }
4029
4030 RetVal visit(const SCEV *S) {
4031 // Has the whole operand been seen already?
4032 if (!SeenOps.insert(S).second)
4033 return std::nullopt;
4034 return Base::visit(S);
4035 }
4036
4037public:
4038 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4039 SCEVTypes RootKind)
4040 : SE(SE), RootKind(RootKind),
4041 NonSequentialRootKind(
4042 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4043 RootKind)) {}
4044
4045 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4046 SmallVectorImpl<const SCEV *> &NewOps) {
4047 bool Changed = false;
4049 Ops.reserve(OrigOps.size());
4050
4051 for (const SCEV *Op : OrigOps) {
4052 RetVal NewOp = visit(Op);
4053 if (NewOp != Op)
4054 Changed = true;
4055 if (NewOp)
4056 Ops.emplace_back(*NewOp);
4057 }
4058
4059 if (Changed)
4060 NewOps = std::move(Ops);
4061 return Changed;
4062 }
4063
4064 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4065
4066 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4067
4068 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4069
4070 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4071
4072 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4073
4074 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4075
4076 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4077
4078 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4079
4080 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4081
4082 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4083
4084 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4085 return visitAnyMinMaxExpr(Expr);
4086 }
4087
4088 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4089 return visitAnyMinMaxExpr(Expr);
4090 }
4091
4092 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4093 return visitAnyMinMaxExpr(Expr);
4094 }
4095
4096 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4097 return visitAnyMinMaxExpr(Expr);
4098 }
4099
4100 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4101 return visitAnyMinMaxExpr(Expr);
4102 }
4103
4104 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4105
4106 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4107};
4108
4109} // namespace
4110
4112 switch (Kind) {
4113 case scConstant:
4114 case scVScale:
4115 case scTruncate:
4116 case scZeroExtend:
4117 case scSignExtend:
4118 case scPtrToInt:
4119 case scAddExpr:
4120 case scMulExpr:
4121 case scUDivExpr:
4122 case scAddRecExpr:
4123 case scUMaxExpr:
4124 case scSMaxExpr:
4125 case scUMinExpr:
4126 case scSMinExpr:
4127 case scUnknown:
4128 // If any operand is poison, the whole expression is poison.
4129 return true;
4131 // FIXME: if the *first* operand is poison, the whole expression is poison.
4132 return false; // Pessimistically, say that it does not propagate poison.
4133 case scCouldNotCompute:
4134 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4135 }
4136 llvm_unreachable("Unknown SCEV kind!");
4137}
4138
4139namespace {
4140// The only way poison may be introduced in a SCEV expression is from a
4141// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4142// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4143// introduce poison -- they encode guaranteed, non-speculated knowledge.
4144//
4145// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4146// with the notable exception of umin_seq, where only poison from the first
4147// operand is (unconditionally) propagated.
4148struct SCEVPoisonCollector {
4149 bool LookThroughMaybePoisonBlocking;
4150 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4151 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4152 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4153
4154 bool follow(const SCEV *S) {
4155 if (!LookThroughMaybePoisonBlocking &&
4157 return false;
4158
4159 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4160 if (!isGuaranteedNotToBePoison(SU->getValue()))
4161 MaybePoison.insert(SU);
4162 }
4163 return true;
4164 }
4165 bool isDone() const { return false; }
4166};
4167} // namespace
4168
4169/// Return true if V is poison given that AssumedPoison is already poison.
4170static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4171 // First collect all SCEVs that might result in AssumedPoison to be poison.
4172 // We need to look through potentially poison-blocking operations here,
4173 // because we want to find all SCEVs that *might* result in poison, not only
4174 // those that are *required* to.
4175 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4176 visitAll(AssumedPoison, PC1);
4177
4178 // AssumedPoison is never poison. As the assumption is false, the implication
4179 // is true. Don't bother walking the other SCEV in this case.
4180 if (PC1.MaybePoison.empty())
4181 return true;
4182
4183 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4184 // as well. We cannot look through potentially poison-blocking operations
4185 // here, as their arguments only *may* make the result poison.
4186 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4187 visitAll(S, PC2);
4188
4189 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4190 // it will also make S poison by being part of PC2.MaybePoison.
4191 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4192}
4193
4195 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4196 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4197 visitAll(S, PC);
4198 for (const SCEVUnknown *SU : PC.MaybePoison)
4199 Result.insert(SU->getValue());
4200}
4201
4203 const SCEV *S, Instruction *I,
4204 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4205 // If the instruction cannot be poison, it's always safe to reuse.
4207 return true;
4208
4209 // Otherwise, it is possible that I is more poisonous that S. Collect the
4210 // poison-contributors of S, and then check whether I has any additional
4211 // poison-contributors. Poison that is contributed through poison-generating
4212 // flags is handled by dropping those flags instead.
4214 getPoisonGeneratingValues(PoisonVals, S);
4215
4216 SmallVector<Value *> Worklist;
4218 Worklist.push_back(I);
4219 while (!Worklist.empty()) {
4220 Value *V = Worklist.pop_back_val();
4221 if (!Visited.insert(V).second)
4222 continue;
4223
4224 // Avoid walking large instruction graphs.
4225 if (Visited.size() > 16)
4226 return false;
4227
4228 // Either the value can't be poison, or the S would also be poison if it
4229 // is.
4230 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4231 continue;
4232
4233 auto *I = dyn_cast<Instruction>(V);
4234 if (!I)
4235 return false;
4236
4237 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4238 // can't replace an arbitrary add with disjoint or, even if we drop the
4239 // flag. We would need to convert the or into an add.
4240 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4241 if (PDI->isDisjoint())
4242 return false;
4243
4244 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4245 // because SCEV currently assumes it can't be poison. Remove this special
4246 // case once we proper model when vscale can be poison.
4247 if (auto *II = dyn_cast<IntrinsicInst>(I);
4248 II && II->getIntrinsicID() == Intrinsic::vscale)
4249 continue;
4250
4251 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4252 return false;
4253
4254 // If the instruction can't create poison, we can recurse to its operands.
4255 if (I->hasPoisonGeneratingAnnotations())
4256 DropPoisonGeneratingInsts.push_back(I);
4257
4258 llvm::append_range(Worklist, I->operands());
4259 }
4260 return true;
4261}
4262
4263const SCEV *
4266 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4267 "Not a SCEVSequentialMinMaxExpr!");
4268 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4269 if (Ops.size() == 1)
4270 return Ops[0];
4271#ifndef NDEBUG
4272 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4273 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4274 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4275 "Operand types don't match!");
4276 assert(Ops[0]->getType()->isPointerTy() ==
4277 Ops[i]->getType()->isPointerTy() &&
4278 "min/max should be consistently pointerish");
4279 }
4280#endif
4281
4282 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4283 // so we can *NOT* do any kind of sorting of the expressions!
4284
4285 // Check if we have created the same expression before.
4286 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4287 return S;
4288
4289 // FIXME: there are *some* simplifications that we can do here.
4290
4291 // Keep only the first instance of an operand.
4292 {
4293 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4294 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4295 if (Changed)
4296 return getSequentialMinMaxExpr(Kind, Ops);
4297 }
4298
4299 // Check to see if one of the operands is of the same kind. If so, expand its
4300 // operands onto our operand list, and recurse to simplify.
4301 {
4302 unsigned Idx = 0;
4303 bool DeletedAny = false;
4304 while (Idx < Ops.size()) {
4305 if (Ops[Idx]->getSCEVType() != Kind) {
4306 ++Idx;
4307 continue;
4308 }
4309 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4310 Ops.erase(Ops.begin() + Idx);
4311 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4312 SMME->operands().end());
4313 DeletedAny = true;
4314 }
4315
4316 if (DeletedAny)
4317 return getSequentialMinMaxExpr(Kind, Ops);
4318 }
4319
4320 const SCEV *SaturationPoint;
4322 switch (Kind) {
4324 SaturationPoint = getZero(Ops[0]->getType());
4325 Pred = ICmpInst::ICMP_ULE;
4326 break;
4327 default:
4328 llvm_unreachable("Not a sequential min/max type.");
4329 }
4330
4331 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4332 if (!isGuaranteedNotToCauseUB(Ops[i]))
4333 continue;
4334 // We can replace %x umin_seq %y with %x umin %y if either:
4335 // * %y being poison implies %x is also poison.
4336 // * %x cannot be the saturating value (e.g. zero for umin).
4337 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4338 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4339 SaturationPoint)) {
4340 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4341 Ops[i - 1] = getMinMaxExpr(
4343 SeqOps);
4344 Ops.erase(Ops.begin() + i);
4345 return getSequentialMinMaxExpr(Kind, Ops);
4346 }
4347 // Fold %x umin_seq %y to %x if %x ule %y.
4348 // TODO: We might be able to prove the predicate for a later operand.
4349 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4350 Ops.erase(Ops.begin() + i);
4351 return getSequentialMinMaxExpr(Kind, Ops);
4352 }
4353 }
4354
4355 // Okay, it looks like we really DO need an expr. Check to see if we
4356 // already have one, otherwise create a new one.
4358 ID.AddInteger(Kind);
4359 for (const SCEV *Op : Ops)
4360 ID.AddPointer(Op);
4361 void *IP = nullptr;
4362 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4363 if (ExistingSCEV)
4364 return ExistingSCEV;
4365
4366 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4368 SCEV *S = new (SCEVAllocator)
4369 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4370
4371 UniqueSCEVs.InsertNode(S, IP);
4372 registerUser(S, Ops);
4373 return S;
4374}
4375
4376const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4377 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4378 return getSMaxExpr(Ops);
4379}
4380
4384
4385const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4386 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4387 return getUMaxExpr(Ops);
4388}
4389
4393
4395 const SCEV *RHS) {
4396 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4397 return getSMinExpr(Ops);
4398}
4399
4403
4404const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4405 bool Sequential) {
4406 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4407 return getUMinExpr(Ops, Sequential);
4408}
4409
4415
4416const SCEV *
4418 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4419 if (Size.isScalable())
4420 Res = getMulExpr(Res, getVScale(IntTy));
4421 return Res;
4422}
4423
4425 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4426}
4427
4429 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4430}
4431
4433 StructType *STy,
4434 unsigned FieldNo) {
4435 // We can bypass creating a target-independent constant expression and then
4436 // folding it back into a ConstantInt. This is just a compile-time
4437 // optimization.
4438 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4439 assert(!SL->getSizeInBits().isScalable() &&
4440 "Cannot get offset for structure containing scalable vector types");
4441 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4442}
4443
4445 // Don't attempt to do anything other than create a SCEVUnknown object
4446 // here. createSCEV only calls getUnknown after checking for all other
4447 // interesting possibilities, and any other code that calls getUnknown
4448 // is doing so in order to hide a value from SCEV canonicalization.
4449
4451 ID.AddInteger(scUnknown);
4452 ID.AddPointer(V);
4453 void *IP = nullptr;
4454 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4455 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4456 "Stale SCEVUnknown in uniquing map!");
4457 return S;
4458 }
4459 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4460 FirstUnknown);
4461 FirstUnknown = cast<SCEVUnknown>(S);
4462 UniqueSCEVs.InsertNode(S, IP);
4463 return S;
4464}
4465
4466//===----------------------------------------------------------------------===//
4467// Basic SCEV Analysis and PHI Idiom Recognition Code
4468//
4469
4470/// Test if values of the given type are analyzable within the SCEV
4471/// framework. This primarily includes integer types, and it can optionally
4472/// include pointer types if the ScalarEvolution class has access to
4473/// target-specific information.
4475 // Integers and pointers are always SCEVable.
4476 return Ty->isIntOrPtrTy();
4477}
4478
4479/// Return the size in bits of the specified type, for which isSCEVable must
4480/// return true.
4482 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4483 if (Ty->isPointerTy())
4485 return getDataLayout().getTypeSizeInBits(Ty);
4486}
4487
4488/// Return a type with the same bitwidth as the given type and which represents
4489/// how SCEV will treat the given type, for which isSCEVable must return
4490/// true. For pointer types, this is the pointer index sized integer type.
4492 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4493
4494 if (Ty->isIntegerTy())
4495 return Ty;
4496
4497 // The only other support type is pointer.
4498 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4499 return getDataLayout().getIndexType(Ty);
4500}
4501
4503 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4504}
4505
4507 const SCEV *B) {
4508 /// For a valid use point to exist, the defining scope of one operand
4509 /// must dominate the other.
4510 bool PreciseA, PreciseB;
4511 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4512 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4513 if (!PreciseA || !PreciseB)
4514 // Can't tell.
4515 return false;
4516 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4517 DT.dominates(ScopeB, ScopeA);
4518}
4519
4521 return CouldNotCompute.get();
4522}
4523
4524bool ScalarEvolution::checkValidity(const SCEV *S) const {
4525 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4526 auto *SU = dyn_cast<SCEVUnknown>(S);
4527 return SU && SU->getValue() == nullptr;
4528 });
4529
4530 return !ContainsNulls;
4531}
4532
4534 HasRecMapType::iterator I = HasRecMap.find(S);
4535 if (I != HasRecMap.end())
4536 return I->second;
4537
4538 bool FoundAddRec =
4539 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4540 HasRecMap.insert({S, FoundAddRec});
4541 return FoundAddRec;
4542}
4543
4544/// Return the ValueOffsetPair set for \p S. \p S can be represented
4545/// by the value and offset from any ValueOffsetPair in the set.
4546ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4547 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4548 if (SI == ExprValueMap.end())
4549 return {};
4550 return SI->second.getArrayRef();
4551}
4552
4553/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4554/// cannot be used separately. eraseValueFromMap should be used to remove
4555/// V from ValueExprMap and ExprValueMap at the same time.
4556void ScalarEvolution::eraseValueFromMap(Value *V) {
4557 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4558 if (I != ValueExprMap.end()) {
4559 auto EVIt = ExprValueMap.find(I->second);
4560 bool Removed = EVIt->second.remove(V);
4561 (void) Removed;
4562 assert(Removed && "Value not in ExprValueMap?");
4563 ValueExprMap.erase(I);
4564 }
4565}
4566
4567void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4568 // A recursive query may have already computed the SCEV. It should be
4569 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4570 // inferred nowrap flags.
4571 auto It = ValueExprMap.find_as(V);
4572 if (It == ValueExprMap.end()) {
4573 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4574 ExprValueMap[S].insert(V);
4575 }
4576}
4577
4578/// Return an existing SCEV if it exists, otherwise analyze the expression and
4579/// create a new one.
4581 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4582
4583 if (const SCEV *S = getExistingSCEV(V))
4584 return S;
4585 return createSCEVIter(V);
4586}
4587
4589 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4590
4591 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4592 if (I != ValueExprMap.end()) {
4593 const SCEV *S = I->second;
4594 assert(checkValidity(S) &&
4595 "existing SCEV has not been properly invalidated");
4596 return S;
4597 }
4598 return nullptr;
4599}
4600
4601/// Return a SCEV corresponding to -V = -1*V
4603 SCEV::NoWrapFlags Flags) {
4604 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4605 return getConstant(
4606 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4607
4608 Type *Ty = V->getType();
4609 Ty = getEffectiveSCEVType(Ty);
4610 return getMulExpr(V, getMinusOne(Ty), Flags);
4611}
4612
4613/// If Expr computes ~A, return A else return nullptr
4614static const SCEV *MatchNotExpr(const SCEV *Expr) {
4615 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4616 if (!Add || Add->getNumOperands() != 2 ||
4617 !Add->getOperand(0)->isAllOnesValue())
4618 return nullptr;
4619
4620 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4621 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4622 !AddRHS->getOperand(0)->isAllOnesValue())
4623 return nullptr;
4624
4625 return AddRHS->getOperand(1);
4626}
4627
4628/// Return a SCEV corresponding to ~V = -1-V
4630 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4631
4632 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4633 return getConstant(
4634 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4635
4636 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4637 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4638 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4639 SmallVector<const SCEV *, 2> MatchedOperands;
4640 for (const SCEV *Operand : MME->operands()) {
4641 const SCEV *Matched = MatchNotExpr(Operand);
4642 if (!Matched)
4643 return (const SCEV *)nullptr;
4644 MatchedOperands.push_back(Matched);
4645 }
4646 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4647 MatchedOperands);
4648 };
4649 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4650 return Replaced;
4651 }
4652
4653 Type *Ty = V->getType();
4654 Ty = getEffectiveSCEVType(Ty);
4655 return getMinusSCEV(getMinusOne(Ty), V);
4656}
4657
4659 assert(P->getType()->isPointerTy());
4660
4661 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4662 // The base of an AddRec is the first operand.
4663 SmallVector<const SCEV *> Ops{AddRec->operands()};
4664 Ops[0] = removePointerBase(Ops[0]);
4665 // Don't try to transfer nowrap flags for now. We could in some cases
4666 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4667 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4668 }
4669 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4670 // The base of an Add is the pointer operand.
4671 SmallVector<const SCEV *> Ops{Add->operands()};
4672 const SCEV **PtrOp = nullptr;
4673 for (const SCEV *&AddOp : Ops) {
4674 if (AddOp->getType()->isPointerTy()) {
4675 assert(!PtrOp && "Cannot have multiple pointer ops");
4676 PtrOp = &AddOp;
4677 }
4678 }
4679 *PtrOp = removePointerBase(*PtrOp);
4680 // Don't try to transfer nowrap flags for now. We could in some cases
4681 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4682 return getAddExpr(Ops);
4683 }
4684 // Any other expression must be a pointer base.
4685 return getZero(P->getType());
4686}
4687
4688const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4689 SCEV::NoWrapFlags Flags,
4690 unsigned Depth) {
4691 // Fast path: X - X --> 0.
4692 if (LHS == RHS)
4693 return getZero(LHS->getType());
4694
4695 // If we subtract two pointers with different pointer bases, bail.
4696 // Eventually, we're going to add an assertion to getMulExpr that we
4697 // can't multiply by a pointer.
4698 if (RHS->getType()->isPointerTy()) {
4699 if (!LHS->getType()->isPointerTy() ||
4700 getPointerBase(LHS) != getPointerBase(RHS))
4701 return getCouldNotCompute();
4702 LHS = removePointerBase(LHS);
4703 RHS = removePointerBase(RHS);
4704 }
4705
4706 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4707 // makes it so that we cannot make much use of NUW.
4708 auto AddFlags = SCEV::FlagAnyWrap;
4709 const bool RHSIsNotMinSigned =
4711 if (hasFlags(Flags, SCEV::FlagNSW)) {
4712 // Let M be the minimum representable signed value. Then (-1)*RHS
4713 // signed-wraps if and only if RHS is M. That can happen even for
4714 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4715 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4716 // (-1)*RHS, we need to prove that RHS != M.
4717 //
4718 // If LHS is non-negative and we know that LHS - RHS does not
4719 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4720 // either by proving that RHS > M or that LHS >= 0.
4721 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4722 AddFlags = SCEV::FlagNSW;
4723 }
4724 }
4725
4726 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4727 // RHS is NSW and LHS >= 0.
4728 //
4729 // The difficulty here is that the NSW flag may have been proven
4730 // relative to a loop that is to be found in a recurrence in LHS and
4731 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4732 // larger scope than intended.
4733 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4734
4735 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4736}
4737
4739 unsigned Depth) {
4740 Type *SrcTy = V->getType();
4741 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4742 "Cannot truncate or zero extend with non-integer arguments!");
4743 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4744 return V; // No conversion
4745 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4746 return getTruncateExpr(V, Ty, Depth);
4747 return getZeroExtendExpr(V, Ty, Depth);
4748}
4749
4751 unsigned Depth) {
4752 Type *SrcTy = V->getType();
4753 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4754 "Cannot truncate or zero extend with non-integer arguments!");
4755 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4756 return V; // No conversion
4757 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4758 return getTruncateExpr(V, Ty, Depth);
4759 return getSignExtendExpr(V, Ty, Depth);
4760}
4761
4762const SCEV *
4764 Type *SrcTy = V->getType();
4765 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4766 "Cannot noop or zero extend with non-integer arguments!");
4768 "getNoopOrZeroExtend cannot truncate!");
4769 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4770 return V; // No conversion
4771 return getZeroExtendExpr(V, Ty);
4772}
4773
4774const SCEV *
4776 Type *SrcTy = V->getType();
4777 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4778 "Cannot noop or sign extend with non-integer arguments!");
4780 "getNoopOrSignExtend cannot truncate!");
4781 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4782 return V; // No conversion
4783 return getSignExtendExpr(V, Ty);
4784}
4785
4786const SCEV *
4788 Type *SrcTy = V->getType();
4789 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4790 "Cannot noop or any extend with non-integer arguments!");
4792 "getNoopOrAnyExtend cannot truncate!");
4793 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4794 return V; // No conversion
4795 return getAnyExtendExpr(V, Ty);
4796}
4797
4798const SCEV *
4800 Type *SrcTy = V->getType();
4801 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4802 "Cannot truncate or noop with non-integer arguments!");
4804 "getTruncateOrNoop cannot extend!");
4805 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4806 return V; // No conversion
4807 return getTruncateExpr(V, Ty);
4808}
4809
4811 const SCEV *RHS) {
4812 const SCEV *PromotedLHS = LHS;
4813 const SCEV *PromotedRHS = RHS;
4814
4815 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4816 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4817 else
4818 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4819
4820 return getUMaxExpr(PromotedLHS, PromotedRHS);
4821}
4822
4824 const SCEV *RHS,
4825 bool Sequential) {
4826 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4827 return getUMinFromMismatchedTypes(Ops, Sequential);
4828}
4829
4830const SCEV *
4832 bool Sequential) {
4833 assert(!Ops.empty() && "At least one operand must be!");
4834 // Trivial case.
4835 if (Ops.size() == 1)
4836 return Ops[0];
4837
4838 // Find the max type first.
4839 Type *MaxType = nullptr;
4840 for (const auto *S : Ops)
4841 if (MaxType)
4842 MaxType = getWiderType(MaxType, S->getType());
4843 else
4844 MaxType = S->getType();
4845 assert(MaxType && "Failed to find maximum type!");
4846
4847 // Extend all ops to max type.
4848 SmallVector<const SCEV *, 2> PromotedOps;
4849 for (const auto *S : Ops)
4850 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4851
4852 // Generate umin.
4853 return getUMinExpr(PromotedOps, Sequential);
4854}
4855
4857 // A pointer operand may evaluate to a nonpointer expression, such as null.
4858 if (!V->getType()->isPointerTy())
4859 return V;
4860
4861 while (true) {
4862 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4863 V = AddRec->getStart();
4864 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4865 const SCEV *PtrOp = nullptr;
4866 for (const SCEV *AddOp : Add->operands()) {
4867 if (AddOp->getType()->isPointerTy()) {
4868 assert(!PtrOp && "Cannot have multiple pointer ops");
4869 PtrOp = AddOp;
4870 }
4871 }
4872 assert(PtrOp && "Must have pointer op");
4873 V = PtrOp;
4874 } else // Not something we can look further into.
4875 return V;
4876 }
4877}
4878
4879/// Push users of the given Instruction onto the given Worklist.
4883 // Push the def-use children onto the Worklist stack.
4884 for (User *U : I->users()) {
4885 auto *UserInsn = cast<Instruction>(U);
4886 if (Visited.insert(UserInsn).second)
4887 Worklist.push_back(UserInsn);
4888 }
4889}
4890
4891namespace {
4892
4893/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4894/// expression in case its Loop is L. If it is not L then
4895/// if IgnoreOtherLoops is true then use AddRec itself
4896/// otherwise rewrite cannot be done.
4897/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4898class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4899public:
4900 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4901 bool IgnoreOtherLoops = true) {
4902 SCEVInitRewriter Rewriter(L, SE);
4903 const SCEV *Result = Rewriter.visit(S);
4904 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4905 return SE.getCouldNotCompute();
4906 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4907 ? SE.getCouldNotCompute()
4908 : Result;
4909 }
4910
4911 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4912 if (!SE.isLoopInvariant(Expr, L))
4913 SeenLoopVariantSCEVUnknown = true;
4914 return Expr;
4915 }
4916
4917 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4918 // Only re-write AddRecExprs for this loop.
4919 if (Expr->getLoop() == L)
4920 return Expr->getStart();
4921 SeenOtherLoops = true;
4922 return Expr;
4923 }
4924
4925 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4926
4927 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4928
4929private:
4930 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4931 : SCEVRewriteVisitor(SE), L(L) {}
4932
4933 const Loop *L;
4934 bool SeenLoopVariantSCEVUnknown = false;
4935 bool SeenOtherLoops = false;
4936};
4937
4938/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4939/// increment expression in case its Loop is L. If it is not L then
4940/// use AddRec itself.
4941/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4942class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4943public:
4944 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4945 SCEVPostIncRewriter Rewriter(L, SE);
4946 const SCEV *Result = Rewriter.visit(S);
4947 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4948 ? SE.getCouldNotCompute()
4949 : Result;
4950 }
4951
4952 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4953 if (!SE.isLoopInvariant(Expr, L))
4954 SeenLoopVariantSCEVUnknown = true;
4955 return Expr;
4956 }
4957
4958 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4959 // Only re-write AddRecExprs for this loop.
4960 if (Expr->getLoop() == L)
4961 return Expr->getPostIncExpr(SE);
4962 SeenOtherLoops = true;
4963 return Expr;
4964 }
4965
4966 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4967
4968 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4969
4970private:
4971 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4972 : SCEVRewriteVisitor(SE), L(L) {}
4973
4974 const Loop *L;
4975 bool SeenLoopVariantSCEVUnknown = false;
4976 bool SeenOtherLoops = false;
4977};
4978
4979/// This class evaluates the compare condition by matching it against the
4980/// condition of loop latch. If there is a match we assume a true value
4981/// for the condition while building SCEV nodes.
4982class SCEVBackedgeConditionFolder
4983 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4984public:
4985 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4986 ScalarEvolution &SE) {
4987 bool IsPosBECond = false;
4988 Value *BECond = nullptr;
4989 if (BasicBlock *Latch = L->getLoopLatch()) {
4990 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4991 if (BI && BI->isConditional()) {
4992 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4993 "Both outgoing branches should not target same header!");
4994 BECond = BI->getCondition();
4995 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4996 } else {
4997 return S;
4998 }
4999 }
5000 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5001 return Rewriter.visit(S);
5002 }
5003
5004 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5005 const SCEV *Result = Expr;
5006 bool InvariantF = SE.isLoopInvariant(Expr, L);
5007
5008 if (!InvariantF) {
5010 switch (I->getOpcode()) {
5011 case Instruction::Select: {
5012 SelectInst *SI = cast<SelectInst>(I);
5013 std::optional<const SCEV *> Res =
5014 compareWithBackedgeCondition(SI->getCondition());
5015 if (Res) {
5016 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5017 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5018 }
5019 break;
5020 }
5021 default: {
5022 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5023 if (Res)
5024 Result = *Res;
5025 break;
5026 }
5027 }
5028 }
5029 return Result;
5030 }
5031
5032private:
5033 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5034 bool IsPosBECond, ScalarEvolution &SE)
5035 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5036 IsPositiveBECond(IsPosBECond) {}
5037
5038 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5039
5040 const Loop *L;
5041 /// Loop back condition.
5042 Value *BackedgeCond = nullptr;
5043 /// Set to true if loop back is on positive branch condition.
5044 bool IsPositiveBECond;
5045};
5046
5047std::optional<const SCEV *>
5048SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5049
5050 // If value matches the backedge condition for loop latch,
5051 // then return a constant evolution node based on loopback
5052 // branch taken.
5053 if (BackedgeCond == IC)
5054 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5056 return std::nullopt;
5057}
5058
5059class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5060public:
5061 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5062 ScalarEvolution &SE) {
5063 SCEVShiftRewriter Rewriter(L, SE);
5064 const SCEV *Result = Rewriter.visit(S);
5065 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5066 }
5067
5068 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5069 // Only allow AddRecExprs for this loop.
5070 if (!SE.isLoopInvariant(Expr, L))
5071 Valid = false;
5072 return Expr;
5073 }
5074
5075 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5076 if (Expr->getLoop() == L && Expr->isAffine())
5077 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5078 Valid = false;
5079 return Expr;
5080 }
5081
5082 bool isValid() { return Valid; }
5083
5084private:
5085 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5086 : SCEVRewriteVisitor(SE), L(L) {}
5087
5088 const Loop *L;
5089 bool Valid = true;
5090};
5091
5092} // end anonymous namespace
5093
5095ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5096 if (!AR->isAffine())
5097 return SCEV::FlagAnyWrap;
5098
5099 using OBO = OverflowingBinaryOperator;
5100
5102
5103 if (!AR->hasNoSelfWrap()) {
5104 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5105 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5106 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5107 const APInt &BECountAP = BECountMax->getAPInt();
5108 unsigned NoOverflowBitWidth =
5109 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5110 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5112 }
5113 }
5114
5115 if (!AR->hasNoSignedWrap()) {
5116 ConstantRange AddRecRange = getSignedRange(AR);
5117 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5118
5120 Instruction::Add, IncRange, OBO::NoSignedWrap);
5121 if (NSWRegion.contains(AddRecRange))
5123 }
5124
5125 if (!AR->hasNoUnsignedWrap()) {
5126 ConstantRange AddRecRange = getUnsignedRange(AR);
5127 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5128
5130 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5131 if (NUWRegion.contains(AddRecRange))
5133 }
5134
5135 return Result;
5136}
5137
5139ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5141
5142 if (AR->hasNoSignedWrap())
5143 return Result;
5144
5145 if (!AR->isAffine())
5146 return Result;
5147
5148 // This function can be expensive, only try to prove NSW once per AddRec.
5149 if (!SignedWrapViaInductionTried.insert(AR).second)
5150 return Result;
5151
5152 const SCEV *Step = AR->getStepRecurrence(*this);
5153 const Loop *L = AR->getLoop();
5154
5155 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5156 // Note that this serves two purposes: It filters out loops that are
5157 // simply not analyzable, and it covers the case where this code is
5158 // being called from within backedge-taken count analysis, such that
5159 // attempting to ask for the backedge-taken count would likely result
5160 // in infinite recursion. In the later case, the analysis code will
5161 // cope with a conservative value, and it will take care to purge
5162 // that value once it has finished.
5163 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5164
5165 // Normally, in the cases we can prove no-overflow via a
5166 // backedge guarding condition, we can also compute a backedge
5167 // taken count for the loop. The exceptions are assumptions and
5168 // guards present in the loop -- SCEV is not great at exploiting
5169 // these to compute max backedge taken counts, but can still use
5170 // these to prove lack of overflow. Use this fact to avoid
5171 // doing extra work that may not pay off.
5172
5173 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5174 AC.assumptions().empty())
5175 return Result;
5176
5177 // If the backedge is guarded by a comparison with the pre-inc value the
5178 // addrec is safe. Also, if the entry is guarded by a comparison with the
5179 // start value and the backedge is guarded by a comparison with the post-inc
5180 // value, the addrec is safe.
5182 const SCEV *OverflowLimit =
5183 getSignedOverflowLimitForStep(Step, &Pred, this);
5184 if (OverflowLimit &&
5185 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5186 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5187 Result = setFlags(Result, SCEV::FlagNSW);
5188 }
5189 return Result;
5190}
5192ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5194
5195 if (AR->hasNoUnsignedWrap())
5196 return Result;
5197
5198 if (!AR->isAffine())
5199 return Result;
5200
5201 // This function can be expensive, only try to prove NUW once per AddRec.
5202 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5203 return Result;
5204
5205 const SCEV *Step = AR->getStepRecurrence(*this);
5206 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5207 const Loop *L = AR->getLoop();
5208
5209 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5210 // Note that this serves two purposes: It filters out loops that are
5211 // simply not analyzable, and it covers the case where this code is
5212 // being called from within backedge-taken count analysis, such that
5213 // attempting to ask for the backedge-taken count would likely result
5214 // in infinite recursion. In the later case, the analysis code will
5215 // cope with a conservative value, and it will take care to purge
5216 // that value once it has finished.
5217 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5218
5219 // Normally, in the cases we can prove no-overflow via a
5220 // backedge guarding condition, we can also compute a backedge
5221 // taken count for the loop. The exceptions are assumptions and
5222 // guards present in the loop -- SCEV is not great at exploiting
5223 // these to compute max backedge taken counts, but can still use
5224 // these to prove lack of overflow. Use this fact to avoid
5225 // doing extra work that may not pay off.
5226
5227 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5228 AC.assumptions().empty())
5229 return Result;
5230
5231 // If the backedge is guarded by a comparison with the pre-inc value the
5232 // addrec is safe. Also, if the entry is guarded by a comparison with the
5233 // start value and the backedge is guarded by a comparison with the post-inc
5234 // value, the addrec is safe.
5235 if (isKnownPositive(Step)) {
5236 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5237 getUnsignedRangeMax(Step));
5240 Result = setFlags(Result, SCEV::FlagNUW);
5241 }
5242 }
5243
5244 return Result;
5245}
5246
5247namespace {
5248
5249/// Represents an abstract binary operation. This may exist as a
5250/// normal instruction or constant expression, or may have been
5251/// derived from an expression tree.
5252struct BinaryOp {
5253 unsigned Opcode;
5254 Value *LHS;
5255 Value *RHS;
5256 bool IsNSW = false;
5257 bool IsNUW = false;
5258
5259 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5260 /// constant expression.
5261 Operator *Op = nullptr;
5262
5263 explicit BinaryOp(Operator *Op)
5264 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5265 Op(Op) {
5266 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5267 IsNSW = OBO->hasNoSignedWrap();
5268 IsNUW = OBO->hasNoUnsignedWrap();
5269 }
5270 }
5271
5272 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5273 bool IsNUW = false)
5274 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5275};
5276
5277} // end anonymous namespace
5278
5279/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5280static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5281 AssumptionCache &AC,
5282 const DominatorTree &DT,
5283 const Instruction *CxtI) {
5284 auto *Op = dyn_cast<Operator>(V);
5285 if (!Op)
5286 return std::nullopt;
5287
5288 // Implementation detail: all the cleverness here should happen without
5289 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5290 // SCEV expressions when possible, and we should not break that.
5291
5292 switch (Op->getOpcode()) {
5293 case Instruction::Add:
5294 case Instruction::Sub:
5295 case Instruction::Mul:
5296 case Instruction::UDiv:
5297 case Instruction::URem:
5298 case Instruction::And:
5299 case Instruction::AShr:
5300 case Instruction::Shl:
5301 return BinaryOp(Op);
5302
5303 case Instruction::Or: {
5304 // Convert or disjoint into add nuw nsw.
5305 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5306 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5307 /*IsNSW=*/true, /*IsNUW=*/true);
5308 return BinaryOp(Op);
5309 }
5310
5311 case Instruction::Xor:
5312 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5313 // If the RHS of the xor is a signmask, then this is just an add.
5314 // Instcombine turns add of signmask into xor as a strength reduction step.
5315 if (RHSC->getValue().isSignMask())
5316 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5317 // Binary `xor` is a bit-wise `add`.
5318 if (V->getType()->isIntegerTy(1))
5319 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5320 return BinaryOp(Op);
5321
5322 case Instruction::LShr:
5323 // Turn logical shift right of a constant into a unsigned divide.
5324 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5325 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5326
5327 // If the shift count is not less than the bitwidth, the result of
5328 // the shift is undefined. Don't try to analyze it, because the
5329 // resolution chosen here may differ from the resolution chosen in
5330 // other parts of the compiler.
5331 if (SA->getValue().ult(BitWidth)) {
5332 Constant *X =
5333 ConstantInt::get(SA->getContext(),
5334 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5335 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5336 }
5337 }
5338 return BinaryOp(Op);
5339
5340 case Instruction::ExtractValue: {
5341 auto *EVI = cast<ExtractValueInst>(Op);
5342 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5343 break;
5344
5345 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5346 if (!WO)
5347 break;
5348
5349 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5350 bool Signed = WO->isSigned();
5351 // TODO: Should add nuw/nsw flags for mul as well.
5352 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5353 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5354
5355 // Now that we know that all uses of the arithmetic-result component of
5356 // CI are guarded by the overflow check, we can go ahead and pretend
5357 // that the arithmetic is non-overflowing.
5358 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5359 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5360 }
5361
5362 default:
5363 break;
5364 }
5365
5366 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5367 // semantics as a Sub, return a binary sub expression.
5368 if (auto *II = dyn_cast<IntrinsicInst>(V))
5369 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5370 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5371
5372 return std::nullopt;
5373}
5374
5375/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5376/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5377/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5378/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5379/// follows one of the following patterns:
5380/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5381/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5382/// If the SCEV expression of \p Op conforms with one of the expected patterns
5383/// we return the type of the truncation operation, and indicate whether the
5384/// truncated type should be treated as signed/unsigned by setting
5385/// \p Signed to true/false, respectively.
5386static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5387 bool &Signed, ScalarEvolution &SE) {
5388 // The case where Op == SymbolicPHI (that is, with no type conversions on
5389 // the way) is handled by the regular add recurrence creating logic and
5390 // would have already been triggered in createAddRecForPHI. Reaching it here
5391 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5392 // because one of the other operands of the SCEVAddExpr updating this PHI is
5393 // not invariant).
5394 //
5395 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5396 // this case predicates that allow us to prove that Op == SymbolicPHI will
5397 // be added.
5398 if (Op == SymbolicPHI)
5399 return nullptr;
5400
5401 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5402 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5403 if (SourceBits != NewBits)
5404 return nullptr;
5405
5408 if (!SExt && !ZExt)
5409 return nullptr;
5410 const SCEVTruncateExpr *Trunc =
5413 if (!Trunc)
5414 return nullptr;
5415 const SCEV *X = Trunc->getOperand();
5416 if (X != SymbolicPHI)
5417 return nullptr;
5418 Signed = SExt != nullptr;
5419 return Trunc->getType();
5420}
5421
5422static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5423 if (!PN->getType()->isIntegerTy())
5424 return nullptr;
5425 const Loop *L = LI.getLoopFor(PN->getParent());
5426 if (!L || L->getHeader() != PN->getParent())
5427 return nullptr;
5428 return L;
5429}
5430
5431// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5432// computation that updates the phi follows the following pattern:
5433// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5434// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5435// If so, try to see if it can be rewritten as an AddRecExpr under some
5436// Predicates. If successful, return them as a pair. Also cache the results
5437// of the analysis.
5438//
5439// Example usage scenario:
5440// Say the Rewriter is called for the following SCEV:
5441// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5442// where:
5443// %X = phi i64 (%Start, %BEValue)
5444// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5445// and call this function with %SymbolicPHI = %X.
5446//
5447// The analysis will find that the value coming around the backedge has
5448// the following SCEV:
5449// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5450// Upon concluding that this matches the desired pattern, the function
5451// will return the pair {NewAddRec, SmallPredsVec} where:
5452// NewAddRec = {%Start,+,%Step}
5453// SmallPredsVec = {P1, P2, P3} as follows:
5454// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5455// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5456// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5457// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5458// under the predicates {P1,P2,P3}.
5459// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5460// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5461//
5462// TODO's:
5463//
5464// 1) Extend the Induction descriptor to also support inductions that involve
5465// casts: When needed (namely, when we are called in the context of the
5466// vectorizer induction analysis), a Set of cast instructions will be
5467// populated by this method, and provided back to isInductionPHI. This is
5468// needed to allow the vectorizer to properly record them to be ignored by
5469// the cost model and to avoid vectorizing them (otherwise these casts,
5470// which are redundant under the runtime overflow checks, will be
5471// vectorized, which can be costly).
5472//
5473// 2) Support additional induction/PHISCEV patterns: We also want to support
5474// inductions where the sext-trunc / zext-trunc operations (partly) occur
5475// after the induction update operation (the induction increment):
5476//
5477// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5478// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5479//
5480// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5481// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5482//
5483// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5484std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5485ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5487
5488 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5489 // return an AddRec expression under some predicate.
5490
5491 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5492 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5493 assert(L && "Expecting an integer loop header phi");
5494
5495 // The loop may have multiple entrances or multiple exits; we can analyze
5496 // this phi as an addrec if it has a unique entry value and a unique
5497 // backedge value.
5498 Value *BEValueV = nullptr, *StartValueV = nullptr;
5499 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5500 Value *V = PN->getIncomingValue(i);
5501 if (L->contains(PN->getIncomingBlock(i))) {
5502 if (!BEValueV) {
5503 BEValueV = V;
5504 } else if (BEValueV != V) {
5505 BEValueV = nullptr;
5506 break;
5507 }
5508 } else if (!StartValueV) {
5509 StartValueV = V;
5510 } else if (StartValueV != V) {
5511 StartValueV = nullptr;
5512 break;
5513 }
5514 }
5515 if (!BEValueV || !StartValueV)
5516 return std::nullopt;
5517
5518 const SCEV *BEValue = getSCEV(BEValueV);
5519
5520 // If the value coming around the backedge is an add with the symbolic
5521 // value we just inserted, possibly with casts that we can ignore under
5522 // an appropriate runtime guard, then we found a simple induction variable!
5523 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5524 if (!Add)
5525 return std::nullopt;
5526
5527 // If there is a single occurrence of the symbolic value, possibly
5528 // casted, replace it with a recurrence.
5529 unsigned FoundIndex = Add->getNumOperands();
5530 Type *TruncTy = nullptr;
5531 bool Signed;
5532 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5533 if ((TruncTy =
5534 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5535 if (FoundIndex == e) {
5536 FoundIndex = i;
5537 break;
5538 }
5539
5540 if (FoundIndex == Add->getNumOperands())
5541 return std::nullopt;
5542
5543 // Create an add with everything but the specified operand.
5545 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5546 if (i != FoundIndex)
5547 Ops.push_back(Add->getOperand(i));
5548 const SCEV *Accum = getAddExpr(Ops);
5549
5550 // The runtime checks will not be valid if the step amount is
5551 // varying inside the loop.
5552 if (!isLoopInvariant(Accum, L))
5553 return std::nullopt;
5554
5555 // *** Part2: Create the predicates
5556
5557 // Analysis was successful: we have a phi-with-cast pattern for which we
5558 // can return an AddRec expression under the following predicates:
5559 //
5560 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5561 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5562 // P2: An Equal predicate that guarantees that
5563 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5564 // P3: An Equal predicate that guarantees that
5565 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5566 //
5567 // As we next prove, the above predicates guarantee that:
5568 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5569 //
5570 //
5571 // More formally, we want to prove that:
5572 // Expr(i+1) = Start + (i+1) * Accum
5573 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5574 //
5575 // Given that:
5576 // 1) Expr(0) = Start
5577 // 2) Expr(1) = Start + Accum
5578 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5579 // 3) Induction hypothesis (step i):
5580 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5581 //
5582 // Proof:
5583 // Expr(i+1) =
5584 // = Start + (i+1)*Accum
5585 // = (Start + i*Accum) + Accum
5586 // = Expr(i) + Accum
5587 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5588 // :: from step i
5589 //
5590 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5591 //
5592 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5593 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5594 // + Accum :: from P3
5595 //
5596 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5597 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5598 //
5599 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5600 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5601 //
5602 // By induction, the same applies to all iterations 1<=i<n:
5603 //
5604
5605 // Create a truncated addrec for which we will add a no overflow check (P1).
5606 const SCEV *StartVal = getSCEV(StartValueV);
5607 const SCEV *PHISCEV =
5608 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5609 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5610
5611 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5612 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5613 // will be constant.
5614 //
5615 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5616 // add P1.
5617 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5621 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5622 Predicates.push_back(AddRecPred);
5623 }
5624
5625 // Create the Equal Predicates P2,P3:
5626
5627 // It is possible that the predicates P2 and/or P3 are computable at
5628 // compile time due to StartVal and/or Accum being constants.
5629 // If either one is, then we can check that now and escape if either P2
5630 // or P3 is false.
5631
5632 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5633 // for each of StartVal and Accum
5634 auto getExtendedExpr = [&](const SCEV *Expr,
5635 bool CreateSignExtend) -> const SCEV * {
5636 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5637 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5638 const SCEV *ExtendedExpr =
5639 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5640 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5641 return ExtendedExpr;
5642 };
5643
5644 // Given:
5645 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5646 // = getExtendedExpr(Expr)
5647 // Determine whether the predicate P: Expr == ExtendedExpr
5648 // is known to be false at compile time
5649 auto PredIsKnownFalse = [&](const SCEV *Expr,
5650 const SCEV *ExtendedExpr) -> bool {
5651 return Expr != ExtendedExpr &&
5652 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5653 };
5654
5655 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5656 if (PredIsKnownFalse(StartVal, StartExtended)) {
5657 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5658 return std::nullopt;
5659 }
5660
5661 // The Step is always Signed (because the overflow checks are either
5662 // NSSW or NUSW)
5663 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5664 if (PredIsKnownFalse(Accum, AccumExtended)) {
5665 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5666 return std::nullopt;
5667 }
5668
5669 auto AppendPredicate = [&](const SCEV *Expr,
5670 const SCEV *ExtendedExpr) -> void {
5671 if (Expr != ExtendedExpr &&
5672 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5673 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5674 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5675 Predicates.push_back(Pred);
5676 }
5677 };
5678
5679 AppendPredicate(StartVal, StartExtended);
5680 AppendPredicate(Accum, AccumExtended);
5681
5682 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5683 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5684 // into NewAR if it will also add the runtime overflow checks specified in
5685 // Predicates.
5686 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5687
5688 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5689 std::make_pair(NewAR, Predicates);
5690 // Remember the result of the analysis for this SCEV at this locayyytion.
5691 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5692 return PredRewrite;
5693}
5694
5695std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5697 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5698 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5699 if (!L)
5700 return std::nullopt;
5701
5702 // Check to see if we already analyzed this PHI.
5703 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5704 if (I != PredicatedSCEVRewrites.end()) {
5705 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5706 I->second;
5707 // Analysis was done before and failed to create an AddRec:
5708 if (Rewrite.first == SymbolicPHI)
5709 return std::nullopt;
5710 // Analysis was done before and succeeded to create an AddRec under
5711 // a predicate:
5712 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5713 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5714 return Rewrite;
5715 }
5716
5717 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5718 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5719
5720 // Record in the cache that the analysis failed
5721 if (!Rewrite) {
5723 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5724 return std::nullopt;
5725 }
5726
5727 return Rewrite;
5728}
5729
5730// FIXME: This utility is currently required because the Rewriter currently
5731// does not rewrite this expression:
5732// {0, +, (sext ix (trunc iy to ix) to iy)}
5733// into {0, +, %step},
5734// even when the following Equal predicate exists:
5735// "%step == (sext ix (trunc iy to ix) to iy)".
5737 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5738 if (AR1 == AR2)
5739 return true;
5740
5741 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5742 if (Expr1 != Expr2 &&
5743 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5744 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5745 return false;
5746 return true;
5747 };
5748
5749 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5750 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5751 return false;
5752 return true;
5753}
5754
5755/// A helper function for createAddRecFromPHI to handle simple cases.
5756///
5757/// This function tries to find an AddRec expression for the simplest (yet most
5758/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5759/// If it fails, createAddRecFromPHI will use a more general, but slow,
5760/// technique for finding the AddRec expression.
5761const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5762 Value *BEValueV,
5763 Value *StartValueV) {
5764 const Loop *L = LI.getLoopFor(PN->getParent());
5765 assert(L && L->getHeader() == PN->getParent());
5766 assert(BEValueV && StartValueV);
5767
5768 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5769 if (!BO)
5770 return nullptr;
5771
5772 if (BO->Opcode != Instruction::Add)
5773 return nullptr;
5774
5775 const SCEV *Accum = nullptr;
5776 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5777 Accum = getSCEV(BO->RHS);
5778 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5779 Accum = getSCEV(BO->LHS);
5780
5781 if (!Accum)
5782 return nullptr;
5783
5785 if (BO->IsNUW)
5786 Flags = setFlags(Flags, SCEV::FlagNUW);
5787 if (BO->IsNSW)
5788 Flags = setFlags(Flags, SCEV::FlagNSW);
5789
5790 const SCEV *StartVal = getSCEV(StartValueV);
5791 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5792 insertValueToMap(PN, PHISCEV);
5793
5794 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5795 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5797 proveNoWrapViaConstantRanges(AR)));
5798 }
5799
5800 // We can add Flags to the post-inc expression only if we
5801 // know that it is *undefined behavior* for BEValueV to
5802 // overflow.
5803 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5804 assert(isLoopInvariant(Accum, L) &&
5805 "Accum is defined outside L, but is not invariant?");
5806 if (isAddRecNeverPoison(BEInst, L))
5807 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5808 }
5809
5810 return PHISCEV;
5811}
5812
5813const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5814 const Loop *L = LI.getLoopFor(PN->getParent());
5815 if (!L || L->getHeader() != PN->getParent())
5816 return nullptr;
5817
5818 // The loop may have multiple entrances or multiple exits; we can analyze
5819 // this phi as an addrec if it has a unique entry value and a unique
5820 // backedge value.
5821 Value *BEValueV = nullptr, *StartValueV = nullptr;
5822 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5823 Value *V = PN->getIncomingValue(i);
5824 if (L->contains(PN->getIncomingBlock(i))) {
5825 if (!BEValueV) {
5826 BEValueV = V;
5827 } else if (BEValueV != V) {
5828 BEValueV = nullptr;
5829 break;
5830 }
5831 } else if (!StartValueV) {
5832 StartValueV = V;
5833 } else if (StartValueV != V) {
5834 StartValueV = nullptr;
5835 break;
5836 }
5837 }
5838 if (!BEValueV || !StartValueV)
5839 return nullptr;
5840
5841 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5842 "PHI node already processed?");
5843
5844 // First, try to find AddRec expression without creating a fictituos symbolic
5845 // value for PN.
5846 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5847 return S;
5848
5849 // Handle PHI node value symbolically.
5850 const SCEV *SymbolicName = getUnknown(PN);
5851 insertValueToMap(PN, SymbolicName);
5852
5853 // Using this symbolic name for the PHI, analyze the value coming around
5854 // the back-edge.
5855 const SCEV *BEValue = getSCEV(BEValueV);
5856
5857 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5858 // has a special value for the first iteration of the loop.
5859
5860 // If the value coming around the backedge is an add with the symbolic
5861 // value we just inserted, then we found a simple induction variable!
5862 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5863 // If there is a single occurrence of the symbolic value, replace it
5864 // with a recurrence.
5865 unsigned FoundIndex = Add->getNumOperands();
5866 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5867 if (Add->getOperand(i) == SymbolicName)
5868 if (FoundIndex == e) {
5869 FoundIndex = i;
5870 break;
5871 }
5872
5873 if (FoundIndex != Add->getNumOperands()) {
5874 // Create an add with everything but the specified operand.
5876 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5877 if (i != FoundIndex)
5878 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5879 L, *this));
5880 const SCEV *Accum = getAddExpr(Ops);
5881
5882 // This is not a valid addrec if the step amount is varying each
5883 // loop iteration, but is not itself an addrec in this loop.
5884 if (isLoopInvariant(Accum, L) ||
5885 (isa<SCEVAddRecExpr>(Accum) &&
5886 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5888
5889 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5890 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5891 if (BO->IsNUW)
5892 Flags = setFlags(Flags, SCEV::FlagNUW);
5893 if (BO->IsNSW)
5894 Flags = setFlags(Flags, SCEV::FlagNSW);
5895 }
5896 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5897 if (GEP->getOperand(0) == PN) {
5898 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5899 // If the increment has any nowrap flags, then we know the address
5900 // space cannot be wrapped around.
5901 if (NW != GEPNoWrapFlags::none())
5902 Flags = setFlags(Flags, SCEV::FlagNW);
5903 // If the GEP is nuw or nusw with non-negative offset, we know that
5904 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5905 // offset is treated as signed, while the base is unsigned.
5906 if (NW.hasNoUnsignedWrap() ||
5908 Flags = setFlags(Flags, SCEV::FlagNUW);
5909 }
5910
5911 // We cannot transfer nuw and nsw flags from subtraction
5912 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5913 // for instance.
5914 }
5915
5916 const SCEV *StartVal = getSCEV(StartValueV);
5917 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5918
5919 // Okay, for the entire analysis of this edge we assumed the PHI
5920 // to be symbolic. We now need to go back and purge all of the
5921 // entries for the scalars that use the symbolic expression.
5922 forgetMemoizedResults(SymbolicName);
5923 insertValueToMap(PN, PHISCEV);
5924
5925 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5926 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5928 proveNoWrapViaConstantRanges(AR)));
5929 }
5930
5931 // We can add Flags to the post-inc expression only if we
5932 // know that it is *undefined behavior* for BEValueV to
5933 // overflow.
5934 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5935 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5936 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5937
5938 return PHISCEV;
5939 }
5940 }
5941 } else {
5942 // Otherwise, this could be a loop like this:
5943 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5944 // In this case, j = {1,+,1} and BEValue is j.
5945 // Because the other in-value of i (0) fits the evolution of BEValue
5946 // i really is an addrec evolution.
5947 //
5948 // We can generalize this saying that i is the shifted value of BEValue
5949 // by one iteration:
5950 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5951
5952 // Do not allow refinement in rewriting of BEValue.
5953 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5954 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5955 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5956 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5957 const SCEV *StartVal = getSCEV(StartValueV);
5958 if (Start == StartVal) {
5959 // Okay, for the entire analysis of this edge we assumed the PHI
5960 // to be symbolic. We now need to go back and purge all of the
5961 // entries for the scalars that use the symbolic expression.
5962 forgetMemoizedResults(SymbolicName);
5963 insertValueToMap(PN, Shifted);
5964 return Shifted;
5965 }
5966 }
5967 }
5968
5969 // Remove the temporary PHI node SCEV that has been inserted while intending
5970 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5971 // as it will prevent later (possibly simpler) SCEV expressions to be added
5972 // to the ValueExprMap.
5973 eraseValueFromMap(PN);
5974
5975 return nullptr;
5976}
5977
5978// Try to match a control flow sequence that branches out at BI and merges back
5979// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5980// match.
5982 Value *&C, Value *&LHS, Value *&RHS) {
5983 C = BI->getCondition();
5984
5985 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5986 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5987
5988 if (!LeftEdge.isSingleEdge())
5989 return false;
5990
5991 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5992
5993 Use &LeftUse = Merge->getOperandUse(0);
5994 Use &RightUse = Merge->getOperandUse(1);
5995
5996 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5997 LHS = LeftUse;
5998 RHS = RightUse;
5999 return true;
6000 }
6001
6002 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6003 LHS = RightUse;
6004 RHS = LeftUse;
6005 return true;
6006 }
6007
6008 return false;
6009}
6010
6011const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6012 auto IsReachable =
6013 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6014 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6015 // Try to match
6016 //
6017 // br %cond, label %left, label %right
6018 // left:
6019 // br label %merge
6020 // right:
6021 // br label %merge
6022 // merge:
6023 // V = phi [ %x, %left ], [ %y, %right ]
6024 //
6025 // as "select %cond, %x, %y"
6026
6027 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6028 assert(IDom && "At least the entry block should dominate PN");
6029
6030 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6031 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6032
6033 if (BI && BI->isConditional() &&
6034 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6037 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6038 }
6039
6040 return nullptr;
6041}
6042
6043/// Returns SCEV for the first operand of a phi if all phi operands have
6044/// identical opcodes and operands
6045/// eg.
6046/// a: %add = %a + %b
6047/// br %c
6048/// b: %add1 = %a + %b
6049/// br %c
6050/// c: %phi = phi [%add, a], [%add1, b]
6051/// scev(%phi) => scev(%add)
6052const SCEV *
6053ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6054 BinaryOperator *CommonInst = nullptr;
6055 // Check if instructions are identical.
6056 for (Value *Incoming : PN->incoming_values()) {
6057 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6058 if (!IncomingInst)
6059 return nullptr;
6060 if (CommonInst) {
6061 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6062 return nullptr; // Not identical, give up
6063 } else {
6064 // Remember binary operator
6065 CommonInst = IncomingInst;
6066 }
6067 }
6068 if (!CommonInst)
6069 return nullptr;
6070
6071 // Check if SCEV exprs for instructions are identical.
6072 const SCEV *CommonSCEV = getSCEV(CommonInst);
6073 bool SCEVExprsIdentical =
6075 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6076 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6077}
6078
6079const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6080 if (const SCEV *S = createAddRecFromPHI(PN))
6081 return S;
6082
6083 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6084 // phi node for X.
6085 if (Value *V = simplifyInstruction(
6086 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6087 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6088 return getSCEV(V);
6089
6090 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6091 return S;
6092
6093 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6094 return S;
6095
6096 // If it's not a loop phi, we can't handle it yet.
6097 return getUnknown(PN);
6098}
6099
6100bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6101 SCEVTypes RootKind) {
6102 struct FindClosure {
6103 const SCEV *OperandToFind;
6104 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6105 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6106
6107 bool Found = false;
6108
6109 bool canRecurseInto(SCEVTypes Kind) const {
6110 // We can only recurse into the SCEV expression of the same effective type
6111 // as the type of our root SCEV expression, and into zero-extensions.
6112 return RootKind == Kind || NonSequentialRootKind == Kind ||
6113 scZeroExtend == Kind;
6114 };
6115
6116 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6117 : OperandToFind(OperandToFind), RootKind(RootKind),
6118 NonSequentialRootKind(
6120 RootKind)) {}
6121
6122 bool follow(const SCEV *S) {
6123 Found = S == OperandToFind;
6124
6125 return !isDone() && canRecurseInto(S->getSCEVType());
6126 }
6127
6128 bool isDone() const { return Found; }
6129 };
6130
6131 FindClosure FC(OperandToFind, RootKind);
6132 visitAll(Root, FC);
6133 return FC.Found;
6134}
6135
6136std::optional<const SCEV *>
6137ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6138 ICmpInst *Cond,
6139 Value *TrueVal,
6140 Value *FalseVal) {
6141 // Try to match some simple smax or umax patterns.
6142 auto *ICI = Cond;
6143
6144 Value *LHS = ICI->getOperand(0);
6145 Value *RHS = ICI->getOperand(1);
6146
6147 switch (ICI->getPredicate()) {
6148 case ICmpInst::ICMP_SLT:
6149 case ICmpInst::ICMP_SLE:
6150 case ICmpInst::ICMP_ULT:
6151 case ICmpInst::ICMP_ULE:
6152 std::swap(LHS, RHS);
6153 [[fallthrough]];
6154 case ICmpInst::ICMP_SGT:
6155 case ICmpInst::ICMP_SGE:
6156 case ICmpInst::ICMP_UGT:
6157 case ICmpInst::ICMP_UGE:
6158 // a > b ? a+x : b+x -> max(a, b)+x
6159 // a > b ? b+x : a+x -> min(a, b)+x
6161 bool Signed = ICI->isSigned();
6162 const SCEV *LA = getSCEV(TrueVal);
6163 const SCEV *RA = getSCEV(FalseVal);
6164 const SCEV *LS = getSCEV(LHS);
6165 const SCEV *RS = getSCEV(RHS);
6166 if (LA->getType()->isPointerTy()) {
6167 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6168 // Need to make sure we can't produce weird expressions involving
6169 // negated pointers.
6170 if (LA == LS && RA == RS)
6171 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6172 if (LA == RS && RA == LS)
6173 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6174 }
6175 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6176 if (Op->getType()->isPointerTy()) {
6179 return Op;
6180 }
6181 if (Signed)
6182 Op = getNoopOrSignExtend(Op, Ty);
6183 else
6184 Op = getNoopOrZeroExtend(Op, Ty);
6185 return Op;
6186 };
6187 LS = CoerceOperand(LS);
6188 RS = CoerceOperand(RS);
6190 break;
6191 const SCEV *LDiff = getMinusSCEV(LA, LS);
6192 const SCEV *RDiff = getMinusSCEV(RA, RS);
6193 if (LDiff == RDiff)
6194 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6195 LDiff);
6196 LDiff = getMinusSCEV(LA, RS);
6197 RDiff = getMinusSCEV(RA, LS);
6198 if (LDiff == RDiff)
6199 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6200 LDiff);
6201 }
6202 break;
6203 case ICmpInst::ICMP_NE:
6204 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6205 std::swap(TrueVal, FalseVal);
6206 [[fallthrough]];
6207 case ICmpInst::ICMP_EQ:
6208 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6211 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6212 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6213 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6214 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6215 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6216 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6217 return getAddExpr(getUMaxExpr(X, C), Y);
6218 }
6219 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6220 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6221 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6222 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6224 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6225 const SCEV *X = getSCEV(LHS);
6226 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6227 X = ZExt->getOperand();
6228 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6229 const SCEV *FalseValExpr = getSCEV(FalseVal);
6230 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6231 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6232 /*Sequential=*/true);
6233 }
6234 }
6235 break;
6236 default:
6237 break;
6238 }
6239
6240 return std::nullopt;
6241}
6242
6243static std::optional<const SCEV *>
6245 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6246 assert(CondExpr->getType()->isIntegerTy(1) &&
6247 TrueExpr->getType() == FalseExpr->getType() &&
6248 TrueExpr->getType()->isIntegerTy(1) &&
6249 "Unexpected operands of a select.");
6250
6251 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6252 // --> C + (umin_seq cond, x - C)
6253 //
6254 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6255 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6256 // --> C + (umin_seq ~cond, x - C)
6257
6258 // FIXME: while we can't legally model the case where both of the hands
6259 // are fully variable, we only require that the *difference* is constant.
6260 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6261 return std::nullopt;
6262
6263 const SCEV *X, *C;
6264 if (isa<SCEVConstant>(TrueExpr)) {
6265 CondExpr = SE->getNotSCEV(CondExpr);
6266 X = FalseExpr;
6267 C = TrueExpr;
6268 } else {
6269 X = TrueExpr;
6270 C = FalseExpr;
6271 }
6272 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6273 /*Sequential=*/true));
6274}
6275
6276static std::optional<const SCEV *>
6278 Value *FalseVal) {
6279 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6280 return std::nullopt;
6281
6282 const auto *SECond = SE->getSCEV(Cond);
6283 const auto *SETrue = SE->getSCEV(TrueVal);
6284 const auto *SEFalse = SE->getSCEV(FalseVal);
6285 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6286}
6287
6288const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6289 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6290 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6291 assert(TrueVal->getType() == FalseVal->getType() &&
6292 V->getType() == TrueVal->getType() &&
6293 "Types of select hands and of the result must match.");
6294
6295 // For now, only deal with i1-typed `select`s.
6296 if (!V->getType()->isIntegerTy(1))
6297 return getUnknown(V);
6298
6299 if (std::optional<const SCEV *> S =
6300 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6301 return *S;
6302
6303 return getUnknown(V);
6304}
6305
6306const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6307 Value *TrueVal,
6308 Value *FalseVal) {
6309 // Handle "constant" branch or select. This can occur for instance when a
6310 // loop pass transforms an inner loop and moves on to process the outer loop.
6311 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6312 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6313
6314 if (auto *I = dyn_cast<Instruction>(V)) {
6315 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6316 if (std::optional<const SCEV *> S =
6317 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6318 TrueVal, FalseVal))
6319 return *S;
6320 }
6321 }
6322
6323 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6324}
6325
6326/// Expand GEP instructions into add and multiply operations. This allows them
6327/// to be analyzed by regular SCEV code.
6328const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6329 assert(GEP->getSourceElementType()->isSized() &&
6330 "GEP source element type must be sized");
6331
6333 for (Value *Index : GEP->indices())
6334 IndexExprs.push_back(getSCEV(Index));
6335 return getGEPExpr(GEP, IndexExprs);
6336}
6337
6338APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6339 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6340 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6341 return TrailingZeros >= BitWidth
6343 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6344 };
6345 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6346 // The result is GCD of all operands results.
6347 APInt Res = getConstantMultiple(N->getOperand(0));
6348 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6350 Res, getConstantMultiple(N->getOperand(I)));
6351 return Res;
6352 };
6353
6354 switch (S->getSCEVType()) {
6355 case scConstant:
6356 return cast<SCEVConstant>(S)->getAPInt();
6357 case scPtrToInt:
6358 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6359 case scUDivExpr:
6360 case scVScale:
6361 return APInt(BitWidth, 1);
6362 case scTruncate: {
6363 // Only multiples that are a power of 2 will hold after truncation.
6364 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6365 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6366 return GetShiftedByZeros(TZ);
6367 }
6368 case scZeroExtend: {
6369 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6370 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6371 }
6372 case scSignExtend: {
6373 // Only multiples that are a power of 2 will hold after sext.
6374 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6375 uint32_t TZ = getMinTrailingZeros(E->getOperand());
6376 return GetShiftedByZeros(TZ);
6377 }
6378 case scMulExpr: {
6379 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6380 if (M->hasNoUnsignedWrap()) {
6381 // The result is the product of all operand results.
6382 APInt Res = getConstantMultiple(M->getOperand(0));
6383 for (const SCEV *Operand : M->operands().drop_front())
6384 Res = Res * getConstantMultiple(Operand);
6385 return Res;
6386 }
6387
6388 // If there are no wrap guarentees, find the trailing zeros, which is the
6389 // sum of trailing zeros for all its operands.
6390 uint32_t TZ = 0;
6391 for (const SCEV *Operand : M->operands())
6392 TZ += getMinTrailingZeros(Operand);
6393 return GetShiftedByZeros(TZ);
6394 }
6395 case scAddExpr:
6396 case scAddRecExpr: {
6397 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6398 if (N->hasNoUnsignedWrap())
6399 return GetGCDMultiple(N);
6400 // Find the trailing bits, which is the minimum of its operands.
6401 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6402 for (const SCEV *Operand : N->operands().drop_front())
6403 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6404 return GetShiftedByZeros(TZ);
6405 }
6406 case scUMaxExpr:
6407 case scSMaxExpr:
6408 case scUMinExpr:
6409 case scSMinExpr:
6411 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6412 case scUnknown: {
6413 // ask ValueTracking for known bits
6414 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6415 unsigned Known =
6416 computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6417 .countMinTrailingZeros();
6418 return GetShiftedByZeros(Known);
6419 }
6420 case scCouldNotCompute:
6421 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6422 }
6423 llvm_unreachable("Unknown SCEV kind!");
6424}
6425
6427 auto I = ConstantMultipleCache.find(S);
6428 if (I != ConstantMultipleCache.end())
6429 return I->second;
6430
6431 APInt Result = getConstantMultipleImpl(S);
6432 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6433 assert(InsertPair.second && "Should insert a new key");
6434 return InsertPair.first->second;
6435}
6436
6438 APInt Multiple = getConstantMultiple(S);
6439 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6440}
6441
6443 return std::min(getConstantMultiple(S).countTrailingZeros(),
6444 (unsigned)getTypeSizeInBits(S->getType()));
6445}
6446
6447/// Helper method to assign a range to V from metadata present in the IR.
6448static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6450 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6451 return getConstantRangeFromMetadata(*MD);
6452 if (const auto *CB = dyn_cast<CallBase>(V))
6453 if (std::optional<ConstantRange> Range = CB->getRange())
6454 return Range;
6455 }
6456 if (auto *A = dyn_cast<Argument>(V))
6457 if (std::optional<ConstantRange> Range = A->getRange())
6458 return Range;
6459
6460 return std::nullopt;
6461}
6462
6464 SCEV::NoWrapFlags Flags) {
6465 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6466 AddRec->setNoWrapFlags(Flags);
6467 UnsignedRanges.erase(AddRec);
6468 SignedRanges.erase(AddRec);
6469 ConstantMultipleCache.erase(AddRec);
6470 }
6471}
6472
6473ConstantRange ScalarEvolution::
6474getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6475 const DataLayout &DL = getDataLayout();
6476
6477 unsigned BitWidth = getTypeSizeInBits(U->getType());
6478 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6479
6480 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6481 // use information about the trip count to improve our available range. Note
6482 // that the trip count independent cases are already handled by known bits.
6483 // WARNING: The definition of recurrence used here is subtly different than
6484 // the one used by AddRec (and thus most of this file). Step is allowed to
6485 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6486 // and other addrecs in the same loop (for non-affine addrecs). The code
6487 // below intentionally handles the case where step is not loop invariant.
6488 auto *P = dyn_cast<PHINode>(U->getValue());
6489 if (!P)
6490 return FullSet;
6491
6492 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6493 // even the values that are not available in these blocks may come from them,
6494 // and this leads to false-positive recurrence test.
6495 for (auto *Pred : predecessors(P->getParent()))
6496 if (!DT.isReachableFromEntry(Pred))
6497 return FullSet;
6498
6499 BinaryOperator *BO;
6500 Value *Start, *Step;
6501 if (!matchSimpleRecurrence(P, BO, Start, Step))
6502 return FullSet;
6503
6504 // If we found a recurrence in reachable code, we must be in a loop. Note
6505 // that BO might be in some subloop of L, and that's completely okay.
6506 auto *L = LI.getLoopFor(P->getParent());
6507 assert(L && L->getHeader() == P->getParent());
6508 if (!L->contains(BO->getParent()))
6509 // NOTE: This bailout should be an assert instead. However, asserting
6510 // the condition here exposes a case where LoopFusion is querying SCEV
6511 // with malformed loop information during the midst of the transform.
6512 // There doesn't appear to be an obvious fix, so for the moment bailout
6513 // until the caller issue can be fixed. PR49566 tracks the bug.
6514 return FullSet;
6515
6516 // TODO: Extend to other opcodes such as mul, and div
6517 switch (BO->getOpcode()) {
6518 default:
6519 return FullSet;
6520 case Instruction::AShr:
6521 case Instruction::LShr:
6522 case Instruction::Shl:
6523 break;
6524 };
6525
6526 if (BO->getOperand(0) != P)
6527 // TODO: Handle the power function forms some day.
6528 return FullSet;
6529
6530 unsigned TC = getSmallConstantMaxTripCount(L);
6531 if (!TC || TC >= BitWidth)
6532 return FullSet;
6533
6534 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6535 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6536 assert(KnownStart.getBitWidth() == BitWidth &&
6537 KnownStep.getBitWidth() == BitWidth);
6538
6539 // Compute total shift amount, being careful of overflow and bitwidths.
6540 auto MaxShiftAmt = KnownStep.getMaxValue();
6541 APInt TCAP(BitWidth, TC-1);
6542 bool Overflow = false;
6543 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6544 if (Overflow)
6545 return FullSet;
6546
6547 switch (BO->getOpcode()) {
6548 default:
6549 llvm_unreachable("filtered out above");
6550 case Instruction::AShr: {
6551 // For each ashr, three cases:
6552 // shift = 0 => unchanged value
6553 // saturation => 0 or -1
6554 // other => a value closer to zero (of the same sign)
6555 // Thus, the end value is closer to zero than the start.
6556 auto KnownEnd = KnownBits::ashr(KnownStart,
6557 KnownBits::makeConstant(TotalShift));
6558 if (KnownStart.isNonNegative())
6559 // Analogous to lshr (simply not yet canonicalized)
6560 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6561 KnownStart.getMaxValue() + 1);
6562 if (KnownStart.isNegative())
6563 // End >=u Start && End <=s Start
6564 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6565 KnownEnd.getMaxValue() + 1);
6566 break;
6567 }
6568 case Instruction::LShr: {
6569 // For each lshr, three cases:
6570 // shift = 0 => unchanged value
6571 // saturation => 0
6572 // other => a smaller positive number
6573 // Thus, the low end of the unsigned range is the last value produced.
6574 auto KnownEnd = KnownBits::lshr(KnownStart,
6575 KnownBits::makeConstant(TotalShift));
6576 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6577 KnownStart.getMaxValue() + 1);
6578 }
6579 case Instruction::Shl: {
6580 // Iff no bits are shifted out, value increases on every shift.
6581 auto KnownEnd = KnownBits::shl(KnownStart,
6582 KnownBits::makeConstant(TotalShift));
6583 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6584 return ConstantRange(KnownStart.getMinValue(),
6585 KnownEnd.getMaxValue() + 1);
6586 break;
6587 }
6588 };
6589 return FullSet;
6590}
6591
6592const ConstantRange &
6593ScalarEvolution::getRangeRefIter(const SCEV *S,
6594 ScalarEvolution::RangeSignHint SignHint) {
6595 DenseMap<const SCEV *, ConstantRange> &Cache =
6596 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6597 : SignedRanges;
6599 SmallPtrSet<const SCEV *, 8> Seen;
6600
6601 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6602 // SCEVUnknown PHI node.
6603 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6604 if (!Seen.insert(Expr).second)
6605 return;
6606 if (Cache.contains(Expr))
6607 return;
6608 switch (Expr->getSCEVType()) {
6609 case scUnknown:
6610 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6611 break;
6612 [[fallthrough]];
6613 case scConstant:
6614 case scVScale:
6615 case scTruncate:
6616 case scZeroExtend:
6617 case scSignExtend:
6618 case scPtrToInt:
6619 case scAddExpr:
6620 case scMulExpr:
6621 case scUDivExpr:
6622 case scAddRecExpr:
6623 case scUMaxExpr:
6624 case scSMaxExpr:
6625 case scUMinExpr:
6626 case scSMinExpr:
6628 WorkList.push_back(Expr);
6629 break;
6630 case scCouldNotCompute:
6631 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6632 }
6633 };
6634 AddToWorklist(S);
6635
6636 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6637 for (unsigned I = 0; I != WorkList.size(); ++I) {
6638 const SCEV *P = WorkList[I];
6639 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6640 // If it is not a `SCEVUnknown`, just recurse into operands.
6641 if (!UnknownS) {
6642 for (const SCEV *Op : P->operands())
6643 AddToWorklist(Op);
6644 continue;
6645 }
6646 // `SCEVUnknown`'s require special treatment.
6647 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6648 if (!PendingPhiRangesIter.insert(P).second)
6649 continue;
6650 for (auto &Op : reverse(P->operands()))
6651 AddToWorklist(getSCEV(Op));
6652 }
6653 }
6654
6655 if (!WorkList.empty()) {
6656 // Use getRangeRef to compute ranges for items in the worklist in reverse
6657 // order. This will force ranges for earlier operands to be computed before
6658 // their users in most cases.
6659 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6660 getRangeRef(P, SignHint);
6661
6662 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6663 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6664 PendingPhiRangesIter.erase(P);
6665 }
6666 }
6667
6668 return getRangeRef(S, SignHint, 0);
6669}
6670
6671/// Determine the range for a particular SCEV. If SignHint is
6672/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6673/// with a "cleaner" unsigned (resp. signed) representation.
6674const ConstantRange &ScalarEvolution::getRangeRef(
6675 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6676 DenseMap<const SCEV *, ConstantRange> &Cache =
6677 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6678 : SignedRanges;
6680 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6682
6683 // See if we've computed this range already.
6685 if (I != Cache.end())
6686 return I->second;
6687
6688 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6689 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6690
6691 // Switch to iteratively computing the range for S, if it is part of a deeply
6692 // nested expression.
6694 return getRangeRefIter(S, SignHint);
6695
6696 unsigned BitWidth = getTypeSizeInBits(S->getType());
6697 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6698 using OBO = OverflowingBinaryOperator;
6699
6700 // If the value has known zeros, the maximum value will have those known zeros
6701 // as well.
6702 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6703 APInt Multiple = getNonZeroConstantMultiple(S);
6704 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6705 if (!Remainder.isZero())
6706 ConservativeResult =
6707 ConstantRange(APInt::getMinValue(BitWidth),
6708 APInt::getMaxValue(BitWidth) - Remainder + 1);
6709 }
6710 else {
6711 uint32_t TZ = getMinTrailingZeros(S);
6712 if (TZ != 0) {
6713 ConservativeResult = ConstantRange(
6715 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6716 }
6717 }
6718
6719 switch (S->getSCEVType()) {
6720 case scConstant:
6721 llvm_unreachable("Already handled above.");
6722 case scVScale:
6723 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6724 case scTruncate: {
6725 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6726 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6727 return setRange(
6728 Trunc, SignHint,
6729 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6730 }
6731 case scZeroExtend: {
6732 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6733 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6734 return setRange(
6735 ZExt, SignHint,
6736 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6737 }
6738 case scSignExtend: {
6739 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6740 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6741 return setRange(
6742 SExt, SignHint,
6743 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6744 }
6745 case scPtrToInt: {
6746 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6747 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6748 return setRange(PtrToInt, SignHint, X);
6749 }
6750 case scAddExpr: {
6751 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6752 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6753 unsigned WrapType = OBO::AnyWrap;
6754 if (Add->hasNoSignedWrap())
6755 WrapType |= OBO::NoSignedWrap;
6756 if (Add->hasNoUnsignedWrap())
6757 WrapType |= OBO::NoUnsignedWrap;
6758 for (const SCEV *Op : drop_begin(Add->operands()))
6759 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6760 RangeType);
6761 return setRange(Add, SignHint,
6762 ConservativeResult.intersectWith(X, RangeType));
6763 }
6764 case scMulExpr: {
6765 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6766 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6767 for (const SCEV *Op : drop_begin(Mul->operands()))
6768 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6769 return setRange(Mul, SignHint,
6770 ConservativeResult.intersectWith(X, RangeType));
6771 }
6772 case scUDivExpr: {
6773 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6774 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6775 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6776 return setRange(UDiv, SignHint,
6777 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6778 }
6779 case scAddRecExpr: {
6780 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6781 // If there's no unsigned wrap, the value will never be less than its
6782 // initial value.
6783 if (AddRec->hasNoUnsignedWrap()) {
6784 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6785 if (!UnsignedMinValue.isZero())
6786 ConservativeResult = ConservativeResult.intersectWith(
6787 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6788 }
6789
6790 // If there's no signed wrap, and all the operands except initial value have
6791 // the same sign or zero, the value won't ever be:
6792 // 1: smaller than initial value if operands are non negative,
6793 // 2: bigger than initial value if operands are non positive.
6794 // For both cases, value can not cross signed min/max boundary.
6795 if (AddRec->hasNoSignedWrap()) {
6796 bool AllNonNeg = true;
6797 bool AllNonPos = true;
6798 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6799 if (!isKnownNonNegative(AddRec->getOperand(i)))
6800 AllNonNeg = false;
6801 if (!isKnownNonPositive(AddRec->getOperand(i)))
6802 AllNonPos = false;
6803 }
6804 if (AllNonNeg)
6805 ConservativeResult = ConservativeResult.intersectWith(
6808 RangeType);
6809 else if (AllNonPos)
6810 ConservativeResult = ConservativeResult.intersectWith(
6812 getSignedRangeMax(AddRec->getStart()) +
6813 1),
6814 RangeType);
6815 }
6816
6817 // TODO: non-affine addrec
6818 if (AddRec->isAffine()) {
6819 const SCEV *MaxBEScev =
6821 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6822 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6823
6824 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6825 // MaxBECount's active bits are all <= AddRec's bit width.
6826 if (MaxBECount.getBitWidth() > BitWidth &&
6827 MaxBECount.getActiveBits() <= BitWidth)
6828 MaxBECount = MaxBECount.trunc(BitWidth);
6829 else if (MaxBECount.getBitWidth() < BitWidth)
6830 MaxBECount = MaxBECount.zext(BitWidth);
6831
6832 if (MaxBECount.getBitWidth() == BitWidth) {
6833 auto RangeFromAffine = getRangeForAffineAR(
6834 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6835 ConservativeResult =
6836 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6837
6838 auto RangeFromFactoring = getRangeViaFactoring(
6839 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6840 ConservativeResult =
6841 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6842 }
6843 }
6844
6845 // Now try symbolic BE count and more powerful methods.
6847 const SCEV *SymbolicMaxBECount =
6849 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6850 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6851 AddRec->hasNoSelfWrap()) {
6852 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6853 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6854 ConservativeResult =
6855 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6856 }
6857 }
6858 }
6859
6860 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6861 }
6862 case scUMaxExpr:
6863 case scSMaxExpr:
6864 case scUMinExpr:
6865 case scSMinExpr:
6866 case scSequentialUMinExpr: {
6868 switch (S->getSCEVType()) {
6869 case scUMaxExpr:
6870 ID = Intrinsic::umax;
6871 break;
6872 case scSMaxExpr:
6873 ID = Intrinsic::smax;
6874 break;
6875 case scUMinExpr:
6877 ID = Intrinsic::umin;
6878 break;
6879 case scSMinExpr:
6880 ID = Intrinsic::smin;
6881 break;
6882 default:
6883 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6884 }
6885
6886 const auto *NAry = cast<SCEVNAryExpr>(S);
6887 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6888 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6889 X = X.intrinsic(
6890 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6891 return setRange(S, SignHint,
6892 ConservativeResult.intersectWith(X, RangeType));
6893 }
6894 case scUnknown: {
6895 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6896 Value *V = U->getValue();
6897
6898 // Check if the IR explicitly contains !range metadata.
6899 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6900 if (MDRange)
6901 ConservativeResult =
6902 ConservativeResult.intersectWith(*MDRange, RangeType);
6903
6904 // Use facts about recurrences in the underlying IR. Note that add
6905 // recurrences are AddRecExprs and thus don't hit this path. This
6906 // primarily handles shift recurrences.
6907 auto CR = getRangeForUnknownRecurrence(U);
6908 ConservativeResult = ConservativeResult.intersectWith(CR);
6909
6910 // See if ValueTracking can give us a useful range.
6911 const DataLayout &DL = getDataLayout();
6912 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6913 if (Known.getBitWidth() != BitWidth)
6914 Known = Known.zextOrTrunc(BitWidth);
6915
6916 // ValueTracking may be able to compute a tighter result for the number of
6917 // sign bits than for the value of those sign bits.
6918 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6919 if (U->getType()->isPointerTy()) {
6920 // If the pointer size is larger than the index size type, this can cause
6921 // NS to be larger than BitWidth. So compensate for this.
6922 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6923 int ptrIdxDiff = ptrSize - BitWidth;
6924 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6925 NS -= ptrIdxDiff;
6926 }
6927
6928 if (NS > 1) {
6929 // If we know any of the sign bits, we know all of the sign bits.
6930 if (!Known.Zero.getHiBits(NS).isZero())
6931 Known.Zero.setHighBits(NS);
6932 if (!Known.One.getHiBits(NS).isZero())
6933 Known.One.setHighBits(NS);
6934 }
6935
6936 if (Known.getMinValue() != Known.getMaxValue() + 1)
6937 ConservativeResult = ConservativeResult.intersectWith(
6938 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6939 RangeType);
6940 if (NS > 1)
6941 ConservativeResult = ConservativeResult.intersectWith(
6942 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6943 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6944 RangeType);
6945
6946 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6947 // Strengthen the range if the underlying IR value is a
6948 // global/alloca/heap allocation using the size of the object.
6949 bool CanBeNull, CanBeFreed;
6950 uint64_t DerefBytes =
6951 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6952 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6953 // The highest address the object can start is DerefBytes bytes before
6954 // the end (unsigned max value). If this value is not a multiple of the
6955 // alignment, the last possible start value is the next lowest multiple
6956 // of the alignment. Note: The computations below cannot overflow,
6957 // because if they would there's no possible start address for the
6958 // object.
6959 APInt MaxVal =
6960 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6961 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6962 uint64_t Rem = MaxVal.urem(Align);
6963 MaxVal -= APInt(BitWidth, Rem);
6964 APInt MinVal = APInt::getZero(BitWidth);
6965 if (llvm::isKnownNonZero(V, DL))
6966 MinVal = Align;
6967 ConservativeResult = ConservativeResult.intersectWith(
6968 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6969 }
6970 }
6971
6972 // A range of Phi is a subset of union of all ranges of its input.
6973 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6974 // Make sure that we do not run over cycled Phis.
6975 if (PendingPhiRanges.insert(Phi).second) {
6976 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6977
6978 for (const auto &Op : Phi->operands()) {
6979 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6980 RangeFromOps = RangeFromOps.unionWith(OpRange);
6981 // No point to continue if we already have a full set.
6982 if (RangeFromOps.isFullSet())
6983 break;
6984 }
6985 ConservativeResult =
6986 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6987 bool Erased = PendingPhiRanges.erase(Phi);
6988 assert(Erased && "Failed to erase Phi properly?");
6989 (void)Erased;
6990 }
6991 }
6992
6993 // vscale can't be equal to zero
6994 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6995 if (II->getIntrinsicID() == Intrinsic::vscale) {
6996 ConstantRange Disallowed = APInt::getZero(BitWidth);
6997 ConservativeResult = ConservativeResult.difference(Disallowed);
6998 }
6999
7000 return setRange(U, SignHint, std::move(ConservativeResult));
7001 }
7002 case scCouldNotCompute:
7003 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7004 }
7005
7006 return setRange(S, SignHint, std::move(ConservativeResult));
7007}
7008
7009// Given a StartRange, Step and MaxBECount for an expression compute a range of
7010// values that the expression can take. Initially, the expression has a value
7011// from StartRange and then is changed by Step up to MaxBECount times. Signed
7012// argument defines if we treat Step as signed or unsigned.
7014 const ConstantRange &StartRange,
7015 const APInt &MaxBECount,
7016 bool Signed) {
7017 unsigned BitWidth = Step.getBitWidth();
7018 assert(BitWidth == StartRange.getBitWidth() &&
7019 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7020 // If either Step or MaxBECount is 0, then the expression won't change, and we
7021 // just need to return the initial range.
7022 if (Step == 0 || MaxBECount == 0)
7023 return StartRange;
7024
7025 // If we don't know anything about the initial value (i.e. StartRange is
7026 // FullRange), then we don't know anything about the final range either.
7027 // Return FullRange.
7028 if (StartRange.isFullSet())
7029 return ConstantRange::getFull(BitWidth);
7030
7031 // If Step is signed and negative, then we use its absolute value, but we also
7032 // note that we're moving in the opposite direction.
7033 bool Descending = Signed && Step.isNegative();
7034
7035 if (Signed)
7036 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7037 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7038 // This equations hold true due to the well-defined wrap-around behavior of
7039 // APInt.
7040 Step = Step.abs();
7041
7042 // Check if Offset is more than full span of BitWidth. If it is, the
7043 // expression is guaranteed to overflow.
7044 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7045 return ConstantRange::getFull(BitWidth);
7046
7047 // Offset is by how much the expression can change. Checks above guarantee no
7048 // overflow here.
7049 APInt Offset = Step * MaxBECount;
7050
7051 // Minimum value of the final range will match the minimal value of StartRange
7052 // if the expression is increasing and will be decreased by Offset otherwise.
7053 // Maximum value of the final range will match the maximal value of StartRange
7054 // if the expression is decreasing and will be increased by Offset otherwise.
7055 APInt StartLower = StartRange.getLower();
7056 APInt StartUpper = StartRange.getUpper() - 1;
7057 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7058 : (StartUpper + std::move(Offset));
7059
7060 // It's possible that the new minimum/maximum value will fall into the initial
7061 // range (due to wrap around). This means that the expression can take any
7062 // value in this bitwidth, and we have to return full range.
7063 if (StartRange.contains(MovedBoundary))
7064 return ConstantRange::getFull(BitWidth);
7065
7066 APInt NewLower =
7067 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7068 APInt NewUpper =
7069 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7070 NewUpper += 1;
7071
7072 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7073 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7074}
7075
7076ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7077 const SCEV *Step,
7078 const APInt &MaxBECount) {
7079 assert(getTypeSizeInBits(Start->getType()) ==
7080 getTypeSizeInBits(Step->getType()) &&
7081 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7082 "mismatched bit widths");
7083
7084 // First, consider step signed.
7085 ConstantRange StartSRange = getSignedRange(Start);
7086 ConstantRange StepSRange = getSignedRange(Step);
7087
7088 // If Step can be both positive and negative, we need to find ranges for the
7089 // maximum absolute step values in both directions and union them.
7090 ConstantRange SR = getRangeForAffineARHelper(
7091 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7093 StartSRange, MaxBECount,
7094 /* Signed = */ true));
7095
7096 // Next, consider step unsigned.
7097 ConstantRange UR = getRangeForAffineARHelper(
7098 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7099 /* Signed = */ false);
7100
7101 // Finally, intersect signed and unsigned ranges.
7103}
7104
7105ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7106 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7107 ScalarEvolution::RangeSignHint SignHint) {
7108 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7109 assert(AddRec->hasNoSelfWrap() &&
7110 "This only works for non-self-wrapping AddRecs!");
7111 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7112 const SCEV *Step = AddRec->getStepRecurrence(*this);
7113 // Only deal with constant step to save compile time.
7114 if (!isa<SCEVConstant>(Step))
7115 return ConstantRange::getFull(BitWidth);
7116 // Let's make sure that we can prove that we do not self-wrap during
7117 // MaxBECount iterations. We need this because MaxBECount is a maximum
7118 // iteration count estimate, and we might infer nw from some exit for which we
7119 // do not know max exit count (or any other side reasoning).
7120 // TODO: Turn into assert at some point.
7121 if (getTypeSizeInBits(MaxBECount->getType()) >
7122 getTypeSizeInBits(AddRec->getType()))
7123 return ConstantRange::getFull(BitWidth);
7124 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7125 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7126 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7127 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7128 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7129 MaxItersWithoutWrap))
7130 return ConstantRange::getFull(BitWidth);
7131
7132 ICmpInst::Predicate LEPred =
7134 ICmpInst::Predicate GEPred =
7136 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7137
7138 // We know that there is no self-wrap. Let's take Start and End values and
7139 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7140 // the iteration. They either lie inside the range [Min(Start, End),
7141 // Max(Start, End)] or outside it:
7142 //
7143 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7144 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7145 //
7146 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7147 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7148 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7149 // Start <= End and step is positive, or Start >= End and step is negative.
7150 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7151 ConstantRange StartRange = getRangeRef(Start, SignHint);
7152 ConstantRange EndRange = getRangeRef(End, SignHint);
7153 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7154 // If they already cover full iteration space, we will know nothing useful
7155 // even if we prove what we want to prove.
7156 if (RangeBetween.isFullSet())
7157 return RangeBetween;
7158 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7159 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7160 : RangeBetween.isWrappedSet();
7161 if (IsWrappedSet)
7162 return ConstantRange::getFull(BitWidth);
7163
7164 if (isKnownPositive(Step) &&
7165 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7166 return RangeBetween;
7167 if (isKnownNegative(Step) &&
7168 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7169 return RangeBetween;
7170 return ConstantRange::getFull(BitWidth);
7171}
7172
7173ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7174 const SCEV *Step,
7175 const APInt &MaxBECount) {
7176 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7177 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7178
7179 unsigned BitWidth = MaxBECount.getBitWidth();
7180 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7181 getTypeSizeInBits(Step->getType()) == BitWidth &&
7182 "mismatched bit widths");
7183
7184 struct SelectPattern {
7185 Value *Condition = nullptr;
7186 APInt TrueValue;
7187 APInt FalseValue;
7188
7189 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7190 const SCEV *S) {
7191 std::optional<unsigned> CastOp;
7192 APInt Offset(BitWidth, 0);
7193
7195 "Should be!");
7196
7197 // Peel off a constant offset. In the future we could consider being
7198 // smarter here and handle {Start+Step,+,Step} too.
7199 const APInt *Off;
7200 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7201 Offset = *Off;
7202
7203 // Peel off a cast operation
7204 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7205 CastOp = SCast->getSCEVType();
7206 S = SCast->getOperand();
7207 }
7208
7209 using namespace llvm::PatternMatch;
7210
7211 auto *SU = dyn_cast<SCEVUnknown>(S);
7212 const APInt *TrueVal, *FalseVal;
7213 if (!SU ||
7214 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7215 m_APInt(FalseVal)))) {
7216 Condition = nullptr;
7217 return;
7218 }
7219
7220 TrueValue = *TrueVal;
7221 FalseValue = *FalseVal;
7222
7223 // Re-apply the cast we peeled off earlier
7224 if (CastOp)
7225 switch (*CastOp) {
7226 default:
7227 llvm_unreachable("Unknown SCEV cast type!");
7228
7229 case scTruncate:
7230 TrueValue = TrueValue.trunc(BitWidth);
7231 FalseValue = FalseValue.trunc(BitWidth);
7232 break;
7233 case scZeroExtend:
7234 TrueValue = TrueValue.zext(BitWidth);
7235 FalseValue = FalseValue.zext(BitWidth);
7236 break;
7237 case scSignExtend:
7238 TrueValue = TrueValue.sext(BitWidth);
7239 FalseValue = FalseValue.sext(BitWidth);
7240 break;
7241 }
7242
7243 // Re-apply the constant offset we peeled off earlier
7244 TrueValue += Offset;
7245 FalseValue += Offset;
7246 }
7247
7248 bool isRecognized() { return Condition != nullptr; }
7249 };
7250
7251 SelectPattern StartPattern(*this, BitWidth, Start);
7252 if (!StartPattern.isRecognized())
7253 return ConstantRange::getFull(BitWidth);
7254
7255 SelectPattern StepPattern(*this, BitWidth, Step);
7256 if (!StepPattern.isRecognized())
7257 return ConstantRange::getFull(BitWidth);
7258
7259 if (StartPattern.Condition != StepPattern.Condition) {
7260 // We don't handle this case today; but we could, by considering four
7261 // possibilities below instead of two. I'm not sure if there are cases where
7262 // that will help over what getRange already does, though.
7263 return ConstantRange::getFull(BitWidth);
7264 }
7265
7266 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7267 // construct arbitrary general SCEV expressions here. This function is called
7268 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7269 // say) can end up caching a suboptimal value.
7270
7271 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7272 // C2352 and C2512 (otherwise it isn't needed).
7273
7274 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7275 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7276 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7277 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7278
7279 ConstantRange TrueRange =
7280 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7281 ConstantRange FalseRange =
7282 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7283
7284 return TrueRange.unionWith(FalseRange);
7285}
7286
7287SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7288 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7289 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7290
7291 // Return early if there are no flags to propagate to the SCEV.
7293 if (BinOp->hasNoUnsignedWrap())
7295 if (BinOp->hasNoSignedWrap())
7297 if (Flags == SCEV::FlagAnyWrap)
7298 return SCEV::FlagAnyWrap;
7299
7300 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7301}
7302
7303const Instruction *
7304ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7305 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7306 return &*AddRec->getLoop()->getHeader()->begin();
7307 if (auto *U = dyn_cast<SCEVUnknown>(S))
7308 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7309 return I;
7310 return nullptr;
7311}
7312
7313const Instruction *
7314ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7315 bool &Precise) {
7316 Precise = true;
7317 // Do a bounded search of the def relation of the requested SCEVs.
7318 SmallPtrSet<const SCEV *, 16> Visited;
7320 auto pushOp = [&](const SCEV *S) {
7321 if (!Visited.insert(S).second)
7322 return;
7323 // Threshold of 30 here is arbitrary.
7324 if (Visited.size() > 30) {
7325 Precise = false;
7326 return;
7327 }
7328 Worklist.push_back(S);
7329 };
7330
7331 for (const auto *S : Ops)
7332 pushOp(S);
7333
7334 const Instruction *Bound = nullptr;
7335 while (!Worklist.empty()) {
7336 auto *S = Worklist.pop_back_val();
7337 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7338 if (!Bound || DT.dominates(Bound, DefI))
7339 Bound = DefI;
7340 } else {
7341 for (const auto *Op : S->operands())
7342 pushOp(Op);
7343 }
7344 }
7345 return Bound ? Bound : &*F.getEntryBlock().begin();
7346}
7347
7348const Instruction *
7349ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7350 bool Discard;
7351 return getDefiningScopeBound(Ops, Discard);
7352}
7353
7354bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7355 const Instruction *B) {
7356 if (A->getParent() == B->getParent() &&
7358 B->getIterator()))
7359 return true;
7360
7361 auto *BLoop = LI.getLoopFor(B->getParent());
7362 if (BLoop && BLoop->getHeader() == B->getParent() &&
7363 BLoop->getLoopPreheader() == A->getParent() &&
7365 A->getParent()->end()) &&
7366 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7367 B->getIterator()))
7368 return true;
7369 return false;
7370}
7371
7372bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7373 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7374 visitAll(Op, PC);
7375 return PC.MaybePoison.empty();
7376}
7377
7378bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7379 return !SCEVExprContains(Op, [this](const SCEV *S) {
7380 const SCEV *Op1;
7381 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7382 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7383 // is a non-zero constant, we have to assume the UDiv may be UB.
7384 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7385 });
7386}
7387
7388bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7389 // Only proceed if we can prove that I does not yield poison.
7391 return false;
7392
7393 // At this point we know that if I is executed, then it does not wrap
7394 // according to at least one of NSW or NUW. If I is not executed, then we do
7395 // not know if the calculation that I represents would wrap. Multiple
7396 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7397 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7398 // derived from other instructions that map to the same SCEV. We cannot make
7399 // that guarantee for cases where I is not executed. So we need to find a
7400 // upper bound on the defining scope for the SCEV, and prove that I is
7401 // executed every time we enter that scope. When the bounding scope is a
7402 // loop (the common case), this is equivalent to proving I executes on every
7403 // iteration of that loop.
7405 for (const Use &Op : I->operands()) {
7406 // I could be an extractvalue from a call to an overflow intrinsic.
7407 // TODO: We can do better here in some cases.
7408 if (isSCEVable(Op->getType()))
7409 SCEVOps.push_back(getSCEV(Op));
7410 }
7411 auto *DefI = getDefiningScopeBound(SCEVOps);
7412 return isGuaranteedToTransferExecutionTo(DefI, I);
7413}
7414
7415bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7416 // If we know that \c I can never be poison period, then that's enough.
7417 if (isSCEVExprNeverPoison(I))
7418 return true;
7419
7420 // If the loop only has one exit, then we know that, if the loop is entered,
7421 // any instruction dominating that exit will be executed. If any such
7422 // instruction would result in UB, the addrec cannot be poison.
7423 //
7424 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7425 // also handles uses outside the loop header (they just need to dominate the
7426 // single exit).
7427
7428 auto *ExitingBB = L->getExitingBlock();
7429 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7430 return false;
7431
7432 SmallPtrSet<const Value *, 16> KnownPoison;
7434
7435 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7436 // things that are known to be poison under that assumption go on the
7437 // Worklist.
7438 KnownPoison.insert(I);
7439 Worklist.push_back(I);
7440
7441 while (!Worklist.empty()) {
7442 const Instruction *Poison = Worklist.pop_back_val();
7443
7444 for (const Use &U : Poison->uses()) {
7445 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7446 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7447 DT.dominates(PoisonUser->getParent(), ExitingBB))
7448 return true;
7449
7450 if (propagatesPoison(U) && L->contains(PoisonUser))
7451 if (KnownPoison.insert(PoisonUser).second)
7452 Worklist.push_back(PoisonUser);
7453 }
7454 }
7455
7456 return false;
7457}
7458
7459ScalarEvolution::LoopProperties
7460ScalarEvolution::getLoopProperties(const Loop *L) {
7461 using LoopProperties = ScalarEvolution::LoopProperties;
7462
7463 auto Itr = LoopPropertiesCache.find(L);
7464 if (Itr == LoopPropertiesCache.end()) {
7465 auto HasSideEffects = [](Instruction *I) {
7466 if (auto *SI = dyn_cast<StoreInst>(I))
7467 return !SI->isSimple();
7468
7469 if (I->mayThrow())
7470 return true;
7471
7472 // Non-volatile memset / memcpy do not count as side-effect for forward
7473 // progress.
7474 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7475 return false;
7476
7477 return I->mayWriteToMemory();
7478 };
7479
7480 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7481 /*HasNoSideEffects*/ true};
7482
7483 for (auto *BB : L->getBlocks())
7484 for (auto &I : *BB) {
7486 LP.HasNoAbnormalExits = false;
7487 if (HasSideEffects(&I))
7488 LP.HasNoSideEffects = false;
7489 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7490 break; // We're already as pessimistic as we can get.
7491 }
7492
7493 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7494 assert(InsertPair.second && "We just checked!");
7495 Itr = InsertPair.first;
7496 }
7497
7498 return Itr->second;
7499}
7500
7502 // A mustprogress loop without side effects must be finite.
7503 // TODO: The check used here is very conservative. It's only *specific*
7504 // side effects which are well defined in infinite loops.
7505 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7506}
7507
7508const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7509 // Worklist item with a Value and a bool indicating whether all operands have
7510 // been visited already.
7513
7514 Stack.emplace_back(V, true);
7515 Stack.emplace_back(V, false);
7516 while (!Stack.empty()) {
7517 auto E = Stack.pop_back_val();
7518 Value *CurV = E.getPointer();
7519
7520 if (getExistingSCEV(CurV))
7521 continue;
7522
7524 const SCEV *CreatedSCEV = nullptr;
7525 // If all operands have been visited already, create the SCEV.
7526 if (E.getInt()) {
7527 CreatedSCEV = createSCEV(CurV);
7528 } else {
7529 // Otherwise get the operands we need to create SCEV's for before creating
7530 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7531 // just use it.
7532 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7533 }
7534
7535 if (CreatedSCEV) {
7536 insertValueToMap(CurV, CreatedSCEV);
7537 } else {
7538 // Queue CurV for SCEV creation, followed by its's operands which need to
7539 // be constructed first.
7540 Stack.emplace_back(CurV, true);
7541 for (Value *Op : Ops)
7542 Stack.emplace_back(Op, false);
7543 }
7544 }
7545
7546 return getExistingSCEV(V);
7547}
7548
7549const SCEV *
7550ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7551 if (!isSCEVable(V->getType()))
7552 return getUnknown(V);
7553
7554 if (Instruction *I = dyn_cast<Instruction>(V)) {
7555 // Don't attempt to analyze instructions in blocks that aren't
7556 // reachable. Such instructions don't matter, and they aren't required
7557 // to obey basic rules for definitions dominating uses which this
7558 // analysis depends on.
7559 if (!DT.isReachableFromEntry(I->getParent()))
7560 return getUnknown(PoisonValue::get(V->getType()));
7561 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7562 return getConstant(CI);
7563 else if (isa<GlobalAlias>(V))
7564 return getUnknown(V);
7565 else if (!isa<ConstantExpr>(V))
7566 return getUnknown(V);
7567
7568 Operator *U = cast<Operator>(V);
7569 if (auto BO =
7571 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7572 switch (BO->Opcode) {
7573 case Instruction::Add:
7574 case Instruction::Mul: {
7575 // For additions and multiplications, traverse add/mul chains for which we
7576 // can potentially create a single SCEV, to reduce the number of
7577 // get{Add,Mul}Expr calls.
7578 do {
7579 if (BO->Op) {
7580 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7581 Ops.push_back(BO->Op);
7582 break;
7583 }
7584 }
7585 Ops.push_back(BO->RHS);
7586 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7588 if (!NewBO ||
7589 (BO->Opcode == Instruction::Add &&
7590 (NewBO->Opcode != Instruction::Add &&
7591 NewBO->Opcode != Instruction::Sub)) ||
7592 (BO->Opcode == Instruction::Mul &&
7593 NewBO->Opcode != Instruction::Mul)) {
7594 Ops.push_back(BO->LHS);
7595 break;
7596 }
7597 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7598 // requires a SCEV for the LHS.
7599 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7600 auto *I = dyn_cast<Instruction>(BO->Op);
7601 if (I && programUndefinedIfPoison(I)) {
7602 Ops.push_back(BO->LHS);
7603 break;
7604 }
7605 }
7606 BO = NewBO;
7607 } while (true);
7608 return nullptr;
7609 }
7610 case Instruction::Sub:
7611 case Instruction::UDiv:
7612 case Instruction::URem:
7613 break;
7614 case Instruction::AShr:
7615 case Instruction::Shl:
7616 case Instruction::Xor:
7617 if (!IsConstArg)
7618 return nullptr;
7619 break;
7620 case Instruction::And:
7621 case Instruction::Or:
7622 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7623 return nullptr;
7624 break;
7625 case Instruction::LShr:
7626 return getUnknown(V);
7627 default:
7628 llvm_unreachable("Unhandled binop");
7629 break;
7630 }
7631
7632 Ops.push_back(BO->LHS);
7633 Ops.push_back(BO->RHS);
7634 return nullptr;
7635 }
7636
7637 switch (U->getOpcode()) {
7638 case Instruction::Trunc:
7639 case Instruction::ZExt:
7640 case Instruction::SExt:
7641 case Instruction::PtrToInt:
7642 Ops.push_back(U->getOperand(0));
7643 return nullptr;
7644
7645 case Instruction::BitCast:
7646 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7647 Ops.push_back(U->getOperand(0));
7648 return nullptr;
7649 }
7650 return getUnknown(V);
7651
7652 case Instruction::SDiv:
7653 case Instruction::SRem:
7654 Ops.push_back(U->getOperand(0));
7655 Ops.push_back(U->getOperand(1));
7656 return nullptr;
7657
7658 case Instruction::GetElementPtr:
7659 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7660 "GEP source element type must be sized");
7661 llvm::append_range(Ops, U->operands());
7662 return nullptr;
7663
7664 case Instruction::IntToPtr:
7665 return getUnknown(V);
7666
7667 case Instruction::PHI:
7668 // Keep constructing SCEVs' for phis recursively for now.
7669 return nullptr;
7670
7671 case Instruction::Select: {
7672 // Check if U is a select that can be simplified to a SCEVUnknown.
7673 auto CanSimplifyToUnknown = [this, U]() {
7674 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7675 return false;
7676
7677 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7678 if (!ICI)
7679 return false;
7680 Value *LHS = ICI->getOperand(0);
7681 Value *RHS = ICI->getOperand(1);
7682 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7683 ICI->getPredicate() == CmpInst::ICMP_NE) {
7685 return true;
7686 } else if (getTypeSizeInBits(LHS->getType()) >
7687 getTypeSizeInBits(U->getType()))
7688 return true;
7689 return false;
7690 };
7691 if (CanSimplifyToUnknown())
7692 return getUnknown(U);
7693
7694 llvm::append_range(Ops, U->operands());
7695 return nullptr;
7696 break;
7697 }
7698 case Instruction::Call:
7699 case Instruction::Invoke:
7700 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7701 Ops.push_back(RV);
7702 return nullptr;
7703 }
7704
7705 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7706 switch (II->getIntrinsicID()) {
7707 case Intrinsic::abs:
7708 Ops.push_back(II->getArgOperand(0));
7709 return nullptr;
7710 case Intrinsic::umax:
7711 case Intrinsic::umin:
7712 case Intrinsic::smax:
7713 case Intrinsic::smin:
7714 case Intrinsic::usub_sat:
7715 case Intrinsic::uadd_sat:
7716 Ops.push_back(II->getArgOperand(0));
7717 Ops.push_back(II->getArgOperand(1));
7718 return nullptr;
7719 case Intrinsic::start_loop_iterations:
7720 case Intrinsic::annotation:
7721 case Intrinsic::ptr_annotation:
7722 Ops.push_back(II->getArgOperand(0));
7723 return nullptr;
7724 default:
7725 break;
7726 }
7727 }
7728 break;
7729 }
7730
7731 return nullptr;
7732}
7733
7734const SCEV *ScalarEvolution::createSCEV(Value *V) {
7735 if (!isSCEVable(V->getType()))
7736 return getUnknown(V);
7737
7738 if (Instruction *I = dyn_cast<Instruction>(V)) {
7739 // Don't attempt to analyze instructions in blocks that aren't
7740 // reachable. Such instructions don't matter, and they aren't required
7741 // to obey basic rules for definitions dominating uses which this
7742 // analysis depends on.
7743 if (!DT.isReachableFromEntry(I->getParent()))
7744 return getUnknown(PoisonValue::get(V->getType()));
7745 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7746 return getConstant(CI);
7747 else if (isa<GlobalAlias>(V))
7748 return getUnknown(V);
7749 else if (!isa<ConstantExpr>(V))
7750 return getUnknown(V);
7751
7752 const SCEV *LHS;
7753 const SCEV *RHS;
7754
7755 Operator *U = cast<Operator>(V);
7756 if (auto BO =
7758 switch (BO->Opcode) {
7759 case Instruction::Add: {
7760 // The simple thing to do would be to just call getSCEV on both operands
7761 // and call getAddExpr with the result. However if we're looking at a
7762 // bunch of things all added together, this can be quite inefficient,
7763 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7764 // Instead, gather up all the operands and make a single getAddExpr call.
7765 // LLVM IR canonical form means we need only traverse the left operands.
7767 do {
7768 if (BO->Op) {
7769 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7770 AddOps.push_back(OpSCEV);
7771 break;
7772 }
7773
7774 // If a NUW or NSW flag can be applied to the SCEV for this
7775 // addition, then compute the SCEV for this addition by itself
7776 // with a separate call to getAddExpr. We need to do that
7777 // instead of pushing the operands of the addition onto AddOps,
7778 // since the flags are only known to apply to this particular
7779 // addition - they may not apply to other additions that can be
7780 // formed with operands from AddOps.
7781 const SCEV *RHS = getSCEV(BO->RHS);
7782 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7783 if (Flags != SCEV::FlagAnyWrap) {
7784 const SCEV *LHS = getSCEV(BO->LHS);
7785 if (BO->Opcode == Instruction::Sub)
7786 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7787 else
7788 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7789 break;
7790 }
7791 }
7792
7793 if (BO->Opcode == Instruction::Sub)
7794 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7795 else
7796 AddOps.push_back(getSCEV(BO->RHS));
7797
7798 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7800 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7801 NewBO->Opcode != Instruction::Sub)) {
7802 AddOps.push_back(getSCEV(BO->LHS));
7803 break;
7804 }
7805 BO = NewBO;
7806 } while (true);
7807
7808 return getAddExpr(AddOps);
7809 }
7810
7811 case Instruction::Mul: {
7813 do {
7814 if (BO->Op) {
7815 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7816 MulOps.push_back(OpSCEV);
7817 break;
7818 }
7819
7820 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7821 if (Flags != SCEV::FlagAnyWrap) {
7822 LHS = getSCEV(BO->LHS);
7823 RHS = getSCEV(BO->RHS);
7824 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7825 break;
7826 }
7827 }
7828
7829 MulOps.push_back(getSCEV(BO->RHS));
7830 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7832 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7833 MulOps.push_back(getSCEV(BO->LHS));
7834 break;
7835 }
7836 BO = NewBO;
7837 } while (true);
7838
7839 return getMulExpr(MulOps);
7840 }
7841 case Instruction::UDiv:
7842 LHS = getSCEV(BO->LHS);
7843 RHS = getSCEV(BO->RHS);
7844 return getUDivExpr(LHS, RHS);
7845 case Instruction::URem:
7846 LHS = getSCEV(BO->LHS);
7847 RHS = getSCEV(BO->RHS);
7848 return getURemExpr(LHS, RHS);
7849 case Instruction::Sub: {
7851 if (BO->Op)
7852 Flags = getNoWrapFlagsFromUB(BO->Op);
7853 LHS = getSCEV(BO->LHS);
7854 RHS = getSCEV(BO->RHS);
7855 return getMinusSCEV(LHS, RHS, Flags);
7856 }
7857 case Instruction::And:
7858 // For an expression like x&255 that merely masks off the high bits,
7859 // use zext(trunc(x)) as the SCEV expression.
7860 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7861 if (CI->isZero())
7862 return getSCEV(BO->RHS);
7863 if (CI->isMinusOne())
7864 return getSCEV(BO->LHS);
7865 const APInt &A = CI->getValue();
7866
7867 // Instcombine's ShrinkDemandedConstant may strip bits out of
7868 // constants, obscuring what would otherwise be a low-bits mask.
7869 // Use computeKnownBits to compute what ShrinkDemandedConstant
7870 // knew about to reconstruct a low-bits mask value.
7871 unsigned LZ = A.countl_zero();
7872 unsigned TZ = A.countr_zero();
7873 unsigned BitWidth = A.getBitWidth();
7874 KnownBits Known(BitWidth);
7875 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7876
7877 APInt EffectiveMask =
7878 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7879 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7880 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7881 const SCEV *LHS = getSCEV(BO->LHS);
7882 const SCEV *ShiftedLHS = nullptr;
7883 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7884 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7885 // For an expression like (x * 8) & 8, simplify the multiply.
7886 unsigned MulZeros = OpC->getAPInt().countr_zero();
7887 unsigned GCD = std::min(MulZeros, TZ);
7888 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7890 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7891 append_range(MulOps, LHSMul->operands().drop_front());
7892 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7893 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7894 }
7895 }
7896 if (!ShiftedLHS)
7897 ShiftedLHS = getUDivExpr(LHS, MulCount);
7898 return getMulExpr(
7900 getTruncateExpr(ShiftedLHS,
7901 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7902 BO->LHS->getType()),
7903 MulCount);
7904 }
7905 }
7906 // Binary `and` is a bit-wise `umin`.
7907 if (BO->LHS->getType()->isIntegerTy(1)) {
7908 LHS = getSCEV(BO->LHS);
7909 RHS = getSCEV(BO->RHS);
7910 return getUMinExpr(LHS, RHS);
7911 }
7912 break;
7913
7914 case Instruction::Or:
7915 // Binary `or` is a bit-wise `umax`.
7916 if (BO->LHS->getType()->isIntegerTy(1)) {
7917 LHS = getSCEV(BO->LHS);
7918 RHS = getSCEV(BO->RHS);
7919 return getUMaxExpr(LHS, RHS);
7920 }
7921 break;
7922
7923 case Instruction::Xor:
7924 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7925 // If the RHS of xor is -1, then this is a not operation.
7926 if (CI->isMinusOne())
7927 return getNotSCEV(getSCEV(BO->LHS));
7928
7929 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7930 // This is a variant of the check for xor with -1, and it handles
7931 // the case where instcombine has trimmed non-demanded bits out
7932 // of an xor with -1.
7933 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7934 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7935 if (LBO->getOpcode() == Instruction::And &&
7936 LCI->getValue() == CI->getValue())
7937 if (const SCEVZeroExtendExpr *Z =
7939 Type *UTy = BO->LHS->getType();
7940 const SCEV *Z0 = Z->getOperand();
7941 Type *Z0Ty = Z0->getType();
7942 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7943
7944 // If C is a low-bits mask, the zero extend is serving to
7945 // mask off the high bits. Complement the operand and
7946 // re-apply the zext.
7947 if (CI->getValue().isMask(Z0TySize))
7948 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7949
7950 // If C is a single bit, it may be in the sign-bit position
7951 // before the zero-extend. In this case, represent the xor
7952 // using an add, which is equivalent, and re-apply the zext.
7953 APInt Trunc = CI->getValue().trunc(Z0TySize);
7954 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7955 Trunc.isSignMask())
7956 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7957 UTy);
7958 }
7959 }
7960 break;
7961
7962 case Instruction::Shl:
7963 // Turn shift left of a constant amount into a multiply.
7964 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7965 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7966
7967 // If the shift count is not less than the bitwidth, the result of
7968 // the shift is undefined. Don't try to analyze it, because the
7969 // resolution chosen here may differ from the resolution chosen in
7970 // other parts of the compiler.
7971 if (SA->getValue().uge(BitWidth))
7972 break;
7973
7974 // We can safely preserve the nuw flag in all cases. It's also safe to
7975 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7976 // requires special handling. It can be preserved as long as we're not
7977 // left shifting by bitwidth - 1.
7978 auto Flags = SCEV::FlagAnyWrap;
7979 if (BO->Op) {
7980 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7981 if ((MulFlags & SCEV::FlagNSW) &&
7982 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7984 if (MulFlags & SCEV::FlagNUW)
7986 }
7987
7988 ConstantInt *X = ConstantInt::get(
7989 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7990 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7991 }
7992 break;
7993
7994 case Instruction::AShr:
7995 // AShr X, C, where C is a constant.
7996 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7997 if (!CI)
7998 break;
7999
8000 Type *OuterTy = BO->LHS->getType();
8001 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8002 // If the shift count is not less than the bitwidth, the result of
8003 // the shift is undefined. Don't try to analyze it, because the
8004 // resolution chosen here may differ from the resolution chosen in
8005 // other parts of the compiler.
8006 if (CI->getValue().uge(BitWidth))
8007 break;
8008
8009 if (CI->isZero())
8010 return getSCEV(BO->LHS); // shift by zero --> noop
8011
8012 uint64_t AShrAmt = CI->getZExtValue();
8013 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8014
8015 Operator *L = dyn_cast<Operator>(BO->LHS);
8016 const SCEV *AddTruncateExpr = nullptr;
8017 ConstantInt *ShlAmtCI = nullptr;
8018 const SCEV *AddConstant = nullptr;
8019
8020 if (L && L->getOpcode() == Instruction::Add) {
8021 // X = Shl A, n
8022 // Y = Add X, c
8023 // Z = AShr Y, m
8024 // n, c and m are constants.
8025
8026 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8027 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8028 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8029 if (AddOperandCI) {
8030 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8031 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8032 // since we truncate to TruncTy, the AddConstant should be of the
8033 // same type, so create a new Constant with type same as TruncTy.
8034 // Also, the Add constant should be shifted right by AShr amount.
8035 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8036 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8037 // we model the expression as sext(add(trunc(A), c << n)), since the
8038 // sext(trunc) part is already handled below, we create a
8039 // AddExpr(TruncExp) which will be used later.
8040 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8041 }
8042 }
8043 } else if (L && L->getOpcode() == Instruction::Shl) {
8044 // X = Shl A, n
8045 // Y = AShr X, m
8046 // Both n and m are constant.
8047
8048 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8049 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8050 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8051 }
8052
8053 if (AddTruncateExpr && ShlAmtCI) {
8054 // We can merge the two given cases into a single SCEV statement,
8055 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8056 // a simpler case. The following code handles the two cases:
8057 //
8058 // 1) For a two-shift sext-inreg, i.e. n = m,
8059 // use sext(trunc(x)) as the SCEV expression.
8060 //
8061 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8062 // expression. We already checked that ShlAmt < BitWidth, so
8063 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8064 // ShlAmt - AShrAmt < Amt.
8065 const APInt &ShlAmt = ShlAmtCI->getValue();
8066 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8067 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8068 ShlAmtCI->getZExtValue() - AShrAmt);
8069 const SCEV *CompositeExpr =
8070 getMulExpr(AddTruncateExpr, getConstant(Mul));
8071 if (L->getOpcode() != Instruction::Shl)
8072 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8073
8074 return getSignExtendExpr(CompositeExpr, OuterTy);
8075 }
8076 }
8077 break;
8078 }
8079 }
8080
8081 switch (U->getOpcode()) {
8082 case Instruction::Trunc:
8083 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8084
8085 case Instruction::ZExt:
8086 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8087
8088 case Instruction::SExt:
8089 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8091 // The NSW flag of a subtract does not always survive the conversion to
8092 // A + (-1)*B. By pushing sign extension onto its operands we are much
8093 // more likely to preserve NSW and allow later AddRec optimisations.
8094 //
8095 // NOTE: This is effectively duplicating this logic from getSignExtend:
8096 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8097 // but by that point the NSW information has potentially been lost.
8098 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8099 Type *Ty = U->getType();
8100 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8101 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8102 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8103 }
8104 }
8105 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8106
8107 case Instruction::BitCast:
8108 // BitCasts are no-op casts so we just eliminate the cast.
8109 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8110 return getSCEV(U->getOperand(0));
8111 break;
8112
8113 case Instruction::PtrToInt: {
8114 // Pointer to integer cast is straight-forward, so do model it.
8115 const SCEV *Op = getSCEV(U->getOperand(0));
8116 Type *DstIntTy = U->getType();
8117 // But only if effective SCEV (integer) type is wide enough to represent
8118 // all possible pointer values.
8119 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8120 if (isa<SCEVCouldNotCompute>(IntOp))
8121 return getUnknown(V);
8122 return IntOp;
8123 }
8124 case Instruction::IntToPtr:
8125 // Just don't deal with inttoptr casts.
8126 return getUnknown(V);
8127
8128 case Instruction::SDiv:
8129 // If both operands are non-negative, this is just an udiv.
8130 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8131 isKnownNonNegative(getSCEV(U->getOperand(1))))
8132 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8133 break;
8134
8135 case Instruction::SRem:
8136 // If both operands are non-negative, this is just an urem.
8137 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8138 isKnownNonNegative(getSCEV(U->getOperand(1))))
8139 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8140 break;
8141
8142 case Instruction::GetElementPtr:
8143 return createNodeForGEP(cast<GEPOperator>(U));
8144
8145 case Instruction::PHI:
8146 return createNodeForPHI(cast<PHINode>(U));
8147
8148 case Instruction::Select:
8149 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8150 U->getOperand(2));
8151
8152 case Instruction::Call:
8153 case Instruction::Invoke:
8154 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8155 return getSCEV(RV);
8156
8157 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8158 switch (II->getIntrinsicID()) {
8159 case Intrinsic::abs:
8160 return getAbsExpr(
8161 getSCEV(II->getArgOperand(0)),
8162 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8163 case Intrinsic::umax:
8164 LHS = getSCEV(II->getArgOperand(0));
8165 RHS = getSCEV(II->getArgOperand(1));
8166 return getUMaxExpr(LHS, RHS);
8167 case Intrinsic::umin:
8168 LHS = getSCEV(II->getArgOperand(0));
8169 RHS = getSCEV(II->getArgOperand(1));
8170 return getUMinExpr(LHS, RHS);
8171 case Intrinsic::smax:
8172 LHS = getSCEV(II->getArgOperand(0));
8173 RHS = getSCEV(II->getArgOperand(1));
8174 return getSMaxExpr(LHS, RHS);
8175 case Intrinsic::smin:
8176 LHS = getSCEV(II->getArgOperand(0));
8177 RHS = getSCEV(II->getArgOperand(1));
8178 return getSMinExpr(LHS, RHS);
8179 case Intrinsic::usub_sat: {
8180 const SCEV *X = getSCEV(II->getArgOperand(0));
8181 const SCEV *Y = getSCEV(II->getArgOperand(1));
8182 const SCEV *ClampedY = getUMinExpr(X, Y);
8183 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8184 }
8185 case Intrinsic::uadd_sat: {
8186 const SCEV *X = getSCEV(II->getArgOperand(0));
8187 const SCEV *Y = getSCEV(II->getArgOperand(1));
8188 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8189 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8190 }
8191 case Intrinsic::start_loop_iterations:
8192 case Intrinsic::annotation:
8193 case Intrinsic::ptr_annotation:
8194 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8195 // just eqivalent to the first operand for SCEV purposes.
8196 return getSCEV(II->getArgOperand(0));
8197 case Intrinsic::vscale:
8198 return getVScale(II->getType());
8199 default:
8200 break;
8201 }
8202 }
8203 break;
8204 }
8205
8206 return getUnknown(V);
8207}
8208
8209//===----------------------------------------------------------------------===//
8210// Iteration Count Computation Code
8211//
8212
8214 if (isa<SCEVCouldNotCompute>(ExitCount))
8215 return getCouldNotCompute();
8216
8217 auto *ExitCountType = ExitCount->getType();
8218 assert(ExitCountType->isIntegerTy());
8219 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8220 1 + ExitCountType->getScalarSizeInBits());
8221 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8222}
8223
8225 Type *EvalTy,
8226 const Loop *L) {
8227 if (isa<SCEVCouldNotCompute>(ExitCount))
8228 return getCouldNotCompute();
8229
8230 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8231 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8232
8233 auto CanAddOneWithoutOverflow = [&]() {
8234 ConstantRange ExitCountRange =
8235 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8236 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8237 return true;
8238
8239 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8240 getMinusOne(ExitCount->getType()));
8241 };
8242
8243 // If we need to zero extend the backedge count, check if we can add one to
8244 // it prior to zero extending without overflow. Provided this is safe, it
8245 // allows better simplification of the +1.
8246 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8247 return getZeroExtendExpr(
8248 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8249
8250 // Get the total trip count from the count by adding 1. This may wrap.
8251 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8252}
8253
8254static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8255 if (!ExitCount)
8256 return 0;
8257
8258 ConstantInt *ExitConst = ExitCount->getValue();
8259
8260 // Guard against huge trip counts.
8261 if (ExitConst->getValue().getActiveBits() > 32)
8262 return 0;
8263
8264 // In case of integer overflow, this returns 0, which is correct.
8265 return ((unsigned)ExitConst->getZExtValue()) + 1;
8266}
8267
8269 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8270 return getConstantTripCount(ExitCount);
8271}
8272
8273unsigned
8275 const BasicBlock *ExitingBlock) {
8276 assert(ExitingBlock && "Must pass a non-null exiting block!");
8277 assert(L->isLoopExiting(ExitingBlock) &&
8278 "Exiting block must actually branch out of the loop!");
8279 const SCEVConstant *ExitCount =
8280 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8281 return getConstantTripCount(ExitCount);
8282}
8283
8285 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8286
8287 const auto *MaxExitCount =
8288 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8290 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8291}
8292
8294 SmallVector<BasicBlock *, 8> ExitingBlocks;
8295 L->getExitingBlocks(ExitingBlocks);
8296
8297 std::optional<unsigned> Res;
8298 for (auto *ExitingBB : ExitingBlocks) {
8299 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8300 if (!Res)
8301 Res = Multiple;
8302 Res = std::gcd(*Res, Multiple);
8303 }
8304 return Res.value_or(1);
8305}
8306
8308 const SCEV *ExitCount) {
8309 if (isa<SCEVCouldNotCompute>(ExitCount))
8310 return 1;
8311
8312 // Get the trip count
8313 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8314
8315 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8316 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8317 // the greatest power of 2 divisor less than 2^32.
8318 return Multiple.getActiveBits() > 32
8319 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8320 : (unsigned)Multiple.getZExtValue();
8321}
8322
8323/// Returns the largest constant divisor of the trip count of this loop as a
8324/// normal unsigned value, if possible. This means that the actual trip count is
8325/// always a multiple of the returned value (don't forget the trip count could
8326/// very well be zero as well!).
8327///
8328/// Returns 1 if the trip count is unknown or not guaranteed to be the
8329/// multiple of a constant (which is also the case if the trip count is simply
8330/// constant, use getSmallConstantTripCount for that case), Will also return 1
8331/// if the trip count is very large (>= 2^32).
8332///
8333/// As explained in the comments for getSmallConstantTripCount, this assumes
8334/// that control exits the loop via ExitingBlock.
8335unsigned
8337 const BasicBlock *ExitingBlock) {
8338 assert(ExitingBlock && "Must pass a non-null exiting block!");
8339 assert(L->isLoopExiting(ExitingBlock) &&
8340 "Exiting block must actually branch out of the loop!");
8341 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8342 return getSmallConstantTripMultiple(L, ExitCount);
8343}
8344
8346 const BasicBlock *ExitingBlock,
8347 ExitCountKind Kind) {
8348 switch (Kind) {
8349 case Exact:
8350 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8351 case SymbolicMaximum:
8352 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8353 case ConstantMaximum:
8354 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8355 };
8356 llvm_unreachable("Invalid ExitCountKind!");
8357}
8358
8360 const Loop *L, const BasicBlock *ExitingBlock,
8362 switch (Kind) {
8363 case Exact:
8364 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8365 Predicates);
8366 case SymbolicMaximum:
8367 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8368 Predicates);
8369 case ConstantMaximum:
8370 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8371 Predicates);
8372 };
8373 llvm_unreachable("Invalid ExitCountKind!");
8374}
8375
8378 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8379}
8380
8382 ExitCountKind Kind) {
8383 switch (Kind) {
8384 case Exact:
8385 return getBackedgeTakenInfo(L).getExact(L, this);
8386 case ConstantMaximum:
8387 return getBackedgeTakenInfo(L).getConstantMax(this);
8388 case SymbolicMaximum:
8389 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8390 };
8391 llvm_unreachable("Invalid ExitCountKind!");
8392}
8393
8396 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8397}
8398
8401 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8402}
8403
8405 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8406}
8407
8408/// Push PHI nodes in the header of the given loop onto the given Worklist.
8409static void PushLoopPHIs(const Loop *L,
8412 BasicBlock *Header = L->getHeader();
8413
8414 // Push all Loop-header PHIs onto the Worklist stack.
8415 for (PHINode &PN : Header->phis())
8416 if (Visited.insert(&PN).second)
8417 Worklist.push_back(&PN);
8418}
8419
8420ScalarEvolution::BackedgeTakenInfo &
8421ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8422 auto &BTI = getBackedgeTakenInfo(L);
8423 if (BTI.hasFullInfo())
8424 return BTI;
8425
8426 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8427
8428 if (!Pair.second)
8429 return Pair.first->second;
8430
8431 BackedgeTakenInfo Result =
8432 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8433
8434 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8435}
8436
8437ScalarEvolution::BackedgeTakenInfo &
8438ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8439 // Initially insert an invalid entry for this loop. If the insertion
8440 // succeeds, proceed to actually compute a backedge-taken count and
8441 // update the value. The temporary CouldNotCompute value tells SCEV
8442 // code elsewhere that it shouldn't attempt to request a new
8443 // backedge-taken count, which could result in infinite recursion.
8444 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8445 BackedgeTakenCounts.try_emplace(L);
8446 if (!Pair.second)
8447 return Pair.first->second;
8448
8449 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8450 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8451 // must be cleared in this scope.
8452 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8453
8454 // Now that we know more about the trip count for this loop, forget any
8455 // existing SCEV values for PHI nodes in this loop since they are only
8456 // conservative estimates made without the benefit of trip count
8457 // information. This invalidation is not necessary for correctness, and is
8458 // only done to produce more precise results.
8459 if (Result.hasAnyInfo()) {
8460 // Invalidate any expression using an addrec in this loop.
8462 auto LoopUsersIt = LoopUsers.find(L);
8463 if (LoopUsersIt != LoopUsers.end())
8464 append_range(ToForget, LoopUsersIt->second);
8465 forgetMemoizedResults(ToForget);
8466
8467 // Invalidate constant-evolved loop header phis.
8468 for (PHINode &PN : L->getHeader()->phis())
8469 ConstantEvolutionLoopExitValue.erase(&PN);
8470 }
8471
8472 // Re-lookup the insert position, since the call to
8473 // computeBackedgeTakenCount above could result in a
8474 // recusive call to getBackedgeTakenInfo (on a different
8475 // loop), which would invalidate the iterator computed
8476 // earlier.
8477 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8478}
8479
8481 // This method is intended to forget all info about loops. It should
8482 // invalidate caches as if the following happened:
8483 // - The trip counts of all loops have changed arbitrarily
8484 // - Every llvm::Value has been updated in place to produce a different
8485 // result.
8486 BackedgeTakenCounts.clear();
8487 PredicatedBackedgeTakenCounts.clear();
8488 BECountUsers.clear();
8489 LoopPropertiesCache.clear();
8490 ConstantEvolutionLoopExitValue.clear();
8491 ValueExprMap.clear();
8492 ValuesAtScopes.clear();
8493 ValuesAtScopesUsers.clear();
8494 LoopDispositions.clear();
8495 BlockDispositions.clear();
8496 UnsignedRanges.clear();
8497 SignedRanges.clear();
8498 ExprValueMap.clear();
8499 HasRecMap.clear();
8500 ConstantMultipleCache.clear();
8501 PredicatedSCEVRewrites.clear();
8502 FoldCache.clear();
8503 FoldCacheUser.clear();
8504}
8505void ScalarEvolution::visitAndClearUsers(
8509 while (!Worklist.empty()) {
8510 Instruction *I = Worklist.pop_back_val();
8511 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8512 continue;
8513
8515 ValueExprMap.find_as(static_cast<Value *>(I));
8516 if (It != ValueExprMap.end()) {
8517 eraseValueFromMap(It->first);
8518 ToForget.push_back(It->second);
8519 if (PHINode *PN = dyn_cast<PHINode>(I))
8520 ConstantEvolutionLoopExitValue.erase(PN);
8521 }
8522
8523 PushDefUseChildren(I, Worklist, Visited);
8524 }
8525}
8526
8528 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8532
8533 // Iterate over all the loops and sub-loops to drop SCEV information.
8534 while (!LoopWorklist.empty()) {
8535 auto *CurrL = LoopWorklist.pop_back_val();
8536
8537 // Drop any stored trip count value.
8538 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8539 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8540
8541 // Drop information about predicated SCEV rewrites for this loop.
8542 for (auto I = PredicatedSCEVRewrites.begin();
8543 I != PredicatedSCEVRewrites.end();) {
8544 std::pair<const SCEV *, const Loop *> Entry = I->first;
8545 if (Entry.second == CurrL)
8546 PredicatedSCEVRewrites.erase(I++);
8547 else
8548 ++I;
8549 }
8550
8551 auto LoopUsersItr = LoopUsers.find(CurrL);
8552 if (LoopUsersItr != LoopUsers.end())
8553 llvm::append_range(ToForget, LoopUsersItr->second);
8554
8555 // Drop information about expressions based on loop-header PHIs.
8556 PushLoopPHIs(CurrL, Worklist, Visited);
8557 visitAndClearUsers(Worklist, Visited, ToForget);
8558
8559 LoopPropertiesCache.erase(CurrL);
8560 // Forget all contained loops too, to avoid dangling entries in the
8561 // ValuesAtScopes map.
8562 LoopWorklist.append(CurrL->begin(), CurrL->end());
8563 }
8564 forgetMemoizedResults(ToForget);
8565}
8566
8568 forgetLoop(L->getOutermostLoop());
8569}
8570
8573 if (!I) return;
8574
8575 // Drop information about expressions based on loop-header PHIs.
8579 Worklist.push_back(I);
8580 Visited.insert(I);
8581 visitAndClearUsers(Worklist, Visited, ToForget);
8582
8583 forgetMemoizedResults(ToForget);
8584}
8585
8587 if (!isSCEVable(V->getType()))
8588 return;
8589
8590 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8591 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8592 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8593 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8594 if (const SCEV *S = getExistingSCEV(V)) {
8595 struct InvalidationRootCollector {
8596 Loop *L;
8598
8599 InvalidationRootCollector(Loop *L) : L(L) {}
8600
8601 bool follow(const SCEV *S) {
8602 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8603 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8604 if (L->contains(I))
8605 Roots.push_back(S);
8606 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8607 if (L->contains(AddRec->getLoop()))
8608 Roots.push_back(S);
8609 }
8610 return true;
8611 }
8612 bool isDone() const { return false; }
8613 };
8614
8615 InvalidationRootCollector C(L);
8616 visitAll(S, C);
8617 forgetMemoizedResults(C.Roots);
8618 }
8619
8620 // Also perform the normal invalidation.
8621 forgetValue(V);
8622}
8623
8624void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8625
8627 // Unless a specific value is passed to invalidation, completely clear both
8628 // caches.
8629 if (!V) {
8630 BlockDispositions.clear();
8631 LoopDispositions.clear();
8632 return;
8633 }
8634
8635 if (!isSCEVable(V->getType()))
8636 return;
8637
8638 const SCEV *S = getExistingSCEV(V);
8639 if (!S)
8640 return;
8641
8642 // Invalidate the block and loop dispositions cached for S. Dispositions of
8643 // S's users may change if S's disposition changes (i.e. a user may change to
8644 // loop-invariant, if S changes to loop invariant), so also invalidate
8645 // dispositions of S's users recursively.
8646 SmallVector<const SCEV *, 8> Worklist = {S};
8648 while (!Worklist.empty()) {
8649 const SCEV *Curr = Worklist.pop_back_val();
8650 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8651 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8652 if (!LoopDispoRemoved && !BlockDispoRemoved)
8653 continue;
8654 auto Users = SCEVUsers.find(Curr);
8655 if (Users != SCEVUsers.end())
8656 for (const auto *User : Users->second)
8657 if (Seen.insert(User).second)
8658 Worklist.push_back(User);
8659 }
8660}
8661
8662/// Get the exact loop backedge taken count considering all loop exits. A
8663/// computable result can only be returned for loops with all exiting blocks
8664/// dominating the latch. howFarToZero assumes that the limit of each loop test
8665/// is never skipped. This is a valid assumption as long as the loop exits via
8666/// that test. For precise results, it is the caller's responsibility to specify
8667/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8668const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8669 const Loop *L, ScalarEvolution *SE,
8671 // If any exits were not computable, the loop is not computable.
8672 if (!isComplete() || ExitNotTaken.empty())
8673 return SE->getCouldNotCompute();
8674
8675 const BasicBlock *Latch = L->getLoopLatch();
8676 // All exiting blocks we have collected must dominate the only backedge.
8677 if (!Latch)
8678 return SE->getCouldNotCompute();
8679
8680 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8681 // count is simply a minimum out of all these calculated exit counts.
8683 for (const auto &ENT : ExitNotTaken) {
8684 const SCEV *BECount = ENT.ExactNotTaken;
8685 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8686 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8687 "We should only have known counts for exiting blocks that dominate "
8688 "latch!");
8689
8690 Ops.push_back(BECount);
8691
8692 if (Preds)
8693 append_range(*Preds, ENT.Predicates);
8694
8695 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8696 "Predicate should be always true!");
8697 }
8698
8699 // If an earlier exit exits on the first iteration (exit count zero), then
8700 // a later poison exit count should not propagate into the result. This are
8701 // exactly the semantics provided by umin_seq.
8702 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8703}
8704
8705const ScalarEvolution::ExitNotTakenInfo *
8706ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8707 const BasicBlock *ExitingBlock,
8708 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8709 for (const auto &ENT : ExitNotTaken)
8710 if (ENT.ExitingBlock == ExitingBlock) {
8711 if (ENT.hasAlwaysTruePredicate())
8712 return &ENT;
8713 else if (Predicates) {
8714 append_range(*Predicates, ENT.Predicates);
8715 return &ENT;
8716 }
8717 }
8718
8719 return nullptr;
8720}
8721
8722/// getConstantMax - Get the constant max backedge taken count for the loop.
8723const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8724 ScalarEvolution *SE,
8725 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8726 if (!getConstantMax())
8727 return SE->getCouldNotCompute();
8728
8729 for (const auto &ENT : ExitNotTaken)
8730 if (!ENT.hasAlwaysTruePredicate()) {
8731 if (!Predicates)
8732 return SE->getCouldNotCompute();
8733 append_range(*Predicates, ENT.Predicates);
8734 }
8735
8736 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8737 isa<SCEVConstant>(getConstantMax())) &&
8738 "No point in having a non-constant max backedge taken count!");
8739 return getConstantMax();
8740}
8741
8742const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8743 const Loop *L, ScalarEvolution *SE,
8744 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8745 if (!SymbolicMax) {
8746 // Form an expression for the maximum exit count possible for this loop. We
8747 // merge the max and exact information to approximate a version of
8748 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8749 // constants.
8751
8752 for (const auto &ENT : ExitNotTaken) {
8753 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8754 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8755 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8756 "We should only have known counts for exiting blocks that "
8757 "dominate latch!");
8758 ExitCounts.push_back(ExitCount);
8759 if (Predicates)
8760 append_range(*Predicates, ENT.Predicates);
8761
8762 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8763 "Predicate should be always true!");
8764 }
8765 }
8766 if (ExitCounts.empty())
8767 SymbolicMax = SE->getCouldNotCompute();
8768 else
8769 SymbolicMax =
8770 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8771 }
8772 return SymbolicMax;
8773}
8774
8775bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8776 ScalarEvolution *SE) const {
8777 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8778 return !ENT.hasAlwaysTruePredicate();
8779 };
8780 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8781}
8782
8785
8787 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8788 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8792 // If we prove the max count is zero, so is the symbolic bound. This happens
8793 // in practice due to differences in a) how context sensitive we've chosen
8794 // to be and b) how we reason about bounds implied by UB.
8795 if (ConstantMaxNotTaken->isZero()) {
8796 this->ExactNotTaken = E = ConstantMaxNotTaken;
8797 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8798 }
8799
8802 "Exact is not allowed to be less precise than Constant Max");
8805 "Exact is not allowed to be less precise than Symbolic Max");
8808 "Symbolic Max is not allowed to be less precise than Constant Max");
8811 "No point in having a non-constant max backedge taken count!");
8813 for (const auto PredList : PredLists)
8814 for (const auto *P : PredList) {
8815 if (SeenPreds.contains(P))
8816 continue;
8817 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8818 SeenPreds.insert(P);
8819 Predicates.push_back(P);
8820 }
8821 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8822 "Backedge count should be int");
8824 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8825 "Max backedge count should be int");
8826}
8827
8835
8836/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8837/// computable exit into a persistent ExitNotTakenInfo array.
8838ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8840 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8841 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8842 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8843
8844 ExitNotTaken.reserve(ExitCounts.size());
8845 std::transform(ExitCounts.begin(), ExitCounts.end(),
8846 std::back_inserter(ExitNotTaken),
8847 [&](const EdgeExitInfo &EEI) {
8848 BasicBlock *ExitBB = EEI.first;
8849 const ExitLimit &EL = EEI.second;
8850 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8851 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8852 EL.Predicates);
8853 });
8854 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8855 isa<SCEVConstant>(ConstantMax)) &&
8856 "No point in having a non-constant max backedge taken count!");
8857}
8858
8859/// Compute the number of times the backedge of the specified loop will execute.
8860ScalarEvolution::BackedgeTakenInfo
8861ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8862 bool AllowPredicates) {
8863 SmallVector<BasicBlock *, 8> ExitingBlocks;
8864 L->getExitingBlocks(ExitingBlocks);
8865
8866 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8867
8869 bool CouldComputeBECount = true;
8870 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8871 const SCEV *MustExitMaxBECount = nullptr;
8872 const SCEV *MayExitMaxBECount = nullptr;
8873 bool MustExitMaxOrZero = false;
8874 bool IsOnlyExit = ExitingBlocks.size() == 1;
8875
8876 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8877 // and compute maxBECount.
8878 // Do a union of all the predicates here.
8879 for (BasicBlock *ExitBB : ExitingBlocks) {
8880 // We canonicalize untaken exits to br (constant), ignore them so that
8881 // proving an exit untaken doesn't negatively impact our ability to reason
8882 // about the loop as whole.
8883 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8884 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8885 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8886 if (ExitIfTrue == CI->isZero())
8887 continue;
8888 }
8889
8890 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8891
8892 assert((AllowPredicates || EL.Predicates.empty()) &&
8893 "Predicated exit limit when predicates are not allowed!");
8894
8895 // 1. For each exit that can be computed, add an entry to ExitCounts.
8896 // CouldComputeBECount is true only if all exits can be computed.
8897 if (EL.ExactNotTaken != getCouldNotCompute())
8898 ++NumExitCountsComputed;
8899 else
8900 // We couldn't compute an exact value for this exit, so
8901 // we won't be able to compute an exact value for the loop.
8902 CouldComputeBECount = false;
8903 // Remember exit count if either exact or symbolic is known. Because
8904 // Exact always implies symbolic, only check symbolic.
8905 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8906 ExitCounts.emplace_back(ExitBB, EL);
8907 else {
8908 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8909 "Exact is known but symbolic isn't?");
8910 ++NumExitCountsNotComputed;
8911 }
8912
8913 // 2. Derive the loop's MaxBECount from each exit's max number of
8914 // non-exiting iterations. Partition the loop exits into two kinds:
8915 // LoopMustExits and LoopMayExits.
8916 //
8917 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8918 // is a LoopMayExit. If any computable LoopMustExit is found, then
8919 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8920 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8921 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8922 // any
8923 // computable EL.ConstantMaxNotTaken.
8924 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8925 DT.dominates(ExitBB, Latch)) {
8926 if (!MustExitMaxBECount) {
8927 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8928 MustExitMaxOrZero = EL.MaxOrZero;
8929 } else {
8930 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8931 EL.ConstantMaxNotTaken);
8932 }
8933 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8934 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8935 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8936 else {
8937 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8938 EL.ConstantMaxNotTaken);
8939 }
8940 }
8941 }
8942 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8943 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8944 // The loop backedge will be taken the maximum or zero times if there's
8945 // a single exit that must be taken the maximum or zero times.
8946 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8947
8948 // Remember which SCEVs are used in exit limits for invalidation purposes.
8949 // We only care about non-constant SCEVs here, so we can ignore
8950 // EL.ConstantMaxNotTaken
8951 // and MaxBECount, which must be SCEVConstant.
8952 for (const auto &Pair : ExitCounts) {
8953 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8954 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8955 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8956 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8957 {L, AllowPredicates});
8958 }
8959 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8960 MaxBECount, MaxOrZero);
8961}
8962
8963ScalarEvolution::ExitLimit
8964ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8965 bool IsOnlyExit, bool AllowPredicates) {
8966 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8967 // If our exiting block does not dominate the latch, then its connection with
8968 // loop's exit limit may be far from trivial.
8969 const BasicBlock *Latch = L->getLoopLatch();
8970 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8971 return getCouldNotCompute();
8972
8973 Instruction *Term = ExitingBlock->getTerminator();
8974 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8975 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8976 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8977 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8978 "It should have one successor in loop and one exit block!");
8979 // Proceed to the next level to examine the exit condition expression.
8980 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8981 /*ControlsOnlyExit=*/IsOnlyExit,
8982 AllowPredicates);
8983 }
8984
8985 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8986 // For switch, make sure that there is a single exit from the loop.
8987 BasicBlock *Exit = nullptr;
8988 for (auto *SBB : successors(ExitingBlock))
8989 if (!L->contains(SBB)) {
8990 if (Exit) // Multiple exit successors.
8991 return getCouldNotCompute();
8992 Exit = SBB;
8993 }
8994 assert(Exit && "Exiting block must have at least one exit");
8995 return computeExitLimitFromSingleExitSwitch(
8996 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8997 }
8998
8999 return getCouldNotCompute();
9000}
9001
9003 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9004 bool AllowPredicates) {
9005 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9006 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9007 ControlsOnlyExit, AllowPredicates);
9008}
9009
9010std::optional<ScalarEvolution::ExitLimit>
9011ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9012 bool ExitIfTrue, bool ControlsOnlyExit,
9013 bool AllowPredicates) {
9014 (void)this->L;
9015 (void)this->ExitIfTrue;
9016 (void)this->AllowPredicates;
9017
9018 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9019 this->AllowPredicates == AllowPredicates &&
9020 "Variance in assumed invariant key components!");
9021 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9022 if (Itr == TripCountMap.end())
9023 return std::nullopt;
9024 return Itr->second;
9025}
9026
9027void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9028 bool ExitIfTrue,
9029 bool ControlsOnlyExit,
9030 bool AllowPredicates,
9031 const ExitLimit &EL) {
9032 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9033 this->AllowPredicates == AllowPredicates &&
9034 "Variance in assumed invariant key components!");
9035
9036 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9037 assert(InsertResult.second && "Expected successful insertion!");
9038 (void)InsertResult;
9039 (void)ExitIfTrue;
9040}
9041
9042ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9043 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9044 bool ControlsOnlyExit, bool AllowPredicates) {
9045
9046 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9047 AllowPredicates))
9048 return *MaybeEL;
9049
9050 ExitLimit EL = computeExitLimitFromCondImpl(
9051 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9052 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9053 return EL;
9054}
9055
9056ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9057 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9058 bool ControlsOnlyExit, bool AllowPredicates) {
9059 // Handle BinOp conditions (And, Or).
9060 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9061 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9062 return *LimitFromBinOp;
9063
9064 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9065 // Proceed to the next level to examine the icmp.
9066 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9067 ExitLimit EL =
9068 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9069 if (EL.hasFullInfo() || !AllowPredicates)
9070 return EL;
9071
9072 // Try again, but use SCEV predicates this time.
9073 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9074 ControlsOnlyExit,
9075 /*AllowPredicates=*/true);
9076 }
9077
9078 // Check for a constant condition. These are normally stripped out by
9079 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9080 // preserve the CFG and is temporarily leaving constant conditions
9081 // in place.
9082 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9083 if (ExitIfTrue == !CI->getZExtValue())
9084 // The backedge is always taken.
9085 return getCouldNotCompute();
9086 // The backedge is never taken.
9087 return getZero(CI->getType());
9088 }
9089
9090 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9091 // with a constant step, we can form an equivalent icmp predicate and figure
9092 // out how many iterations will be taken before we exit.
9093 const WithOverflowInst *WO;
9094 const APInt *C;
9095 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9096 match(WO->getRHS(), m_APInt(C))) {
9097 ConstantRange NWR =
9099 WO->getNoWrapKind());
9100 CmpInst::Predicate Pred;
9101 APInt NewRHSC, Offset;
9102 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9103 if (!ExitIfTrue)
9104 Pred = ICmpInst::getInversePredicate(Pred);
9105 auto *LHS = getSCEV(WO->getLHS());
9106 if (Offset != 0)
9108 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9109 ControlsOnlyExit, AllowPredicates);
9110 if (EL.hasAnyInfo())
9111 return EL;
9112 }
9113
9114 // If it's not an integer or pointer comparison then compute it the hard way.
9115 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9116}
9117
9118std::optional<ScalarEvolution::ExitLimit>
9119ScalarEvolution::computeExitLimitFromCondFromBinOp(
9120 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9121 bool ControlsOnlyExit, bool AllowPredicates) {
9122 // Check if the controlling expression for this loop is an And or Or.
9123 Value *Op0, *Op1;
9124 bool IsAnd = false;
9125 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9126 IsAnd = true;
9127 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9128 IsAnd = false;
9129 else
9130 return std::nullopt;
9131
9132 // EitherMayExit is true in these two cases:
9133 // br (and Op0 Op1), loop, exit
9134 // br (or Op0 Op1), exit, loop
9135 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9136 ExitLimit EL0 = computeExitLimitFromCondCached(
9137 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9138 AllowPredicates);
9139 ExitLimit EL1 = computeExitLimitFromCondCached(
9140 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9141 AllowPredicates);
9142
9143 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9144 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9145 if (isa<ConstantInt>(Op1))
9146 return Op1 == NeutralElement ? EL0 : EL1;
9147 if (isa<ConstantInt>(Op0))
9148 return Op0 == NeutralElement ? EL1 : EL0;
9149
9150 const SCEV *BECount = getCouldNotCompute();
9151 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9152 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9153 if (EitherMayExit) {
9154 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9155 // Both conditions must be same for the loop to continue executing.
9156 // Choose the less conservative count.
9157 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9158 EL1.ExactNotTaken != getCouldNotCompute()) {
9159 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9160 UseSequentialUMin);
9161 }
9162 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9163 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9164 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9165 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9166 else
9167 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9168 EL1.ConstantMaxNotTaken);
9169 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9170 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9171 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9172 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9173 else
9174 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9175 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9176 } else {
9177 // Both conditions must be same at the same time for the loop to exit.
9178 // For now, be conservative.
9179 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9180 BECount = EL0.ExactNotTaken;
9181 }
9182
9183 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9184 // to be more aggressive when computing BECount than when computing
9185 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9186 // and
9187 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9188 // EL1.ConstantMaxNotTaken to not.
9189 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9190 !isa<SCEVCouldNotCompute>(BECount))
9191 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9192 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9193 SymbolicMaxBECount =
9194 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9195 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9196 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9197}
9198
9199ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9200 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9201 bool AllowPredicates) {
9202 // If the condition was exit on true, convert the condition to exit on false
9203 CmpPredicate Pred;
9204 if (!ExitIfTrue)
9205 Pred = ExitCond->getCmpPredicate();
9206 else
9207 Pred = ExitCond->getInverseCmpPredicate();
9208 const ICmpInst::Predicate OriginalPred = Pred;
9209
9210 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9211 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9212
9213 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9214 AllowPredicates);
9215 if (EL.hasAnyInfo())
9216 return EL;
9217
9218 auto *ExhaustiveCount =
9219 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9220
9221 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9222 return ExhaustiveCount;
9223
9224 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9225 ExitCond->getOperand(1), L, OriginalPred);
9226}
9227ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9228 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9229 bool ControlsOnlyExit, bool AllowPredicates) {
9230
9231 // Try to evaluate any dependencies out of the loop.
9232 LHS = getSCEVAtScope(LHS, L);
9233 RHS = getSCEVAtScope(RHS, L);
9234
9235 // At this point, we would like to compute how many iterations of the
9236 // loop the predicate will return true for these inputs.
9237 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9238 // If there is a loop-invariant, force it into the RHS.
9239 std::swap(LHS, RHS);
9241 }
9242
9243 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9245 // Simplify the operands before analyzing them.
9246 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9247
9248 // If we have a comparison of a chrec against a constant, try to use value
9249 // ranges to answer this query.
9250 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9251 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9252 if (AddRec->getLoop() == L) {
9253 // Form the constant range.
9254 ConstantRange CompRange =
9255 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9256
9257 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9258 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9259 }
9260
9261 // If this loop must exit based on this condition (or execute undefined
9262 // behaviour), see if we can improve wrap flags. This is essentially
9263 // a must execute style proof.
9264 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9265 // If we can prove the test sequence produced must repeat the same values
9266 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9267 // because if it did, we'd have an infinite (undefined) loop.
9268 // TODO: We can peel off any functions which are invertible *in L*. Loop
9269 // invariant terms are effectively constants for our purposes here.
9270 auto *InnerLHS = LHS;
9271 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9272 InnerLHS = ZExt->getOperand();
9273 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9274 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9275 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9276 /*OrNegative=*/true)) {
9277 auto Flags = AR->getNoWrapFlags();
9278 Flags = setFlags(Flags, SCEV::FlagNW);
9281 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9282 }
9283
9284 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9285 // From no-self-wrap, this follows trivially from the fact that every
9286 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9287 // last value before (un)signed wrap. Since we know that last value
9288 // didn't exit, nor will any smaller one.
9289 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9290 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9291 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9292 AR && AR->getLoop() == L && AR->isAffine() &&
9293 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9294 isKnownPositive(AR->getStepRecurrence(*this))) {
9295 auto Flags = AR->getNoWrapFlags();
9296 Flags = setFlags(Flags, WrapType);
9299 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9300 }
9301 }
9302 }
9303
9304 switch (Pred) {
9305 case ICmpInst::ICMP_NE: { // while (X != Y)
9306 // Convert to: while (X-Y != 0)
9307 if (LHS->getType()->isPointerTy()) {
9310 return LHS;
9311 }
9312 if (RHS->getType()->isPointerTy()) {
9315 return RHS;
9316 }
9317 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9318 AllowPredicates);
9319 if (EL.hasAnyInfo())
9320 return EL;
9321 break;
9322 }
9323 case ICmpInst::ICMP_EQ: { // while (X == Y)
9324 // Convert to: while (X-Y == 0)
9325 if (LHS->getType()->isPointerTy()) {
9328 return LHS;
9329 }
9330 if (RHS->getType()->isPointerTy()) {
9333 return RHS;
9334 }
9335 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9336 if (EL.hasAnyInfo()) return EL;
9337 break;
9338 }
9339 case ICmpInst::ICMP_SLE:
9340 case ICmpInst::ICMP_ULE:
9341 // Since the loop is finite, an invariant RHS cannot include the boundary
9342 // value, otherwise it would loop forever.
9343 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9344 !isLoopInvariant(RHS, L)) {
9345 // Otherwise, perform the addition in a wider type, to avoid overflow.
9346 // If the LHS is an addrec with the appropriate nowrap flag, the
9347 // extension will be sunk into it and the exit count can be analyzed.
9348 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9349 if (!OldType)
9350 break;
9351 // Prefer doubling the bitwidth over adding a single bit to make it more
9352 // likely that we use a legal type.
9353 auto *NewType =
9354 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9355 if (ICmpInst::isSigned(Pred)) {
9356 LHS = getSignExtendExpr(LHS, NewType);
9357 RHS = getSignExtendExpr(RHS, NewType);
9358 } else {
9359 LHS = getZeroExtendExpr(LHS, NewType);
9360 RHS = getZeroExtendExpr(RHS, NewType);
9361 }
9362 }
9364 [[fallthrough]];
9365 case ICmpInst::ICMP_SLT:
9366 case ICmpInst::ICMP_ULT: { // while (X < Y)
9367 bool IsSigned = ICmpInst::isSigned(Pred);
9368 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9369 AllowPredicates);
9370 if (EL.hasAnyInfo())
9371 return EL;
9372 break;
9373 }
9374 case ICmpInst::ICMP_SGE:
9375 case ICmpInst::ICMP_UGE:
9376 // Since the loop is finite, an invariant RHS cannot include the boundary
9377 // value, otherwise it would loop forever.
9378 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9379 !isLoopInvariant(RHS, L))
9380 break;
9382 [[fallthrough]];
9383 case ICmpInst::ICMP_SGT:
9384 case ICmpInst::ICMP_UGT: { // while (X > Y)
9385 bool IsSigned = ICmpInst::isSigned(Pred);
9386 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9387 AllowPredicates);
9388 if (EL.hasAnyInfo())
9389 return EL;
9390 break;
9391 }
9392 default:
9393 break;
9394 }
9395
9396 return getCouldNotCompute();
9397}
9398
9399ScalarEvolution::ExitLimit
9400ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9401 SwitchInst *Switch,
9402 BasicBlock *ExitingBlock,
9403 bool ControlsOnlyExit) {
9404 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9405
9406 // Give up if the exit is the default dest of a switch.
9407 if (Switch->getDefaultDest() == ExitingBlock)
9408 return getCouldNotCompute();
9409
9410 assert(L->contains(Switch->getDefaultDest()) &&
9411 "Default case must not exit the loop!");
9412 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9413 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9414
9415 // while (X != Y) --> while (X-Y != 0)
9416 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9417 if (EL.hasAnyInfo())
9418 return EL;
9419
9420 return getCouldNotCompute();
9421}
9422
9423static ConstantInt *
9425 ScalarEvolution &SE) {
9426 const SCEV *InVal = SE.getConstant(C);
9427 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9429 "Evaluation of SCEV at constant didn't fold correctly?");
9430 return cast<SCEVConstant>(Val)->getValue();
9431}
9432
9433ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9434 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9435 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9436 if (!RHS)
9437 return getCouldNotCompute();
9438
9439 const BasicBlock *Latch = L->getLoopLatch();
9440 if (!Latch)
9441 return getCouldNotCompute();
9442
9443 const BasicBlock *Predecessor = L->getLoopPredecessor();
9444 if (!Predecessor)
9445 return getCouldNotCompute();
9446
9447 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9448 // Return LHS in OutLHS and shift_opt in OutOpCode.
9449 auto MatchPositiveShift =
9450 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9451
9452 using namespace PatternMatch;
9453
9454 ConstantInt *ShiftAmt;
9455 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9456 OutOpCode = Instruction::LShr;
9457 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9458 OutOpCode = Instruction::AShr;
9459 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9460 OutOpCode = Instruction::Shl;
9461 else
9462 return false;
9463
9464 return ShiftAmt->getValue().isStrictlyPositive();
9465 };
9466
9467 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9468 //
9469 // loop:
9470 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9471 // %iv.shifted = lshr i32 %iv, <positive constant>
9472 //
9473 // Return true on a successful match. Return the corresponding PHI node (%iv
9474 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9475 auto MatchShiftRecurrence =
9476 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9477 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9478
9479 {
9481 Value *V;
9482
9483 // If we encounter a shift instruction, "peel off" the shift operation,
9484 // and remember that we did so. Later when we inspect %iv's backedge
9485 // value, we will make sure that the backedge value uses the same
9486 // operation.
9487 //
9488 // Note: the peeled shift operation does not have to be the same
9489 // instruction as the one feeding into the PHI's backedge value. We only
9490 // really care about it being the same *kind* of shift instruction --
9491 // that's all that is required for our later inferences to hold.
9492 if (MatchPositiveShift(LHS, V, OpC)) {
9493 PostShiftOpCode = OpC;
9494 LHS = V;
9495 }
9496 }
9497
9498 PNOut = dyn_cast<PHINode>(LHS);
9499 if (!PNOut || PNOut->getParent() != L->getHeader())
9500 return false;
9501
9502 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9503 Value *OpLHS;
9504
9505 return
9506 // The backedge value for the PHI node must be a shift by a positive
9507 // amount
9508 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9509
9510 // of the PHI node itself
9511 OpLHS == PNOut &&
9512
9513 // and the kind of shift should be match the kind of shift we peeled
9514 // off, if any.
9515 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9516 };
9517
9518 PHINode *PN;
9520 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9521 return getCouldNotCompute();
9522
9523 const DataLayout &DL = getDataLayout();
9524
9525 // The key rationale for this optimization is that for some kinds of shift
9526 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9527 // within a finite number of iterations. If the condition guarding the
9528 // backedge (in the sense that the backedge is taken if the condition is true)
9529 // is false for the value the shift recurrence stabilizes to, then we know
9530 // that the backedge is taken only a finite number of times.
9531
9532 ConstantInt *StableValue = nullptr;
9533 switch (OpCode) {
9534 default:
9535 llvm_unreachable("Impossible case!");
9536
9537 case Instruction::AShr: {
9538 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9539 // bitwidth(K) iterations.
9540 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9541 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9542 Predecessor->getTerminator(), &DT);
9543 auto *Ty = cast<IntegerType>(RHS->getType());
9544 if (Known.isNonNegative())
9545 StableValue = ConstantInt::get(Ty, 0);
9546 else if (Known.isNegative())
9547 StableValue = ConstantInt::get(Ty, -1, true);
9548 else
9549 return getCouldNotCompute();
9550
9551 break;
9552 }
9553 case Instruction::LShr:
9554 case Instruction::Shl:
9555 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9556 // stabilize to 0 in at most bitwidth(K) iterations.
9557 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9558 break;
9559 }
9560
9561 auto *Result =
9562 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9563 assert(Result->getType()->isIntegerTy(1) &&
9564 "Otherwise cannot be an operand to a branch instruction");
9565
9566 if (Result->isZeroValue()) {
9567 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9568 const SCEV *UpperBound =
9570 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9571 }
9572
9573 return getCouldNotCompute();
9574}
9575
9576/// Return true if we can constant fold an instruction of the specified type,
9577/// assuming that all operands were constants.
9578static bool CanConstantFold(const Instruction *I) {
9582 return true;
9583
9584 if (const CallInst *CI = dyn_cast<CallInst>(I))
9585 if (const Function *F = CI->getCalledFunction())
9586 return canConstantFoldCallTo(CI, F);
9587 return false;
9588}
9589
9590/// Determine whether this instruction can constant evolve within this loop
9591/// assuming its operands can all constant evolve.
9592static bool canConstantEvolve(Instruction *I, const Loop *L) {
9593 // An instruction outside of the loop can't be derived from a loop PHI.
9594 if (!L->contains(I)) return false;
9595
9596 if (isa<PHINode>(I)) {
9597 // We don't currently keep track of the control flow needed to evaluate
9598 // PHIs, so we cannot handle PHIs inside of loops.
9599 return L->getHeader() == I->getParent();
9600 }
9601
9602 // If we won't be able to constant fold this expression even if the operands
9603 // are constants, bail early.
9604 return CanConstantFold(I);
9605}
9606
9607/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9608/// recursing through each instruction operand until reaching a loop header phi.
9609static PHINode *
9612 unsigned Depth) {
9614 return nullptr;
9615
9616 // Otherwise, we can evaluate this instruction if all of its operands are
9617 // constant or derived from a PHI node themselves.
9618 PHINode *PHI = nullptr;
9619 for (Value *Op : UseInst->operands()) {
9620 if (isa<Constant>(Op)) continue;
9621
9623 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9624
9625 PHINode *P = dyn_cast<PHINode>(OpInst);
9626 if (!P)
9627 // If this operand is already visited, reuse the prior result.
9628 // We may have P != PHI if this is the deepest point at which the
9629 // inconsistent paths meet.
9630 P = PHIMap.lookup(OpInst);
9631 if (!P) {
9632 // Recurse and memoize the results, whether a phi is found or not.
9633 // This recursive call invalidates pointers into PHIMap.
9634 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9635 PHIMap[OpInst] = P;
9636 }
9637 if (!P)
9638 return nullptr; // Not evolving from PHI
9639 if (PHI && PHI != P)
9640 return nullptr; // Evolving from multiple different PHIs.
9641 PHI = P;
9642 }
9643 // This is a expression evolving from a constant PHI!
9644 return PHI;
9645}
9646
9647/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9648/// in the loop that V is derived from. We allow arbitrary operations along the
9649/// way, but the operands of an operation must either be constants or a value
9650/// derived from a constant PHI. If this expression does not fit with these
9651/// constraints, return null.
9654 if (!I || !canConstantEvolve(I, L)) return nullptr;
9655
9656 if (PHINode *PN = dyn_cast<PHINode>(I))
9657 return PN;
9658
9659 // Record non-constant instructions contained by the loop.
9661 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9662}
9663
9664/// EvaluateExpression - Given an expression that passes the
9665/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9666/// in the loop has the value PHIVal. If we can't fold this expression for some
9667/// reason, return null.
9670 const DataLayout &DL,
9671 const TargetLibraryInfo *TLI) {
9672 // Convenient constant check, but redundant for recursive calls.
9673 if (Constant *C = dyn_cast<Constant>(V)) return C;
9675 if (!I) return nullptr;
9676
9677 if (Constant *C = Vals.lookup(I)) return C;
9678
9679 // An instruction inside the loop depends on a value outside the loop that we
9680 // weren't given a mapping for, or a value such as a call inside the loop.
9681 if (!canConstantEvolve(I, L)) return nullptr;
9682
9683 // An unmapped PHI can be due to a branch or another loop inside this loop,
9684 // or due to this not being the initial iteration through a loop where we
9685 // couldn't compute the evolution of this particular PHI last time.
9686 if (isa<PHINode>(I)) return nullptr;
9687
9688 std::vector<Constant*> Operands(I->getNumOperands());
9689
9690 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9691 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9692 if (!Operand) {
9693 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9694 if (!Operands[i]) return nullptr;
9695 continue;
9696 }
9697 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9698 Vals[Operand] = C;
9699 if (!C) return nullptr;
9700 Operands[i] = C;
9701 }
9702
9703 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9704 /*AllowNonDeterministic=*/false);
9705}
9706
9707
9708// If every incoming value to PN except the one for BB is a specific Constant,
9709// return that, else return nullptr.
9711 Constant *IncomingVal = nullptr;
9712
9713 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9714 if (PN->getIncomingBlock(i) == BB)
9715 continue;
9716
9717 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9718 if (!CurrentVal)
9719 return nullptr;
9720
9721 if (IncomingVal != CurrentVal) {
9722 if (IncomingVal)
9723 return nullptr;
9724 IncomingVal = CurrentVal;
9725 }
9726 }
9727
9728 return IncomingVal;
9729}
9730
9731/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9732/// in the header of its containing loop, we know the loop executes a
9733/// constant number of times, and the PHI node is just a recurrence
9734/// involving constants, fold it.
9735Constant *
9736ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9737 const APInt &BEs,
9738 const Loop *L) {
9739 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9740 if (!Inserted)
9741 return I->second;
9742
9744 return nullptr; // Not going to evaluate it.
9745
9746 Constant *&RetVal = I->second;
9747
9748 DenseMap<Instruction *, Constant *> CurrentIterVals;
9749 BasicBlock *Header = L->getHeader();
9750 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9751
9752 BasicBlock *Latch = L->getLoopLatch();
9753 if (!Latch)
9754 return nullptr;
9755
9756 for (PHINode &PHI : Header->phis()) {
9757 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9758 CurrentIterVals[&PHI] = StartCST;
9759 }
9760 if (!CurrentIterVals.count(PN))
9761 return RetVal = nullptr;
9762
9763 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9764
9765 // Execute the loop symbolically to determine the exit value.
9766 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9767 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9768
9769 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9770 unsigned IterationNum = 0;
9771 const DataLayout &DL = getDataLayout();
9772 for (; ; ++IterationNum) {
9773 if (IterationNum == NumIterations)
9774 return RetVal = CurrentIterVals[PN]; // Got exit value!
9775
9776 // Compute the value of the PHIs for the next iteration.
9777 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9778 DenseMap<Instruction *, Constant *> NextIterVals;
9779 Constant *NextPHI =
9780 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9781 if (!NextPHI)
9782 return nullptr; // Couldn't evaluate!
9783 NextIterVals[PN] = NextPHI;
9784
9785 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9786
9787 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9788 // cease to be able to evaluate one of them or if they stop evolving,
9789 // because that doesn't necessarily prevent us from computing PN.
9791 for (const auto &I : CurrentIterVals) {
9792 PHINode *PHI = dyn_cast<PHINode>(I.first);
9793 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9794 PHIsToCompute.emplace_back(PHI, I.second);
9795 }
9796 // We use two distinct loops because EvaluateExpression may invalidate any
9797 // iterators into CurrentIterVals.
9798 for (const auto &I : PHIsToCompute) {
9799 PHINode *PHI = I.first;
9800 Constant *&NextPHI = NextIterVals[PHI];
9801 if (!NextPHI) { // Not already computed.
9802 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9803 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9804 }
9805 if (NextPHI != I.second)
9806 StoppedEvolving = false;
9807 }
9808
9809 // If all entries in CurrentIterVals == NextIterVals then we can stop
9810 // iterating, the loop can't continue to change.
9811 if (StoppedEvolving)
9812 return RetVal = CurrentIterVals[PN];
9813
9814 CurrentIterVals.swap(NextIterVals);
9815 }
9816}
9817
9818const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9819 Value *Cond,
9820 bool ExitWhen) {
9821 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9822 if (!PN) return getCouldNotCompute();
9823
9824 // If the loop is canonicalized, the PHI will have exactly two entries.
9825 // That's the only form we support here.
9826 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9827
9828 DenseMap<Instruction *, Constant *> CurrentIterVals;
9829 BasicBlock *Header = L->getHeader();
9830 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9831
9832 BasicBlock *Latch = L->getLoopLatch();
9833 assert(Latch && "Should follow from NumIncomingValues == 2!");
9834
9835 for (PHINode &PHI : Header->phis()) {
9836 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9837 CurrentIterVals[&PHI] = StartCST;
9838 }
9839 if (!CurrentIterVals.count(PN))
9840 return getCouldNotCompute();
9841
9842 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9843 // the loop symbolically to determine when the condition gets a value of
9844 // "ExitWhen".
9845 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9846 const DataLayout &DL = getDataLayout();
9847 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9848 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9849 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9850
9851 // Couldn't symbolically evaluate.
9852 if (!CondVal) return getCouldNotCompute();
9853
9854 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9855 ++NumBruteForceTripCountsComputed;
9856 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9857 }
9858
9859 // Update all the PHI nodes for the next iteration.
9860 DenseMap<Instruction *, Constant *> NextIterVals;
9861
9862 // Create a list of which PHIs we need to compute. We want to do this before
9863 // calling EvaluateExpression on them because that may invalidate iterators
9864 // into CurrentIterVals.
9865 SmallVector<PHINode *, 8> PHIsToCompute;
9866 for (const auto &I : CurrentIterVals) {
9867 PHINode *PHI = dyn_cast<PHINode>(I.first);
9868 if (!PHI || PHI->getParent() != Header) continue;
9869 PHIsToCompute.push_back(PHI);
9870 }
9871 for (PHINode *PHI : PHIsToCompute) {
9872 Constant *&NextPHI = NextIterVals[PHI];
9873 if (NextPHI) continue; // Already computed!
9874
9875 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9876 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9877 }
9878 CurrentIterVals.swap(NextIterVals);
9879 }
9880
9881 // Too many iterations were needed to evaluate.
9882 return getCouldNotCompute();
9883}
9884
9885const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9887 ValuesAtScopes[V];
9888 // Check to see if we've folded this expression at this loop before.
9889 for (auto &LS : Values)
9890 if (LS.first == L)
9891 return LS.second ? LS.second : V;
9892
9893 Values.emplace_back(L, nullptr);
9894
9895 // Otherwise compute it.
9896 const SCEV *C = computeSCEVAtScope(V, L);
9897 for (auto &LS : reverse(ValuesAtScopes[V]))
9898 if (LS.first == L) {
9899 LS.second = C;
9900 if (!isa<SCEVConstant>(C))
9901 ValuesAtScopesUsers[C].push_back({L, V});
9902 break;
9903 }
9904 return C;
9905}
9906
9907/// This builds up a Constant using the ConstantExpr interface. That way, we
9908/// will return Constants for objects which aren't represented by a
9909/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9910/// Returns NULL if the SCEV isn't representable as a Constant.
9912 switch (V->getSCEVType()) {
9913 case scCouldNotCompute:
9914 case scAddRecExpr:
9915 case scVScale:
9916 return nullptr;
9917 case scConstant:
9918 return cast<SCEVConstant>(V)->getValue();
9919 case scUnknown:
9920 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9921 case scPtrToInt: {
9923 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9924 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9925
9926 return nullptr;
9927 }
9928 case scTruncate: {
9930 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9931 return ConstantExpr::getTrunc(CastOp, ST->getType());
9932 return nullptr;
9933 }
9934 case scAddExpr: {
9935 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9936 Constant *C = nullptr;
9937 for (const SCEV *Op : SA->operands()) {
9939 if (!OpC)
9940 return nullptr;
9941 if (!C) {
9942 C = OpC;
9943 continue;
9944 }
9945 assert(!C->getType()->isPointerTy() &&
9946 "Can only have one pointer, and it must be last");
9947 if (OpC->getType()->isPointerTy()) {
9948 // The offsets have been converted to bytes. We can add bytes using
9949 // an i8 GEP.
9951 OpC, C);
9952 } else {
9953 C = ConstantExpr::getAdd(C, OpC);
9954 }
9955 }
9956 return C;
9957 }
9958 case scMulExpr:
9959 case scSignExtend:
9960 case scZeroExtend:
9961 case scUDivExpr:
9962 case scSMaxExpr:
9963 case scUMaxExpr:
9964 case scSMinExpr:
9965 case scUMinExpr:
9967 return nullptr;
9968 }
9969 llvm_unreachable("Unknown SCEV kind!");
9970}
9971
9972const SCEV *
9973ScalarEvolution::getWithOperands(const SCEV *S,
9974 SmallVectorImpl<const SCEV *> &NewOps) {
9975 switch (S->getSCEVType()) {
9976 case scTruncate:
9977 case scZeroExtend:
9978 case scSignExtend:
9979 case scPtrToInt:
9980 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9981 case scAddRecExpr: {
9982 auto *AddRec = cast<SCEVAddRecExpr>(S);
9983 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9984 }
9985 case scAddExpr:
9986 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9987 case scMulExpr:
9988 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9989 case scUDivExpr:
9990 return getUDivExpr(NewOps[0], NewOps[1]);
9991 case scUMaxExpr:
9992 case scSMaxExpr:
9993 case scUMinExpr:
9994 case scSMinExpr:
9995 return getMinMaxExpr(S->getSCEVType(), NewOps);
9997 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9998 case scConstant:
9999 case scVScale:
10000 case scUnknown:
10001 return S;
10002 case scCouldNotCompute:
10003 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10004 }
10005 llvm_unreachable("Unknown SCEV kind!");
10006}
10007
10008const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10009 switch (V->getSCEVType()) {
10010 case scConstant:
10011 case scVScale:
10012 return V;
10013 case scAddRecExpr: {
10014 // If this is a loop recurrence for a loop that does not contain L, then we
10015 // are dealing with the final value computed by the loop.
10016 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10017 // First, attempt to evaluate each operand.
10018 // Avoid performing the look-up in the common case where the specified
10019 // expression has no loop-variant portions.
10020 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10021 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10022 if (OpAtScope == AddRec->getOperand(i))
10023 continue;
10024
10025 // Okay, at least one of these operands is loop variant but might be
10026 // foldable. Build a new instance of the folded commutative expression.
10028 NewOps.reserve(AddRec->getNumOperands());
10029 append_range(NewOps, AddRec->operands().take_front(i));
10030 NewOps.push_back(OpAtScope);
10031 for (++i; i != e; ++i)
10032 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10033
10034 const SCEV *FoldedRec = getAddRecExpr(
10035 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10036 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10037 // The addrec may be folded to a nonrecurrence, for example, if the
10038 // induction variable is multiplied by zero after constant folding. Go
10039 // ahead and return the folded value.
10040 if (!AddRec)
10041 return FoldedRec;
10042 break;
10043 }
10044
10045 // If the scope is outside the addrec's loop, evaluate it by using the
10046 // loop exit value of the addrec.
10047 if (!AddRec->getLoop()->contains(L)) {
10048 // To evaluate this recurrence, we need to know how many times the AddRec
10049 // loop iterates. Compute this now.
10050 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10051 if (BackedgeTakenCount == getCouldNotCompute())
10052 return AddRec;
10053
10054 // Then, evaluate the AddRec.
10055 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10056 }
10057
10058 return AddRec;
10059 }
10060 case scTruncate:
10061 case scZeroExtend:
10062 case scSignExtend:
10063 case scPtrToInt:
10064 case scAddExpr:
10065 case scMulExpr:
10066 case scUDivExpr:
10067 case scUMaxExpr:
10068 case scSMaxExpr:
10069 case scUMinExpr:
10070 case scSMinExpr:
10071 case scSequentialUMinExpr: {
10072 ArrayRef<const SCEV *> Ops = V->operands();
10073 // Avoid performing the look-up in the common case where the specified
10074 // expression has no loop-variant portions.
10075 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10076 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10077 if (OpAtScope != Ops[i]) {
10078 // Okay, at least one of these operands is loop variant but might be
10079 // foldable. Build a new instance of the folded commutative expression.
10081 NewOps.reserve(Ops.size());
10082 append_range(NewOps, Ops.take_front(i));
10083 NewOps.push_back(OpAtScope);
10084
10085 for (++i; i != e; ++i) {
10086 OpAtScope = getSCEVAtScope(Ops[i], L);
10087 NewOps.push_back(OpAtScope);
10088 }
10089
10090 return getWithOperands(V, NewOps);
10091 }
10092 }
10093 // If we got here, all operands are loop invariant.
10094 return V;
10095 }
10096 case scUnknown: {
10097 // If this instruction is evolved from a constant-evolving PHI, compute the
10098 // exit value from the loop without using SCEVs.
10099 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10101 if (!I)
10102 return V; // This is some other type of SCEVUnknown, just return it.
10103
10104 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10105 const Loop *CurrLoop = this->LI[I->getParent()];
10106 // Looking for loop exit value.
10107 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10108 PN->getParent() == CurrLoop->getHeader()) {
10109 // Okay, there is no closed form solution for the PHI node. Check
10110 // to see if the loop that contains it has a known backedge-taken
10111 // count. If so, we may be able to force computation of the exit
10112 // value.
10113 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10114 // This trivial case can show up in some degenerate cases where
10115 // the incoming IR has not yet been fully simplified.
10116 if (BackedgeTakenCount->isZero()) {
10117 Value *InitValue = nullptr;
10118 bool MultipleInitValues = false;
10119 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10120 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10121 if (!InitValue)
10122 InitValue = PN->getIncomingValue(i);
10123 else if (InitValue != PN->getIncomingValue(i)) {
10124 MultipleInitValues = true;
10125 break;
10126 }
10127 }
10128 }
10129 if (!MultipleInitValues && InitValue)
10130 return getSCEV(InitValue);
10131 }
10132 // Do we have a loop invariant value flowing around the backedge
10133 // for a loop which must execute the backedge?
10134 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10135 isKnownNonZero(BackedgeTakenCount) &&
10136 PN->getNumIncomingValues() == 2) {
10137
10138 unsigned InLoopPred =
10139 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10140 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10141 if (CurrLoop->isLoopInvariant(BackedgeVal))
10142 return getSCEV(BackedgeVal);
10143 }
10144 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10145 // Okay, we know how many times the containing loop executes. If
10146 // this is a constant evolving PHI node, get the final value at
10147 // the specified iteration number.
10148 Constant *RV =
10149 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10150 if (RV)
10151 return getSCEV(RV);
10152 }
10153 }
10154 }
10155
10156 // Okay, this is an expression that we cannot symbolically evaluate
10157 // into a SCEV. Check to see if it's possible to symbolically evaluate
10158 // the arguments into constants, and if so, try to constant propagate the
10159 // result. This is particularly useful for computing loop exit values.
10160 if (!CanConstantFold(I))
10161 return V; // This is some other type of SCEVUnknown, just return it.
10162
10164 Operands.reserve(I->getNumOperands());
10165 bool MadeImprovement = false;
10166 for (Value *Op : I->operands()) {
10167 if (Constant *C = dyn_cast<Constant>(Op)) {
10168 Operands.push_back(C);
10169 continue;
10170 }
10171
10172 // If any of the operands is non-constant and if they are
10173 // non-integer and non-pointer, don't even try to analyze them
10174 // with scev techniques.
10175 if (!isSCEVable(Op->getType()))
10176 return V;
10177
10178 const SCEV *OrigV = getSCEV(Op);
10179 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10180 MadeImprovement |= OrigV != OpV;
10181
10183 if (!C)
10184 return V;
10185 assert(C->getType() == Op->getType() && "Type mismatch");
10186 Operands.push_back(C);
10187 }
10188
10189 // Check to see if getSCEVAtScope actually made an improvement.
10190 if (!MadeImprovement)
10191 return V; // This is some other type of SCEVUnknown, just return it.
10192
10193 Constant *C = nullptr;
10194 const DataLayout &DL = getDataLayout();
10195 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10196 /*AllowNonDeterministic=*/false);
10197 if (!C)
10198 return V;
10199 return getSCEV(C);
10200 }
10201 case scCouldNotCompute:
10202 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10203 }
10204 llvm_unreachable("Unknown SCEV type!");
10205}
10206
10208 return getSCEVAtScope(getSCEV(V), L);
10209}
10210
10211const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10213 return stripInjectiveFunctions(ZExt->getOperand());
10215 return stripInjectiveFunctions(SExt->getOperand());
10216 return S;
10217}
10218
10219/// Finds the minimum unsigned root of the following equation:
10220///
10221/// A * X = B (mod N)
10222///
10223/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10224/// A and B isn't important.
10225///
10226/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10227static const SCEV *
10230
10231 ScalarEvolution &SE) {
10232 uint32_t BW = A.getBitWidth();
10233 assert(BW == SE.getTypeSizeInBits(B->getType()));
10234 assert(A != 0 && "A must be non-zero.");
10235
10236 // 1. D = gcd(A, N)
10237 //
10238 // The gcd of A and N may have only one prime factor: 2. The number of
10239 // trailing zeros in A is its multiplicity
10240 uint32_t Mult2 = A.countr_zero();
10241 // D = 2^Mult2
10242
10243 // 2. Check if B is divisible by D.
10244 //
10245 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10246 // is not less than multiplicity of this prime factor for D.
10247 if (SE.getMinTrailingZeros(B) < Mult2) {
10248 // Check if we can prove there's no remainder using URem.
10249 const SCEV *URem =
10250 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10251 const SCEV *Zero = SE.getZero(B->getType());
10252 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10253 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10254 if (!Predicates)
10255 return SE.getCouldNotCompute();
10256
10257 // Avoid adding a predicate that is known to be false.
10258 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10259 return SE.getCouldNotCompute();
10260 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10261 }
10262 }
10263
10264 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10265 // modulo (N / D).
10266 //
10267 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10268 // (N / D) in general. The inverse itself always fits into BW bits, though,
10269 // so we immediately truncate it.
10270 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10271 APInt I = AD.multiplicativeInverse().zext(BW);
10272
10273 // 4. Compute the minimum unsigned root of the equation:
10274 // I * (B / D) mod (N / D)
10275 // To simplify the computation, we factor out the divide by D:
10276 // (I * B mod N) / D
10277 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10278 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10279}
10280
10281/// For a given quadratic addrec, generate coefficients of the corresponding
10282/// quadratic equation, multiplied by a common value to ensure that they are
10283/// integers.
10284/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10285/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10286/// were multiplied by, and BitWidth is the bit width of the original addrec
10287/// coefficients.
10288/// This function returns std::nullopt if the addrec coefficients are not
10289/// compile- time constants.
10290static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10292 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10293 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10294 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10295 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10296 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10297 << *AddRec << '\n');
10298
10299 // We currently can only solve this if the coefficients are constants.
10300 if (!LC || !MC || !NC) {
10301 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10302 return std::nullopt;
10303 }
10304
10305 APInt L = LC->getAPInt();
10306 APInt M = MC->getAPInt();
10307 APInt N = NC->getAPInt();
10308 assert(!N.isZero() && "This is not a quadratic addrec");
10309
10310 unsigned BitWidth = LC->getAPInt().getBitWidth();
10311 unsigned NewWidth = BitWidth + 1;
10312 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10313 << BitWidth << '\n');
10314 // The sign-extension (as opposed to a zero-extension) here matches the
10315 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10316 N = N.sext(NewWidth);
10317 M = M.sext(NewWidth);
10318 L = L.sext(NewWidth);
10319
10320 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10321 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10322 // L+M, L+2M+N, L+3M+3N, ...
10323 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10324 //
10325 // The equation Acc = 0 is then
10326 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10327 // In a quadratic form it becomes:
10328 // N n^2 + (2M-N) n + 2L = 0.
10329
10330 APInt A = N;
10331 APInt B = 2 * M - A;
10332 APInt C = 2 * L;
10333 APInt T = APInt(NewWidth, 2);
10334 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10335 << "x + " << C << ", coeff bw: " << NewWidth
10336 << ", multiplied by " << T << '\n');
10337 return std::make_tuple(A, B, C, T, BitWidth);
10338}
10339
10340/// Helper function to compare optional APInts:
10341/// (a) if X and Y both exist, return min(X, Y),
10342/// (b) if neither X nor Y exist, return std::nullopt,
10343/// (c) if exactly one of X and Y exists, return that value.
10344static std::optional<APInt> MinOptional(std::optional<APInt> X,
10345 std::optional<APInt> Y) {
10346 if (X && Y) {
10347 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10348 APInt XW = X->sext(W);
10349 APInt YW = Y->sext(W);
10350 return XW.slt(YW) ? *X : *Y;
10351 }
10352 if (!X && !Y)
10353 return std::nullopt;
10354 return X ? *X : *Y;
10355}
10356
10357/// Helper function to truncate an optional APInt to a given BitWidth.
10358/// When solving addrec-related equations, it is preferable to return a value
10359/// that has the same bit width as the original addrec's coefficients. If the
10360/// solution fits in the original bit width, truncate it (except for i1).
10361/// Returning a value of a different bit width may inhibit some optimizations.
10362///
10363/// In general, a solution to a quadratic equation generated from an addrec
10364/// may require BW+1 bits, where BW is the bit width of the addrec's
10365/// coefficients. The reason is that the coefficients of the quadratic
10366/// equation are BW+1 bits wide (to avoid truncation when converting from
10367/// the addrec to the equation).
10368static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10369 unsigned BitWidth) {
10370 if (!X)
10371 return std::nullopt;
10372 unsigned W = X->getBitWidth();
10374 return X->trunc(BitWidth);
10375 return X;
10376}
10377
10378/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10379/// iterations. The values L, M, N are assumed to be signed, and they
10380/// should all have the same bit widths.
10381/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10382/// where BW is the bit width of the addrec's coefficients.
10383/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10384/// returned as such, otherwise the bit width of the returned value may
10385/// be greater than BW.
10386///
10387/// This function returns std::nullopt if
10388/// (a) the addrec coefficients are not constant, or
10389/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10390/// like x^2 = 5, no integer solutions exist, in other cases an integer
10391/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10392static std::optional<APInt>
10394 APInt A, B, C, M;
10395 unsigned BitWidth;
10396 auto T = GetQuadraticEquation(AddRec);
10397 if (!T)
10398 return std::nullopt;
10399
10400 std::tie(A, B, C, M, BitWidth) = *T;
10401 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10402 std::optional<APInt> X =
10404 if (!X)
10405 return std::nullopt;
10406
10407 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10408 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10409 if (!V->isZero())
10410 return std::nullopt;
10411
10412 return TruncIfPossible(X, BitWidth);
10413}
10414
10415/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10416/// iterations. The values M, N are assumed to be signed, and they
10417/// should all have the same bit widths.
10418/// Find the least n such that c(n) does not belong to the given range,
10419/// while c(n-1) does.
10420///
10421/// This function returns std::nullopt if
10422/// (a) the addrec coefficients are not constant, or
10423/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10424/// bounds of the range.
10425static std::optional<APInt>
10427 const ConstantRange &Range, ScalarEvolution &SE) {
10428 assert(AddRec->getOperand(0)->isZero() &&
10429 "Starting value of addrec should be 0");
10430 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10431 << Range << ", addrec " << *AddRec << '\n');
10432 // This case is handled in getNumIterationsInRange. Here we can assume that
10433 // we start in the range.
10434 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10435 "Addrec's initial value should be in range");
10436
10437 APInt A, B, C, M;
10438 unsigned BitWidth;
10439 auto T = GetQuadraticEquation(AddRec);
10440 if (!T)
10441 return std::nullopt;
10442
10443 // Be careful about the return value: there can be two reasons for not
10444 // returning an actual number. First, if no solutions to the equations
10445 // were found, and second, if the solutions don't leave the given range.
10446 // The first case means that the actual solution is "unknown", the second
10447 // means that it's known, but not valid. If the solution is unknown, we
10448 // cannot make any conclusions.
10449 // Return a pair: the optional solution and a flag indicating if the
10450 // solution was found.
10451 auto SolveForBoundary =
10452 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10453 // Solve for signed overflow and unsigned overflow, pick the lower
10454 // solution.
10455 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10456 << Bound << " (before multiplying by " << M << ")\n");
10457 Bound *= M; // The quadratic equation multiplier.
10458
10459 std::optional<APInt> SO;
10460 if (BitWidth > 1) {
10461 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10462 "signed overflow\n");
10464 }
10465 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10466 "unsigned overflow\n");
10467 std::optional<APInt> UO =
10469
10470 auto LeavesRange = [&] (const APInt &X) {
10471 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10472 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10473 if (Range.contains(V0->getValue()))
10474 return false;
10475 // X should be at least 1, so X-1 is non-negative.
10476 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10477 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10478 if (Range.contains(V1->getValue()))
10479 return true;
10480 return false;
10481 };
10482
10483 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10484 // can be a solution, but the function failed to find it. We cannot treat it
10485 // as "no solution".
10486 if (!SO || !UO)
10487 return {std::nullopt, false};
10488
10489 // Check the smaller value first to see if it leaves the range.
10490 // At this point, both SO and UO must have values.
10491 std::optional<APInt> Min = MinOptional(SO, UO);
10492 if (LeavesRange(*Min))
10493 return { Min, true };
10494 std::optional<APInt> Max = Min == SO ? UO : SO;
10495 if (LeavesRange(*Max))
10496 return { Max, true };
10497
10498 // Solutions were found, but were eliminated, hence the "true".
10499 return {std::nullopt, true};
10500 };
10501
10502 std::tie(A, B, C, M, BitWidth) = *T;
10503 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10504 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10505 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10506 auto SL = SolveForBoundary(Lower);
10507 auto SU = SolveForBoundary(Upper);
10508 // If any of the solutions was unknown, no meaninigful conclusions can
10509 // be made.
10510 if (!SL.second || !SU.second)
10511 return std::nullopt;
10512
10513 // Claim: The correct solution is not some value between Min and Max.
10514 //
10515 // Justification: Assuming that Min and Max are different values, one of
10516 // them is when the first signed overflow happens, the other is when the
10517 // first unsigned overflow happens. Crossing the range boundary is only
10518 // possible via an overflow (treating 0 as a special case of it, modeling
10519 // an overflow as crossing k*2^W for some k).
10520 //
10521 // The interesting case here is when Min was eliminated as an invalid
10522 // solution, but Max was not. The argument is that if there was another
10523 // overflow between Min and Max, it would also have been eliminated if
10524 // it was considered.
10525 //
10526 // For a given boundary, it is possible to have two overflows of the same
10527 // type (signed/unsigned) without having the other type in between: this
10528 // can happen when the vertex of the parabola is between the iterations
10529 // corresponding to the overflows. This is only possible when the two
10530 // overflows cross k*2^W for the same k. In such case, if the second one
10531 // left the range (and was the first one to do so), the first overflow
10532 // would have to enter the range, which would mean that either we had left
10533 // the range before or that we started outside of it. Both of these cases
10534 // are contradictions.
10535 //
10536 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10537 // solution is not some value between the Max for this boundary and the
10538 // Min of the other boundary.
10539 //
10540 // Justification: Assume that we had such Max_A and Min_B corresponding
10541 // to range boundaries A and B and such that Max_A < Min_B. If there was
10542 // a solution between Max_A and Min_B, it would have to be caused by an
10543 // overflow corresponding to either A or B. It cannot correspond to B,
10544 // since Min_B is the first occurrence of such an overflow. If it
10545 // corresponded to A, it would have to be either a signed or an unsigned
10546 // overflow that is larger than both eliminated overflows for A. But
10547 // between the eliminated overflows and this overflow, the values would
10548 // cover the entire value space, thus crossing the other boundary, which
10549 // is a contradiction.
10550
10551 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10552}
10553
10554ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10555 const Loop *L,
10556 bool ControlsOnlyExit,
10557 bool AllowPredicates) {
10558
10559 // This is only used for loops with a "x != y" exit test. The exit condition
10560 // is now expressed as a single expression, V = x-y. So the exit test is
10561 // effectively V != 0. We know and take advantage of the fact that this
10562 // expression only being used in a comparison by zero context.
10563
10565 // If the value is a constant
10566 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10567 // If the value is already zero, the branch will execute zero times.
10568 if (C->getValue()->isZero()) return C;
10569 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10570 }
10571
10572 const SCEVAddRecExpr *AddRec =
10573 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10574
10575 if (!AddRec && AllowPredicates)
10576 // Try to make this an AddRec using runtime tests, in the first X
10577 // iterations of this loop, where X is the SCEV expression found by the
10578 // algorithm below.
10579 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10580
10581 if (!AddRec || AddRec->getLoop() != L)
10582 return getCouldNotCompute();
10583
10584 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10585 // the quadratic equation to solve it.
10586 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10587 // We can only use this value if the chrec ends up with an exact zero
10588 // value at this index. When solving for "X*X != 5", for example, we
10589 // should not accept a root of 2.
10590 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10591 const auto *R = cast<SCEVConstant>(getConstant(*S));
10592 return ExitLimit(R, R, R, false, Predicates);
10593 }
10594 return getCouldNotCompute();
10595 }
10596
10597 // Otherwise we can only handle this if it is affine.
10598 if (!AddRec->isAffine())
10599 return getCouldNotCompute();
10600
10601 // If this is an affine expression, the execution count of this branch is
10602 // the minimum unsigned root of the following equation:
10603 //
10604 // Start + Step*N = 0 (mod 2^BW)
10605 //
10606 // equivalent to:
10607 //
10608 // Step*N = -Start (mod 2^BW)
10609 //
10610 // where BW is the common bit width of Start and Step.
10611
10612 // Get the initial value for the loop.
10613 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10614 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10615
10616 if (!isLoopInvariant(Step, L))
10617 return getCouldNotCompute();
10618
10619 LoopGuards Guards = LoopGuards::collect(L, *this);
10620 // Specialize step for this loop so we get context sensitive facts below.
10621 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10622
10623 // For positive steps (counting up until unsigned overflow):
10624 // N = -Start/Step (as unsigned)
10625 // For negative steps (counting down to zero):
10626 // N = Start/-Step
10627 // First compute the unsigned distance from zero in the direction of Step.
10628 bool CountDown = isKnownNegative(StepWLG);
10629 if (!CountDown && !isKnownNonNegative(StepWLG))
10630 return getCouldNotCompute();
10631
10632 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10633 // Handle unitary steps, which cannot wraparound.
10634 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10635 // N = Distance (as unsigned)
10636
10637 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10638 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10639 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10640
10641 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10642 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10643 // case, and see if we can improve the bound.
10644 //
10645 // Explicitly handling this here is necessary because getUnsignedRange
10646 // isn't context-sensitive; it doesn't know that we only care about the
10647 // range inside the loop.
10648 const SCEV *Zero = getZero(Distance->getType());
10649 const SCEV *One = getOne(Distance->getType());
10650 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10651 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10652 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10653 // as "unsigned_max(Distance + 1) - 1".
10654 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10655 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10656 }
10657 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10658 Predicates);
10659 }
10660
10661 // If the condition controls loop exit (the loop exits only if the expression
10662 // is true) and the addition is no-wrap we can use unsigned divide to
10663 // compute the backedge count. In this case, the step may not divide the
10664 // distance, but we don't care because if the condition is "missed" the loop
10665 // will have undefined behavior due to wrapping.
10666 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10667 loopHasNoAbnormalExits(AddRec->getLoop())) {
10668
10669 // If the stride is zero and the start is non-zero, the loop must be
10670 // infinite. In C++, most loops are finite by assumption, in which case the
10671 // step being zero implies UB must execute if the loop is entered.
10672 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10673 !isKnownNonZero(StepWLG))
10674 return getCouldNotCompute();
10675
10676 const SCEV *Exact =
10677 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10678 const SCEV *ConstantMax = getCouldNotCompute();
10679 if (Exact != getCouldNotCompute()) {
10680 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10681 ConstantMax =
10683 }
10684 const SCEV *SymbolicMax =
10685 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10686 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10687 }
10688
10689 // Solve the general equation.
10690 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10691 if (!StepC || StepC->getValue()->isZero())
10692 return getCouldNotCompute();
10693 const SCEV *E = SolveLinEquationWithOverflow(
10694 StepC->getAPInt(), getNegativeSCEV(Start),
10695 AllowPredicates ? &Predicates : nullptr, *this);
10696
10697 const SCEV *M = E;
10698 if (E != getCouldNotCompute()) {
10699 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10700 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10701 }
10702 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10703 return ExitLimit(E, M, S, false, Predicates);
10704}
10705
10706ScalarEvolution::ExitLimit
10707ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10708 // Loops that look like: while (X == 0) are very strange indeed. We don't
10709 // handle them yet except for the trivial case. This could be expanded in the
10710 // future as needed.
10711
10712 // If the value is a constant, check to see if it is known to be non-zero
10713 // already. If so, the backedge will execute zero times.
10714 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10715 if (!C->getValue()->isZero())
10716 return getZero(C->getType());
10717 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10718 }
10719
10720 // We could implement others, but I really doubt anyone writes loops like
10721 // this, and if they did, they would already be constant folded.
10722 return getCouldNotCompute();
10723}
10724
10725std::pair<const BasicBlock *, const BasicBlock *>
10726ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10727 const {
10728 // If the block has a unique predecessor, then there is no path from the
10729 // predecessor to the block that does not go through the direct edge
10730 // from the predecessor to the block.
10731 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10732 return {Pred, BB};
10733
10734 // A loop's header is defined to be a block that dominates the loop.
10735 // If the header has a unique predecessor outside the loop, it must be
10736 // a block that has exactly one successor that can reach the loop.
10737 if (const Loop *L = LI.getLoopFor(BB))
10738 return {L->getLoopPredecessor(), L->getHeader()};
10739
10740 return {nullptr, BB};
10741}
10742
10743/// SCEV structural equivalence is usually sufficient for testing whether two
10744/// expressions are equal, however for the purposes of looking for a condition
10745/// guarding a loop, it can be useful to be a little more general, since a
10746/// front-end may have replicated the controlling expression.
10747static bool HasSameValue(const SCEV *A, const SCEV *B) {
10748 // Quick check to see if they are the same SCEV.
10749 if (A == B) return true;
10750
10751 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10752 // Not all instructions that are "identical" compute the same value. For
10753 // instance, two distinct alloca instructions allocating the same type are
10754 // identical and do not read memory; but compute distinct values.
10755 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10756 };
10757
10758 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10759 // two different instructions with the same value. Check for this case.
10760 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10761 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10762 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10763 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10764 if (ComputesEqualValues(AI, BI))
10765 return true;
10766
10767 // Otherwise assume they may have a different value.
10768 return false;
10769}
10770
10771static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10773 if (!Add || Add->getNumOperands() != 2)
10774 return false;
10775 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10776 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10777 LHS = Add->getOperand(1);
10778 RHS = ME->getOperand(1);
10779 return true;
10780 }
10781 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10782 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10783 LHS = Add->getOperand(0);
10784 RHS = ME->getOperand(1);
10785 return true;
10786 }
10787 return false;
10788}
10789
10791 const SCEV *&RHS, unsigned Depth) {
10792 bool Changed = false;
10793 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10794 // '0 != 0'.
10795 auto TrivialCase = [&](bool TriviallyTrue) {
10797 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10798 return true;
10799 };
10800 // If we hit the max recursion limit bail out.
10801 if (Depth >= 3)
10802 return false;
10803
10804 // Canonicalize a constant to the right side.
10805 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10806 // Check for both operands constant.
10807 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10808 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10809 return TrivialCase(false);
10810 return TrivialCase(true);
10811 }
10812 // Otherwise swap the operands to put the constant on the right.
10813 std::swap(LHS, RHS);
10815 Changed = true;
10816 }
10817
10818 // If we're comparing an addrec with a value which is loop-invariant in the
10819 // addrec's loop, put the addrec on the left. Also make a dominance check,
10820 // as both operands could be addrecs loop-invariant in each other's loop.
10821 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10822 const Loop *L = AR->getLoop();
10823 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10824 std::swap(LHS, RHS);
10826 Changed = true;
10827 }
10828 }
10829
10830 // If there's a constant operand, canonicalize comparisons with boundary
10831 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10832 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10833 const APInt &RA = RC->getAPInt();
10834
10835 bool SimplifiedByConstantRange = false;
10836
10837 if (!ICmpInst::isEquality(Pred)) {
10839 if (ExactCR.isFullSet())
10840 return TrivialCase(true);
10841 if (ExactCR.isEmptySet())
10842 return TrivialCase(false);
10843
10844 APInt NewRHS;
10845 CmpInst::Predicate NewPred;
10846 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10847 ICmpInst::isEquality(NewPred)) {
10848 // We were able to convert an inequality to an equality.
10849 Pred = NewPred;
10850 RHS = getConstant(NewRHS);
10851 Changed = SimplifiedByConstantRange = true;
10852 }
10853 }
10854
10855 if (!SimplifiedByConstantRange) {
10856 switch (Pred) {
10857 default:
10858 break;
10859 case ICmpInst::ICMP_EQ:
10860 case ICmpInst::ICMP_NE:
10861 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10862 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10863 Changed = true;
10864 break;
10865
10866 // The "Should have been caught earlier!" messages refer to the fact
10867 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10868 // should have fired on the corresponding cases, and canonicalized the
10869 // check to trivial case.
10870
10871 case ICmpInst::ICMP_UGE:
10872 assert(!RA.isMinValue() && "Should have been caught earlier!");
10873 Pred = ICmpInst::ICMP_UGT;
10874 RHS = getConstant(RA - 1);
10875 Changed = true;
10876 break;
10877 case ICmpInst::ICMP_ULE:
10878 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10879 Pred = ICmpInst::ICMP_ULT;
10880 RHS = getConstant(RA + 1);
10881 Changed = true;
10882 break;
10883 case ICmpInst::ICMP_SGE:
10884 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10885 Pred = ICmpInst::ICMP_SGT;
10886 RHS = getConstant(RA - 1);
10887 Changed = true;
10888 break;
10889 case ICmpInst::ICMP_SLE:
10890 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10891 Pred = ICmpInst::ICMP_SLT;
10892 RHS = getConstant(RA + 1);
10893 Changed = true;
10894 break;
10895 }
10896 }
10897 }
10898
10899 // Check for obvious equality.
10900 if (HasSameValue(LHS, RHS)) {
10901 if (ICmpInst::isTrueWhenEqual(Pred))
10902 return TrivialCase(true);
10904 return TrivialCase(false);
10905 }
10906
10907 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10908 // adding or subtracting 1 from one of the operands.
10909 switch (Pred) {
10910 case ICmpInst::ICMP_SLE:
10911 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10912 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10914 Pred = ICmpInst::ICMP_SLT;
10915 Changed = true;
10916 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10917 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10919 Pred = ICmpInst::ICMP_SLT;
10920 Changed = true;
10921 }
10922 break;
10923 case ICmpInst::ICMP_SGE:
10924 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10925 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10927 Pred = ICmpInst::ICMP_SGT;
10928 Changed = true;
10929 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10930 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10932 Pred = ICmpInst::ICMP_SGT;
10933 Changed = true;
10934 }
10935 break;
10936 case ICmpInst::ICMP_ULE:
10937 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10938 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10940 Pred = ICmpInst::ICMP_ULT;
10941 Changed = true;
10942 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10943 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10944 Pred = ICmpInst::ICMP_ULT;
10945 Changed = true;
10946 }
10947 break;
10948 case ICmpInst::ICMP_UGE:
10949 // If RHS is an op we can fold the -1, try that first.
10950 // Otherwise prefer LHS to preserve the nuw flag.
10951 if ((isa<SCEVConstant>(RHS) ||
10953 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10954 !getUnsignedRangeMin(RHS).isMinValue()) {
10955 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10956 Pred = ICmpInst::ICMP_UGT;
10957 Changed = true;
10958 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10959 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10961 Pred = ICmpInst::ICMP_UGT;
10962 Changed = true;
10963 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
10964 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10965 Pred = ICmpInst::ICMP_UGT;
10966 Changed = true;
10967 }
10968 break;
10969 default:
10970 break;
10971 }
10972
10973 // TODO: More simplifications are possible here.
10974
10975 // Recursively simplify until we either hit a recursion limit or nothing
10976 // changes.
10977 if (Changed)
10978 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10979
10980 return Changed;
10981}
10982
10984 return getSignedRangeMax(S).isNegative();
10985}
10986
10990
10992 return !getSignedRangeMin(S).isNegative();
10993}
10994
10998
11000 // Query push down for cases where the unsigned range is
11001 // less than sufficient.
11002 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11003 return isKnownNonZero(SExt->getOperand(0));
11004 return getUnsignedRangeMin(S) != 0;
11005}
11006
11008 bool OrNegative) {
11009 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11010 if (auto *C = dyn_cast<SCEVConstant>(S))
11011 return C->getAPInt().isPowerOf2() ||
11012 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11013
11014 // The vscale_range indicates vscale is a power-of-two.
11015 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11016 };
11017
11018 if (NonRecursive(S))
11019 return true;
11020
11021 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11022 if (!Mul)
11023 return false;
11024 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11025}
11026
11028 const SCEV *S, uint64_t M,
11030 if (M == 0)
11031 return false;
11032 if (M == 1)
11033 return true;
11034
11035 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11036 // starts with a multiple of M and at every iteration step S only adds
11037 // multiples of M.
11038 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11039 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11040 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11041
11042 // For a constant, check that "S % M == 0".
11043 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11044 APInt C = Cst->getAPInt();
11045 return C.urem(M) == 0;
11046 }
11047
11048 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11049
11050 // Basic tests have failed.
11051 // Check "S % M == 0" at compile time and record runtime Assumptions.
11052 auto *STy = dyn_cast<IntegerType>(S->getType());
11053 const SCEV *SmodM =
11054 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11055 const SCEV *Zero = getZero(STy);
11056
11057 // Check whether "S % M == 0" is known at compile time.
11058 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11059 return true;
11060
11061 // Check whether "S % M != 0" is known at compile time.
11062 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11063 return false;
11064
11066
11067 // Detect redundant predicates.
11068 for (auto *A : Assumptions)
11069 if (A->implies(P, *this))
11070 return true;
11071
11072 // Only record non-redundant predicates.
11073 Assumptions.push_back(P);
11074 return true;
11075}
11076
11077std::pair<const SCEV *, const SCEV *>
11079 // Compute SCEV on entry of loop L.
11080 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11081 if (Start == getCouldNotCompute())
11082 return { Start, Start };
11083 // Compute post increment SCEV for loop L.
11084 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11085 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11086 return { Start, PostInc };
11087}
11088
11090 const SCEV *RHS) {
11091 // First collect all loops.
11093 getUsedLoops(LHS, LoopsUsed);
11094 getUsedLoops(RHS, LoopsUsed);
11095
11096 if (LoopsUsed.empty())
11097 return false;
11098
11099 // Domination relationship must be a linear order on collected loops.
11100#ifndef NDEBUG
11101 for (const auto *L1 : LoopsUsed)
11102 for (const auto *L2 : LoopsUsed)
11103 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11104 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11105 "Domination relationship is not a linear order");
11106#endif
11107
11108 const Loop *MDL =
11109 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11110 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11111 });
11112
11113 // Get init and post increment value for LHS.
11114 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11115 // if LHS contains unknown non-invariant SCEV then bail out.
11116 if (SplitLHS.first == getCouldNotCompute())
11117 return false;
11118 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11119 // Get init and post increment value for RHS.
11120 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11121 // if RHS contains unknown non-invariant SCEV then bail out.
11122 if (SplitRHS.first == getCouldNotCompute())
11123 return false;
11124 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11125 // It is possible that init SCEV contains an invariant load but it does
11126 // not dominate MDL and is not available at MDL loop entry, so we should
11127 // check it here.
11128 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11129 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11130 return false;
11131
11132 // It seems backedge guard check is faster than entry one so in some cases
11133 // it can speed up whole estimation by short circuit
11134 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11135 SplitRHS.second) &&
11136 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11137}
11138
11140 const SCEV *RHS) {
11141 // Canonicalize the inputs first.
11142 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11143
11144 if (isKnownViaInduction(Pred, LHS, RHS))
11145 return true;
11146
11147 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11148 return true;
11149
11150 // Otherwise see what can be done with some simple reasoning.
11151 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11152}
11153
11155 const SCEV *LHS,
11156 const SCEV *RHS) {
11157 if (isKnownPredicate(Pred, LHS, RHS))
11158 return true;
11160 return false;
11161 return std::nullopt;
11162}
11163
11165 const SCEV *RHS,
11166 const Instruction *CtxI) {
11167 // TODO: Analyze guards and assumes from Context's block.
11168 return isKnownPredicate(Pred, LHS, RHS) ||
11169 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11170}
11171
11172std::optional<bool>
11174 const SCEV *RHS, const Instruction *CtxI) {
11175 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11176 if (KnownWithoutContext)
11177 return KnownWithoutContext;
11178
11179 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11180 return true;
11182 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11183 return false;
11184 return std::nullopt;
11185}
11186
11188 const SCEVAddRecExpr *LHS,
11189 const SCEV *RHS) {
11190 const Loop *L = LHS->getLoop();
11191 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11192 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11193}
11194
11195std::optional<ScalarEvolution::MonotonicPredicateType>
11197 ICmpInst::Predicate Pred) {
11198 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11199
11200#ifndef NDEBUG
11201 // Verify an invariant: inverting the predicate should turn a monotonically
11202 // increasing change to a monotonically decreasing one, and vice versa.
11203 if (Result) {
11204 auto ResultSwapped =
11205 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11206
11207 assert(*ResultSwapped != *Result &&
11208 "monotonicity should flip as we flip the predicate");
11209 }
11210#endif
11211
11212 return Result;
11213}
11214
11215std::optional<ScalarEvolution::MonotonicPredicateType>
11216ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11217 ICmpInst::Predicate Pred) {
11218 // A zero step value for LHS means the induction variable is essentially a
11219 // loop invariant value. We don't really depend on the predicate actually
11220 // flipping from false to true (for increasing predicates, and the other way
11221 // around for decreasing predicates), all we care about is that *if* the
11222 // predicate changes then it only changes from false to true.
11223 //
11224 // A zero step value in itself is not very useful, but there may be places
11225 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11226 // as general as possible.
11227
11228 // Only handle LE/LT/GE/GT predicates.
11229 if (!ICmpInst::isRelational(Pred))
11230 return std::nullopt;
11231
11232 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11233 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11234 "Should be greater or less!");
11235
11236 // Check that AR does not wrap.
11237 if (ICmpInst::isUnsigned(Pred)) {
11238 if (!LHS->hasNoUnsignedWrap())
11239 return std::nullopt;
11241 }
11242 assert(ICmpInst::isSigned(Pred) &&
11243 "Relational predicate is either signed or unsigned!");
11244 if (!LHS->hasNoSignedWrap())
11245 return std::nullopt;
11246
11247 const SCEV *Step = LHS->getStepRecurrence(*this);
11248
11249 if (isKnownNonNegative(Step))
11251
11252 if (isKnownNonPositive(Step))
11254
11255 return std::nullopt;
11256}
11257
11258std::optional<ScalarEvolution::LoopInvariantPredicate>
11260 const SCEV *RHS, const Loop *L,
11261 const Instruction *CtxI) {
11262 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11263 if (!isLoopInvariant(RHS, L)) {
11264 if (!isLoopInvariant(LHS, L))
11265 return std::nullopt;
11266
11267 std::swap(LHS, RHS);
11269 }
11270
11271 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11272 if (!ArLHS || ArLHS->getLoop() != L)
11273 return std::nullopt;
11274
11275 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11276 if (!MonotonicType)
11277 return std::nullopt;
11278 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11279 // true as the loop iterates, and the backedge is control dependent on
11280 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11281 //
11282 // * if the predicate was false in the first iteration then the predicate
11283 // is never evaluated again, since the loop exits without taking the
11284 // backedge.
11285 // * if the predicate was true in the first iteration then it will
11286 // continue to be true for all future iterations since it is
11287 // monotonically increasing.
11288 //
11289 // For both the above possibilities, we can replace the loop varying
11290 // predicate with its value on the first iteration of the loop (which is
11291 // loop invariant).
11292 //
11293 // A similar reasoning applies for a monotonically decreasing predicate, by
11294 // replacing true with false and false with true in the above two bullets.
11296 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11297
11298 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11300 RHS);
11301
11302 if (!CtxI)
11303 return std::nullopt;
11304 // Try to prove via context.
11305 // TODO: Support other cases.
11306 switch (Pred) {
11307 default:
11308 break;
11309 case ICmpInst::ICMP_ULE:
11310 case ICmpInst::ICMP_ULT: {
11311 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11312 // Given preconditions
11313 // (1) ArLHS does not cross the border of positive and negative parts of
11314 // range because of:
11315 // - Positive step; (TODO: lift this limitation)
11316 // - nuw - does not cross zero boundary;
11317 // - nsw - does not cross SINT_MAX boundary;
11318 // (2) ArLHS <s RHS
11319 // (3) RHS >=s 0
11320 // we can replace the loop variant ArLHS <u RHS condition with loop
11321 // invariant Start(ArLHS) <u RHS.
11322 //
11323 // Because of (1) there are two options:
11324 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11325 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11326 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11327 // Because of (2) ArLHS <u RHS is trivially true.
11328 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11329 // We can strengthen this to Start(ArLHS) <u RHS.
11330 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11331 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11332 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11333 isKnownNonNegative(RHS) &&
11334 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11336 RHS);
11337 }
11338 }
11339
11340 return std::nullopt;
11341}
11342
11343std::optional<ScalarEvolution::LoopInvariantPredicate>
11345 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11346 const Instruction *CtxI, const SCEV *MaxIter) {
11348 Pred, LHS, RHS, L, CtxI, MaxIter))
11349 return LIP;
11350 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11351 // Number of iterations expressed as UMIN isn't always great for expressing
11352 // the value on the last iteration. If the straightforward approach didn't
11353 // work, try the following trick: if the a predicate is invariant for X, it
11354 // is also invariant for umin(X, ...). So try to find something that works
11355 // among subexpressions of MaxIter expressed as umin.
11356 for (auto *Op : UMin->operands())
11358 Pred, LHS, RHS, L, CtxI, Op))
11359 return LIP;
11360 return std::nullopt;
11361}
11362
11363std::optional<ScalarEvolution::LoopInvariantPredicate>
11365 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11366 const Instruction *CtxI, const SCEV *MaxIter) {
11367 // Try to prove the following set of facts:
11368 // - The predicate is monotonic in the iteration space.
11369 // - If the check does not fail on the 1st iteration:
11370 // - No overflow will happen during first MaxIter iterations;
11371 // - It will not fail on the MaxIter'th iteration.
11372 // If the check does fail on the 1st iteration, we leave the loop and no
11373 // other checks matter.
11374
11375 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11376 if (!isLoopInvariant(RHS, L)) {
11377 if (!isLoopInvariant(LHS, L))
11378 return std::nullopt;
11379
11380 std::swap(LHS, RHS);
11382 }
11383
11384 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11385 if (!AR || AR->getLoop() != L)
11386 return std::nullopt;
11387
11388 // The predicate must be relational (i.e. <, <=, >=, >).
11389 if (!ICmpInst::isRelational(Pred))
11390 return std::nullopt;
11391
11392 // TODO: Support steps other than +/- 1.
11393 const SCEV *Step = AR->getStepRecurrence(*this);
11394 auto *One = getOne(Step->getType());
11395 auto *MinusOne = getNegativeSCEV(One);
11396 if (Step != One && Step != MinusOne)
11397 return std::nullopt;
11398
11399 // Type mismatch here means that MaxIter is potentially larger than max
11400 // unsigned value in start type, which mean we cannot prove no wrap for the
11401 // indvar.
11402 if (AR->getType() != MaxIter->getType())
11403 return std::nullopt;
11404
11405 // Value of IV on suggested last iteration.
11406 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11407 // Does it still meet the requirement?
11408 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11409 return std::nullopt;
11410 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11411 // not exceed max unsigned value of this type), this effectively proves
11412 // that there is no wrap during the iteration. To prove that there is no
11413 // signed/unsigned wrap, we need to check that
11414 // Start <= Last for step = 1 or Start >= Last for step = -1.
11415 ICmpInst::Predicate NoOverflowPred =
11417 if (Step == MinusOne)
11418 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11419 const SCEV *Start = AR->getStart();
11420 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11421 return std::nullopt;
11422
11423 // Everything is fine.
11424 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11425}
11426
11427bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11428 const SCEV *LHS,
11429 const SCEV *RHS) {
11430 if (HasSameValue(LHS, RHS))
11431 return ICmpInst::isTrueWhenEqual(Pred);
11432
11433 auto CheckRange = [&](bool IsSigned) {
11434 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11435 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11436 return RangeLHS.icmp(Pred, RangeRHS);
11437 };
11438
11439 // The check at the top of the function catches the case where the values are
11440 // known to be equal.
11441 if (Pred == CmpInst::ICMP_EQ)
11442 return false;
11443
11444 if (Pred == CmpInst::ICMP_NE) {
11445 if (CheckRange(true) || CheckRange(false))
11446 return true;
11447 auto *Diff = getMinusSCEV(LHS, RHS);
11448 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11449 }
11450
11451 return CheckRange(CmpInst::isSigned(Pred));
11452}
11453
11454bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11455 const SCEV *LHS,
11456 const SCEV *RHS) {
11457 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11458 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11459 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11460 // OutC1 and OutC2.
11461 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11462 APInt &OutC1, APInt &OutC2,
11463 SCEV::NoWrapFlags ExpectedFlags) {
11464 const SCEV *XNonConstOp, *XConstOp;
11465 const SCEV *YNonConstOp, *YConstOp;
11466 SCEV::NoWrapFlags XFlagsPresent;
11467 SCEV::NoWrapFlags YFlagsPresent;
11468
11469 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11470 XConstOp = getZero(X->getType());
11471 XNonConstOp = X;
11472 XFlagsPresent = ExpectedFlags;
11473 }
11474 if (!isa<SCEVConstant>(XConstOp))
11475 return false;
11476
11477 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11478 YConstOp = getZero(Y->getType());
11479 YNonConstOp = Y;
11480 YFlagsPresent = ExpectedFlags;
11481 }
11482
11483 if (YNonConstOp != XNonConstOp)
11484 return false;
11485
11486 if (!isa<SCEVConstant>(YConstOp))
11487 return false;
11488
11489 // When matching ADDs with NUW flags (and unsigned predicates), only the
11490 // second ADD (with the larger constant) requires NUW.
11491 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11492 return false;
11493 if (ExpectedFlags != SCEV::FlagNUW &&
11494 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11495 return false;
11496 }
11497
11498 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11499 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11500
11501 return true;
11502 };
11503
11504 APInt C1;
11505 APInt C2;
11506
11507 switch (Pred) {
11508 default:
11509 break;
11510
11511 case ICmpInst::ICMP_SGE:
11512 std::swap(LHS, RHS);
11513 [[fallthrough]];
11514 case ICmpInst::ICMP_SLE:
11515 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11516 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11517 return true;
11518
11519 break;
11520
11521 case ICmpInst::ICMP_SGT:
11522 std::swap(LHS, RHS);
11523 [[fallthrough]];
11524 case ICmpInst::ICMP_SLT:
11525 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11526 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11527 return true;
11528
11529 break;
11530
11531 case ICmpInst::ICMP_UGE:
11532 std::swap(LHS, RHS);
11533 [[fallthrough]];
11534 case ICmpInst::ICMP_ULE:
11535 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11536 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11537 return true;
11538
11539 break;
11540
11541 case ICmpInst::ICMP_UGT:
11542 std::swap(LHS, RHS);
11543 [[fallthrough]];
11544 case ICmpInst::ICMP_ULT:
11545 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11546 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11547 return true;
11548 break;
11549 }
11550
11551 return false;
11552}
11553
11554bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11555 const SCEV *LHS,
11556 const SCEV *RHS) {
11557 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11558 return false;
11559
11560 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11561 // the stack can result in exponential time complexity.
11562 SaveAndRestore Restore(ProvingSplitPredicate, true);
11563
11564 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11565 //
11566 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11567 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11568 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11569 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11570 // use isKnownPredicate later if needed.
11571 return isKnownNonNegative(RHS) &&
11574}
11575
11576bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11577 const SCEV *LHS, const SCEV *RHS) {
11578 // No need to even try if we know the module has no guards.
11579 if (!HasGuards)
11580 return false;
11581
11582 return any_of(*BB, [&](const Instruction &I) {
11583 using namespace llvm::PatternMatch;
11584
11585 Value *Condition;
11587 m_Value(Condition))) &&
11588 isImpliedCond(Pred, LHS, RHS, Condition, false);
11589 });
11590}
11591
11592/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11593/// protected by a conditional between LHS and RHS. This is used to
11594/// to eliminate casts.
11596 CmpPredicate Pred,
11597 const SCEV *LHS,
11598 const SCEV *RHS) {
11599 // Interpret a null as meaning no loop, where there is obviously no guard
11600 // (interprocedural conditions notwithstanding). Do not bother about
11601 // unreachable loops.
11602 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11603 return true;
11604
11605 if (VerifyIR)
11606 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11607 "This cannot be done on broken IR!");
11608
11609
11610 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11611 return true;
11612
11613 BasicBlock *Latch = L->getLoopLatch();
11614 if (!Latch)
11615 return false;
11616
11617 BranchInst *LoopContinuePredicate =
11619 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11620 isImpliedCond(Pred, LHS, RHS,
11621 LoopContinuePredicate->getCondition(),
11622 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11623 return true;
11624
11625 // We don't want more than one activation of the following loops on the stack
11626 // -- that can lead to O(n!) time complexity.
11627 if (WalkingBEDominatingConds)
11628 return false;
11629
11630 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11631
11632 // See if we can exploit a trip count to prove the predicate.
11633 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11634 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11635 if (LatchBECount != getCouldNotCompute()) {
11636 // We know that Latch branches back to the loop header exactly
11637 // LatchBECount times. This means the backdege condition at Latch is
11638 // equivalent to "{0,+,1} u< LatchBECount".
11639 Type *Ty = LatchBECount->getType();
11640 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11641 const SCEV *LoopCounter =
11642 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11643 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11644 LatchBECount))
11645 return true;
11646 }
11647
11648 // Check conditions due to any @llvm.assume intrinsics.
11649 for (auto &AssumeVH : AC.assumptions()) {
11650 if (!AssumeVH)
11651 continue;
11652 auto *CI = cast<CallInst>(AssumeVH);
11653 if (!DT.dominates(CI, Latch->getTerminator()))
11654 continue;
11655
11656 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11657 return true;
11658 }
11659
11660 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11661 return true;
11662
11663 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11664 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11665 assert(DTN && "should reach the loop header before reaching the root!");
11666
11667 BasicBlock *BB = DTN->getBlock();
11668 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11669 return true;
11670
11671 BasicBlock *PBB = BB->getSinglePredecessor();
11672 if (!PBB)
11673 continue;
11674
11675 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11676 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11677 continue;
11678
11679 Value *Condition = ContinuePredicate->getCondition();
11680
11681 // If we have an edge `E` within the loop body that dominates the only
11682 // latch, the condition guarding `E` also guards the backedge. This
11683 // reasoning works only for loops with a single latch.
11684
11685 BasicBlockEdge DominatingEdge(PBB, BB);
11686 if (DominatingEdge.isSingleEdge()) {
11687 // We're constructively (and conservatively) enumerating edges within the
11688 // loop body that dominate the latch. The dominator tree better agree
11689 // with us on this:
11690 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11691
11692 if (isImpliedCond(Pred, LHS, RHS, Condition,
11693 BB != ContinuePredicate->getSuccessor(0)))
11694 return true;
11695 }
11696 }
11697
11698 return false;
11699}
11700
11702 CmpPredicate Pred,
11703 const SCEV *LHS,
11704 const SCEV *RHS) {
11705 // Do not bother proving facts for unreachable code.
11706 if (!DT.isReachableFromEntry(BB))
11707 return true;
11708 if (VerifyIR)
11709 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11710 "This cannot be done on broken IR!");
11711
11712 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11713 // the facts (a >= b && a != b) separately. A typical situation is when the
11714 // non-strict comparison is known from ranges and non-equality is known from
11715 // dominating predicates. If we are proving strict comparison, we always try
11716 // to prove non-equality and non-strict comparison separately.
11717 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11718 const bool ProvingStrictComparison =
11719 Pred != NonStrictPredicate.dropSameSign();
11720 bool ProvedNonStrictComparison = false;
11721 bool ProvedNonEquality = false;
11722
11723 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11724 if (!ProvedNonStrictComparison)
11725 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11726 if (!ProvedNonEquality)
11727 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11728 if (ProvedNonStrictComparison && ProvedNonEquality)
11729 return true;
11730 return false;
11731 };
11732
11733 if (ProvingStrictComparison) {
11734 auto ProofFn = [&](CmpPredicate P) {
11735 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11736 };
11737 if (SplitAndProve(ProofFn))
11738 return true;
11739 }
11740
11741 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11742 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11743 const Instruction *CtxI = &BB->front();
11744 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11745 return true;
11746 if (ProvingStrictComparison) {
11747 auto ProofFn = [&](CmpPredicate P) {
11748 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11749 };
11750 if (SplitAndProve(ProofFn))
11751 return true;
11752 }
11753 return false;
11754 };
11755
11756 // Starting at the block's predecessor, climb up the predecessor chain, as long
11757 // as there are predecessors that can be found that have unique successors
11758 // leading to the original block.
11759 const Loop *ContainingLoop = LI.getLoopFor(BB);
11760 const BasicBlock *PredBB;
11761 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11762 PredBB = ContainingLoop->getLoopPredecessor();
11763 else
11764 PredBB = BB->getSinglePredecessor();
11765 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11766 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11767 const BranchInst *BlockEntryPredicate =
11768 dyn_cast<BranchInst>(Pair.first->getTerminator());
11769 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11770 continue;
11771
11772 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11773 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11774 return true;
11775 }
11776
11777 // Check conditions due to any @llvm.assume intrinsics.
11778 for (auto &AssumeVH : AC.assumptions()) {
11779 if (!AssumeVH)
11780 continue;
11781 auto *CI = cast<CallInst>(AssumeVH);
11782 if (!DT.dominates(CI, BB))
11783 continue;
11784
11785 if (ProveViaCond(CI->getArgOperand(0), false))
11786 return true;
11787 }
11788
11789 // Check conditions due to any @llvm.experimental.guard intrinsics.
11790 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11791 F.getParent(), Intrinsic::experimental_guard);
11792 if (GuardDecl)
11793 for (const auto *GU : GuardDecl->users())
11794 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11795 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11796 if (ProveViaCond(Guard->getArgOperand(0), false))
11797 return true;
11798 return false;
11799}
11800
11802 const SCEV *LHS,
11803 const SCEV *RHS) {
11804 // Interpret a null as meaning no loop, where there is obviously no guard
11805 // (interprocedural conditions notwithstanding).
11806 if (!L)
11807 return false;
11808
11809 // Both LHS and RHS must be available at loop entry.
11811 "LHS is not available at Loop Entry");
11813 "RHS is not available at Loop Entry");
11814
11815 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11816 return true;
11817
11818 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11819}
11820
11821bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11822 const SCEV *RHS,
11823 const Value *FoundCondValue, bool Inverse,
11824 const Instruction *CtxI) {
11825 // False conditions implies anything. Do not bother analyzing it further.
11826 if (FoundCondValue ==
11827 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11828 return true;
11829
11830 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11831 return false;
11832
11833 auto ClearOnExit =
11834 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11835
11836 // Recursively handle And and Or conditions.
11837 const Value *Op0, *Op1;
11838 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11839 if (!Inverse)
11840 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11841 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11842 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11843 if (Inverse)
11844 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11845 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11846 }
11847
11848 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11849 if (!ICI) return false;
11850
11851 // Now that we found a conditional branch that dominates the loop or controls
11852 // the loop latch. Check to see if it is the comparison we are looking for.
11853 CmpPredicate FoundPred;
11854 if (Inverse)
11855 FoundPred = ICI->getInverseCmpPredicate();
11856 else
11857 FoundPred = ICI->getCmpPredicate();
11858
11859 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11860 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11861
11862 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11863}
11864
11865bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11866 const SCEV *RHS, CmpPredicate FoundPred,
11867 const SCEV *FoundLHS, const SCEV *FoundRHS,
11868 const Instruction *CtxI) {
11869 // Balance the types.
11870 if (getTypeSizeInBits(LHS->getType()) <
11871 getTypeSizeInBits(FoundLHS->getType())) {
11872 // For unsigned and equality predicates, try to prove that both found
11873 // operands fit into narrow unsigned range. If so, try to prove facts in
11874 // narrow types.
11875 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11876 !FoundRHS->getType()->isPointerTy()) {
11877 auto *NarrowType = LHS->getType();
11878 auto *WideType = FoundLHS->getType();
11879 auto BitWidth = getTypeSizeInBits(NarrowType);
11880 const SCEV *MaxValue = getZeroExtendExpr(
11882 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11883 MaxValue) &&
11884 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11885 MaxValue)) {
11886 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11887 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11888 // We cannot preserve samesign after truncation.
11889 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11890 TruncFoundLHS, TruncFoundRHS, CtxI))
11891 return true;
11892 }
11893 }
11894
11895 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11896 return false;
11897 if (CmpInst::isSigned(Pred)) {
11898 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11899 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11900 } else {
11901 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11902 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11903 }
11904 } else if (getTypeSizeInBits(LHS->getType()) >
11905 getTypeSizeInBits(FoundLHS->getType())) {
11906 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11907 return false;
11908 if (CmpInst::isSigned(FoundPred)) {
11909 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11910 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11911 } else {
11912 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11913 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11914 }
11915 }
11916 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11917 FoundRHS, CtxI);
11918}
11919
11920bool ScalarEvolution::isImpliedCondBalancedTypes(
11921 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11922 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11924 getTypeSizeInBits(FoundLHS->getType()) &&
11925 "Types should be balanced!");
11926 // Canonicalize the query to match the way instcombine will have
11927 // canonicalized the comparison.
11928 if (SimplifyICmpOperands(Pred, LHS, RHS))
11929 if (LHS == RHS)
11930 return CmpInst::isTrueWhenEqual(Pred);
11931 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11932 if (FoundLHS == FoundRHS)
11933 return CmpInst::isFalseWhenEqual(FoundPred);
11934
11935 // Check to see if we can make the LHS or RHS match.
11936 if (LHS == FoundRHS || RHS == FoundLHS) {
11937 if (isa<SCEVConstant>(RHS)) {
11938 std::swap(FoundLHS, FoundRHS);
11939 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11940 } else {
11941 std::swap(LHS, RHS);
11943 }
11944 }
11945
11946 // Check whether the found predicate is the same as the desired predicate.
11947 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11948 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11949
11950 // Check whether swapping the found predicate makes it the same as the
11951 // desired predicate.
11952 if (auto P = CmpPredicate::getMatching(
11953 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11954 // We can write the implication
11955 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11956 // using one of the following ways:
11957 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11958 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11959 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11960 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11961 // Forms 1. and 2. require swapping the operands of one condition. Don't
11962 // do this if it would break canonical constant/addrec ordering.
11964 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
11965 LHS, FoundLHS, FoundRHS, CtxI);
11966 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11967 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11968
11969 // There's no clear preference between forms 3. and 4., try both. Avoid
11970 // forming getNotSCEV of pointer values as the resulting subtract is
11971 // not legal.
11972 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11973 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
11974 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
11975 FoundRHS, CtxI))
11976 return true;
11977
11978 if (!FoundLHS->getType()->isPointerTy() &&
11979 !FoundRHS->getType()->isPointerTy() &&
11980 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
11981 getNotSCEV(FoundRHS), CtxI))
11982 return true;
11983
11984 return false;
11985 }
11986
11987 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11988 CmpInst::Predicate P2) {
11989 assert(P1 != P2 && "Handled earlier!");
11990 return CmpInst::isRelational(P2) &&
11992 };
11993 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11994 // Unsigned comparison is the same as signed comparison when both the
11995 // operands are non-negative or negative.
11996 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11997 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11998 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11999 // Create local copies that we can freely swap and canonicalize our
12000 // conditions to "le/lt".
12001 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12002 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12003 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12004 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12005 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12006 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12007 std::swap(CanonicalLHS, CanonicalRHS);
12008 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12009 }
12010 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12011 "Must be!");
12012 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12013 ICmpInst::isLE(CanonicalFoundPred)) &&
12014 "Must be!");
12015 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12016 // Use implication:
12017 // x <u y && y >=s 0 --> x <s y.
12018 // If we can prove the left part, the right part is also proven.
12019 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12020 CanonicalRHS, CanonicalFoundLHS,
12021 CanonicalFoundRHS);
12022 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12023 // Use implication:
12024 // x <s y && y <s 0 --> x <u y.
12025 // If we can prove the left part, the right part is also proven.
12026 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12027 CanonicalRHS, CanonicalFoundLHS,
12028 CanonicalFoundRHS);
12029 }
12030
12031 // Check if we can make progress by sharpening ranges.
12032 if (FoundPred == ICmpInst::ICMP_NE &&
12033 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12034
12035 const SCEVConstant *C = nullptr;
12036 const SCEV *V = nullptr;
12037
12038 if (isa<SCEVConstant>(FoundLHS)) {
12039 C = cast<SCEVConstant>(FoundLHS);
12040 V = FoundRHS;
12041 } else {
12042 C = cast<SCEVConstant>(FoundRHS);
12043 V = FoundLHS;
12044 }
12045
12046 // The guarding predicate tells us that C != V. If the known range
12047 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12048 // range we consider has to correspond to same signedness as the
12049 // predicate we're interested in folding.
12050
12051 APInt Min = ICmpInst::isSigned(Pred) ?
12053
12054 if (Min == C->getAPInt()) {
12055 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12056 // This is true even if (Min + 1) wraps around -- in case of
12057 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12058
12059 APInt SharperMin = Min + 1;
12060
12061 switch (Pred) {
12062 case ICmpInst::ICMP_SGE:
12063 case ICmpInst::ICMP_UGE:
12064 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12065 // RHS, we're done.
12066 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12067 CtxI))
12068 return true;
12069 [[fallthrough]];
12070
12071 case ICmpInst::ICMP_SGT:
12072 case ICmpInst::ICMP_UGT:
12073 // We know from the range information that (V `Pred` Min ||
12074 // V == Min). We know from the guarding condition that !(V
12075 // == Min). This gives us
12076 //
12077 // V `Pred` Min || V == Min && !(V == Min)
12078 // => V `Pred` Min
12079 //
12080 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12081
12082 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12083 return true;
12084 break;
12085
12086 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12087 case ICmpInst::ICMP_SLE:
12088 case ICmpInst::ICMP_ULE:
12089 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12090 LHS, V, getConstant(SharperMin), CtxI))
12091 return true;
12092 [[fallthrough]];
12093
12094 case ICmpInst::ICMP_SLT:
12095 case ICmpInst::ICMP_ULT:
12096 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12097 LHS, V, getConstant(Min), CtxI))
12098 return true;
12099 break;
12100
12101 default:
12102 // No change
12103 break;
12104 }
12105 }
12106 }
12107
12108 // Check whether the actual condition is beyond sufficient.
12109 if (FoundPred == ICmpInst::ICMP_EQ)
12110 if (ICmpInst::isTrueWhenEqual(Pred))
12111 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12112 return true;
12113 if (Pred == ICmpInst::ICMP_NE)
12114 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12115 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12116 return true;
12117
12118 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12119 return true;
12120
12121 // Otherwise assume the worst.
12122 return false;
12123}
12124
12125bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12126 const SCEV *&L, const SCEV *&R,
12127 SCEV::NoWrapFlags &Flags) {
12128 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12129 if (!AE || AE->getNumOperands() != 2)
12130 return false;
12131
12132 L = AE->getOperand(0);
12133 R = AE->getOperand(1);
12134 Flags = AE->getNoWrapFlags();
12135 return true;
12136}
12137
12138std::optional<APInt>
12140 // We avoid subtracting expressions here because this function is usually
12141 // fairly deep in the call stack (i.e. is called many times).
12142
12143 unsigned BW = getTypeSizeInBits(More->getType());
12144 APInt Diff(BW, 0);
12145 APInt DiffMul(BW, 1);
12146 // Try various simplifications to reduce the difference to a constant. Limit
12147 // the number of allowed simplifications to keep compile-time low.
12148 for (unsigned I = 0; I < 8; ++I) {
12149 if (More == Less)
12150 return Diff;
12151
12152 // Reduce addrecs with identical steps to their start value.
12154 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12155 const auto *MAR = cast<SCEVAddRecExpr>(More);
12156
12157 if (LAR->getLoop() != MAR->getLoop())
12158 return std::nullopt;
12159
12160 // We look at affine expressions only; not for correctness but to keep
12161 // getStepRecurrence cheap.
12162 if (!LAR->isAffine() || !MAR->isAffine())
12163 return std::nullopt;
12164
12165 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12166 return std::nullopt;
12167
12168 Less = LAR->getStart();
12169 More = MAR->getStart();
12170 continue;
12171 }
12172
12173 // Try to match a common constant multiply.
12174 auto MatchConstMul =
12175 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12176 auto *M = dyn_cast<SCEVMulExpr>(S);
12177 if (!M || M->getNumOperands() != 2 ||
12178 !isa<SCEVConstant>(M->getOperand(0)))
12179 return std::nullopt;
12180 return {
12181 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12182 };
12183 if (auto MatchedMore = MatchConstMul(More)) {
12184 if (auto MatchedLess = MatchConstMul(Less)) {
12185 if (MatchedMore->second == MatchedLess->second) {
12186 More = MatchedMore->first;
12187 Less = MatchedLess->first;
12188 DiffMul *= MatchedMore->second;
12189 continue;
12190 }
12191 }
12192 }
12193
12194 // Try to cancel out common factors in two add expressions.
12196 auto Add = [&](const SCEV *S, int Mul) {
12197 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12198 if (Mul == 1) {
12199 Diff += C->getAPInt() * DiffMul;
12200 } else {
12201 assert(Mul == -1);
12202 Diff -= C->getAPInt() * DiffMul;
12203 }
12204 } else
12205 Multiplicity[S] += Mul;
12206 };
12207 auto Decompose = [&](const SCEV *S, int Mul) {
12208 if (isa<SCEVAddExpr>(S)) {
12209 for (const SCEV *Op : S->operands())
12210 Add(Op, Mul);
12211 } else
12212 Add(S, Mul);
12213 };
12214 Decompose(More, 1);
12215 Decompose(Less, -1);
12216
12217 // Check whether all the non-constants cancel out, or reduce to new
12218 // More/Less values.
12219 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12220 for (const auto &[S, Mul] : Multiplicity) {
12221 if (Mul == 0)
12222 continue;
12223 if (Mul == 1) {
12224 if (NewMore)
12225 return std::nullopt;
12226 NewMore = S;
12227 } else if (Mul == -1) {
12228 if (NewLess)
12229 return std::nullopt;
12230 NewLess = S;
12231 } else
12232 return std::nullopt;
12233 }
12234
12235 // Values stayed the same, no point in trying further.
12236 if (NewMore == More || NewLess == Less)
12237 return std::nullopt;
12238
12239 More = NewMore;
12240 Less = NewLess;
12241
12242 // Reduced to constant.
12243 if (!More && !Less)
12244 return Diff;
12245
12246 // Left with variable on only one side, bail out.
12247 if (!More || !Less)
12248 return std::nullopt;
12249 }
12250
12251 // Did not reduce to constant.
12252 return std::nullopt;
12253}
12254
12255bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12256 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12257 const SCEV *FoundRHS, const Instruction *CtxI) {
12258 // Try to recognize the following pattern:
12259 //
12260 // FoundRHS = ...
12261 // ...
12262 // loop:
12263 // FoundLHS = {Start,+,W}
12264 // context_bb: // Basic block from the same loop
12265 // known(Pred, FoundLHS, FoundRHS)
12266 //
12267 // If some predicate is known in the context of a loop, it is also known on
12268 // each iteration of this loop, including the first iteration. Therefore, in
12269 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12270 // prove the original pred using this fact.
12271 if (!CtxI)
12272 return false;
12273 const BasicBlock *ContextBB = CtxI->getParent();
12274 // Make sure AR varies in the context block.
12275 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12276 const Loop *L = AR->getLoop();
12277 // Make sure that context belongs to the loop and executes on 1st iteration
12278 // (if it ever executes at all).
12279 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12280 return false;
12281 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12282 return false;
12283 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12284 }
12285
12286 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12287 const Loop *L = AR->getLoop();
12288 // Make sure that context belongs to the loop and executes on 1st iteration
12289 // (if it ever executes at all).
12290 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12291 return false;
12292 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12293 return false;
12294 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12295 }
12296
12297 return false;
12298}
12299
12300bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12301 const SCEV *LHS,
12302 const SCEV *RHS,
12303 const SCEV *FoundLHS,
12304 const SCEV *FoundRHS) {
12305 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12306 return false;
12307
12308 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12309 if (!AddRecLHS)
12310 return false;
12311
12312 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12313 if (!AddRecFoundLHS)
12314 return false;
12315
12316 // We'd like to let SCEV reason about control dependencies, so we constrain
12317 // both the inequalities to be about add recurrences on the same loop. This
12318 // way we can use isLoopEntryGuardedByCond later.
12319
12320 const Loop *L = AddRecFoundLHS->getLoop();
12321 if (L != AddRecLHS->getLoop())
12322 return false;
12323
12324 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12325 //
12326 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12327 // ... (2)
12328 //
12329 // Informal proof for (2), assuming (1) [*]:
12330 //
12331 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12332 //
12333 // Then
12334 //
12335 // FoundLHS s< FoundRHS s< INT_MIN - C
12336 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12337 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12338 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12339 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12340 // <=> FoundLHS + C s< FoundRHS + C
12341 //
12342 // [*]: (1) can be proved by ruling out overflow.
12343 //
12344 // [**]: This can be proved by analyzing all the four possibilities:
12345 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12346 // (A s>= 0, B s>= 0).
12347 //
12348 // Note:
12349 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12350 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12351 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12352 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12353 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12354 // C)".
12355
12356 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12357 if (!LDiff)
12358 return false;
12359 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12360 if (!RDiff || *LDiff != *RDiff)
12361 return false;
12362
12363 if (LDiff->isMinValue())
12364 return true;
12365
12366 APInt FoundRHSLimit;
12367
12368 if (Pred == CmpInst::ICMP_ULT) {
12369 FoundRHSLimit = -(*RDiff);
12370 } else {
12371 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12372 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12373 }
12374
12375 // Try to prove (1) or (2), as needed.
12376 return isAvailableAtLoopEntry(FoundRHS, L) &&
12377 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12378 getConstant(FoundRHSLimit));
12379}
12380
12381bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12382 const SCEV *RHS, const SCEV *FoundLHS,
12383 const SCEV *FoundRHS, unsigned Depth) {
12384 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12385
12386 auto ClearOnExit = make_scope_exit([&]() {
12387 if (LPhi) {
12388 bool Erased = PendingMerges.erase(LPhi);
12389 assert(Erased && "Failed to erase LPhi!");
12390 (void)Erased;
12391 }
12392 if (RPhi) {
12393 bool Erased = PendingMerges.erase(RPhi);
12394 assert(Erased && "Failed to erase RPhi!");
12395 (void)Erased;
12396 }
12397 });
12398
12399 // Find respective Phis and check that they are not being pending.
12400 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12401 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12402 if (!PendingMerges.insert(Phi).second)
12403 return false;
12404 LPhi = Phi;
12405 }
12406 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12407 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12408 // If we detect a loop of Phi nodes being processed by this method, for
12409 // example:
12410 //
12411 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12412 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12413 //
12414 // we don't want to deal with a case that complex, so return conservative
12415 // answer false.
12416 if (!PendingMerges.insert(Phi).second)
12417 return false;
12418 RPhi = Phi;
12419 }
12420
12421 // If none of LHS, RHS is a Phi, nothing to do here.
12422 if (!LPhi && !RPhi)
12423 return false;
12424
12425 // If there is a SCEVUnknown Phi we are interested in, make it left.
12426 if (!LPhi) {
12427 std::swap(LHS, RHS);
12428 std::swap(FoundLHS, FoundRHS);
12429 std::swap(LPhi, RPhi);
12431 }
12432
12433 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12434 const BasicBlock *LBB = LPhi->getParent();
12435 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12436
12437 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12438 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12439 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12440 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12441 };
12442
12443 if (RPhi && RPhi->getParent() == LBB) {
12444 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12445 // If we compare two Phis from the same block, and for each entry block
12446 // the predicate is true for incoming values from this block, then the
12447 // predicate is also true for the Phis.
12448 for (const BasicBlock *IncBB : predecessors(LBB)) {
12449 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12450 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12451 if (!ProvedEasily(L, R))
12452 return false;
12453 }
12454 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12455 // Case two: RHS is also a Phi from the same basic block, and it is an
12456 // AddRec. It means that there is a loop which has both AddRec and Unknown
12457 // PHIs, for it we can compare incoming values of AddRec from above the loop
12458 // and latch with their respective incoming values of LPhi.
12459 // TODO: Generalize to handle loops with many inputs in a header.
12460 if (LPhi->getNumIncomingValues() != 2) return false;
12461
12462 auto *RLoop = RAR->getLoop();
12463 auto *Predecessor = RLoop->getLoopPredecessor();
12464 assert(Predecessor && "Loop with AddRec with no predecessor?");
12465 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12466 if (!ProvedEasily(L1, RAR->getStart()))
12467 return false;
12468 auto *Latch = RLoop->getLoopLatch();
12469 assert(Latch && "Loop with AddRec with no latch?");
12470 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12471 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12472 return false;
12473 } else {
12474 // In all other cases go over inputs of LHS and compare each of them to RHS,
12475 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12476 // At this point RHS is either a non-Phi, or it is a Phi from some block
12477 // different from LBB.
12478 for (const BasicBlock *IncBB : predecessors(LBB)) {
12479 // Check that RHS is available in this block.
12480 if (!dominates(RHS, IncBB))
12481 return false;
12482 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12483 // Make sure L does not refer to a value from a potentially previous
12484 // iteration of a loop.
12485 if (!properlyDominates(L, LBB))
12486 return false;
12487 // Addrecs are considered to properly dominate their loop, so are missed
12488 // by the previous check. Discard any values that have computable
12489 // evolution in this loop.
12490 if (auto *Loop = LI.getLoopFor(LBB))
12491 if (hasComputableLoopEvolution(L, Loop))
12492 return false;
12493 if (!ProvedEasily(L, RHS))
12494 return false;
12495 }
12496 }
12497 return true;
12498}
12499
12500bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12501 const SCEV *LHS,
12502 const SCEV *RHS,
12503 const SCEV *FoundLHS,
12504 const SCEV *FoundRHS) {
12505 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12506 // sure that we are dealing with same LHS.
12507 if (RHS == FoundRHS) {
12508 std::swap(LHS, RHS);
12509 std::swap(FoundLHS, FoundRHS);
12511 }
12512 if (LHS != FoundLHS)
12513 return false;
12514
12515 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12516 if (!SUFoundRHS)
12517 return false;
12518
12519 Value *Shiftee, *ShiftValue;
12520
12521 using namespace PatternMatch;
12522 if (match(SUFoundRHS->getValue(),
12523 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12524 auto *ShifteeS = getSCEV(Shiftee);
12525 // Prove one of the following:
12526 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12527 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12528 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12529 // ---> LHS <s RHS
12530 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12531 // ---> LHS <=s RHS
12532 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12533 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12534 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12535 if (isKnownNonNegative(ShifteeS))
12536 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12537 }
12538
12539 return false;
12540}
12541
12542bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12543 const SCEV *RHS,
12544 const SCEV *FoundLHS,
12545 const SCEV *FoundRHS,
12546 const Instruction *CtxI) {
12547 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12548 FoundRHS) ||
12549 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12550 FoundRHS) ||
12551 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12552 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12553 CtxI) ||
12554 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12555}
12556
12557/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12558template <typename MinMaxExprType>
12559static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12560 const SCEV *Candidate) {
12561 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12562 if (!MinMaxExpr)
12563 return false;
12564
12565 return is_contained(MinMaxExpr->operands(), Candidate);
12566}
12567
12569 CmpPredicate Pred, const SCEV *LHS,
12570 const SCEV *RHS) {
12571 // If both sides are affine addrecs for the same loop, with equal
12572 // steps, and we know the recurrences don't wrap, then we only
12573 // need to check the predicate on the starting values.
12574
12575 if (!ICmpInst::isRelational(Pred))
12576 return false;
12577
12578 const SCEV *LStart, *RStart, *Step;
12579 const Loop *L;
12580 if (!match(LHS,
12581 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12583 m_SpecificLoop(L))))
12584 return false;
12589 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12590 return false;
12591
12592 return SE.isKnownPredicate(Pred, LStart, RStart);
12593}
12594
12595/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12596/// expression?
12598 const SCEV *LHS, const SCEV *RHS) {
12599 switch (Pred) {
12600 default:
12601 return false;
12602
12603 case ICmpInst::ICMP_SGE:
12604 std::swap(LHS, RHS);
12605 [[fallthrough]];
12606 case ICmpInst::ICMP_SLE:
12607 return
12608 // min(A, ...) <= A
12610 // A <= max(A, ...)
12612
12613 case ICmpInst::ICMP_UGE:
12614 std::swap(LHS, RHS);
12615 [[fallthrough]];
12616 case ICmpInst::ICMP_ULE:
12617 return
12618 // min(A, ...) <= A
12619 // FIXME: what about umin_seq?
12621 // A <= max(A, ...)
12623 }
12624
12625 llvm_unreachable("covered switch fell through?!");
12626}
12627
12628bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12629 const SCEV *RHS,
12630 const SCEV *FoundLHS,
12631 const SCEV *FoundRHS,
12632 unsigned Depth) {
12635 "LHS and RHS have different sizes?");
12636 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12637 getTypeSizeInBits(FoundRHS->getType()) &&
12638 "FoundLHS and FoundRHS have different sizes?");
12639 // We want to avoid hurting the compile time with analysis of too big trees.
12641 return false;
12642
12643 // We only want to work with GT comparison so far.
12644 if (ICmpInst::isLT(Pred)) {
12646 std::swap(LHS, RHS);
12647 std::swap(FoundLHS, FoundRHS);
12648 }
12649
12651
12652 // For unsigned, try to reduce it to corresponding signed comparison.
12653 if (P == ICmpInst::ICMP_UGT)
12654 // We can replace unsigned predicate with its signed counterpart if all
12655 // involved values are non-negative.
12656 // TODO: We could have better support for unsigned.
12657 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12658 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12659 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12660 // use this fact to prove that LHS and RHS are non-negative.
12661 const SCEV *MinusOne = getMinusOne(LHS->getType());
12662 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12663 FoundRHS) &&
12664 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12665 FoundRHS))
12667 }
12668
12669 if (P != ICmpInst::ICMP_SGT)
12670 return false;
12671
12672 auto GetOpFromSExt = [&](const SCEV *S) {
12673 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12674 return Ext->getOperand();
12675 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12676 // the constant in some cases.
12677 return S;
12678 };
12679
12680 // Acquire values from extensions.
12681 auto *OrigLHS = LHS;
12682 auto *OrigFoundLHS = FoundLHS;
12683 LHS = GetOpFromSExt(LHS);
12684 FoundLHS = GetOpFromSExt(FoundLHS);
12685
12686 // Is the SGT predicate can be proved trivially or using the found context.
12687 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12688 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12689 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12690 FoundRHS, Depth + 1);
12691 };
12692
12693 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12694 // We want to avoid creation of any new non-constant SCEV. Since we are
12695 // going to compare the operands to RHS, we should be certain that we don't
12696 // need any size extensions for this. So let's decline all cases when the
12697 // sizes of types of LHS and RHS do not match.
12698 // TODO: Maybe try to get RHS from sext to catch more cases?
12700 return false;
12701
12702 // Should not overflow.
12703 if (!LHSAddExpr->hasNoSignedWrap())
12704 return false;
12705
12706 auto *LL = LHSAddExpr->getOperand(0);
12707 auto *LR = LHSAddExpr->getOperand(1);
12708 auto *MinusOne = getMinusOne(RHS->getType());
12709
12710 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12711 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12712 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12713 };
12714 // Try to prove the following rule:
12715 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12716 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12717 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12718 return true;
12719 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12720 Value *LL, *LR;
12721 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12722
12723 using namespace llvm::PatternMatch;
12724
12725 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12726 // Rules for division.
12727 // We are going to perform some comparisons with Denominator and its
12728 // derivative expressions. In general case, creating a SCEV for it may
12729 // lead to a complex analysis of the entire graph, and in particular it
12730 // can request trip count recalculation for the same loop. This would
12731 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12732 // this, we only want to create SCEVs that are constants in this section.
12733 // So we bail if Denominator is not a constant.
12734 if (!isa<ConstantInt>(LR))
12735 return false;
12736
12737 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12738
12739 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12740 // then a SCEV for the numerator already exists and matches with FoundLHS.
12741 auto *Numerator = getExistingSCEV(LL);
12742 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12743 return false;
12744
12745 // Make sure that the numerator matches with FoundLHS and the denominator
12746 // is positive.
12747 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12748 return false;
12749
12750 auto *DTy = Denominator->getType();
12751 auto *FRHSTy = FoundRHS->getType();
12752 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12753 // One of types is a pointer and another one is not. We cannot extend
12754 // them properly to a wider type, so let us just reject this case.
12755 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12756 // to avoid this check.
12757 return false;
12758
12759 // Given that:
12760 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12761 auto *WTy = getWiderType(DTy, FRHSTy);
12762 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12763 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12764
12765 // Try to prove the following rule:
12766 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12767 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12768 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12769 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12770 if (isKnownNonPositive(RHS) &&
12771 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12772 return true;
12773
12774 // Try to prove the following rule:
12775 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12776 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12777 // If we divide it by Denominator > 2, then:
12778 // 1. If FoundLHS is negative, then the result is 0.
12779 // 2. If FoundLHS is non-negative, then the result is non-negative.
12780 // Anyways, the result is non-negative.
12781 auto *MinusOne = getMinusOne(WTy);
12782 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12783 if (isKnownNegative(RHS) &&
12784 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12785 return true;
12786 }
12787 }
12788
12789 // If our expression contained SCEVUnknown Phis, and we split it down and now
12790 // need to prove something for them, try to prove the predicate for every
12791 // possible incoming values of those Phis.
12792 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12793 return true;
12794
12795 return false;
12796}
12797
12799 const SCEV *RHS) {
12800 // zext x u<= sext x, sext x s<= zext x
12801 const SCEV *Op;
12802 switch (Pred) {
12803 case ICmpInst::ICMP_SGE:
12804 std::swap(LHS, RHS);
12805 [[fallthrough]];
12806 case ICmpInst::ICMP_SLE: {
12807 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12808 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12810 }
12811 case ICmpInst::ICMP_UGE:
12812 std::swap(LHS, RHS);
12813 [[fallthrough]];
12814 case ICmpInst::ICMP_ULE: {
12815 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12816 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12818 }
12819 default:
12820 return false;
12821 };
12822 llvm_unreachable("unhandled case");
12823}
12824
12825bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12826 const SCEV *LHS,
12827 const SCEV *RHS) {
12828 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12829 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12830 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12831 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12832 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12833}
12834
12835bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12836 const SCEV *LHS,
12837 const SCEV *RHS,
12838 const SCEV *FoundLHS,
12839 const SCEV *FoundRHS) {
12840 switch (Pred) {
12841 default:
12842 llvm_unreachable("Unexpected CmpPredicate value!");
12843 case ICmpInst::ICMP_EQ:
12844 case ICmpInst::ICMP_NE:
12845 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12846 return true;
12847 break;
12848 case ICmpInst::ICMP_SLT:
12849 case ICmpInst::ICMP_SLE:
12850 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12851 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12852 return true;
12853 break;
12854 case ICmpInst::ICMP_SGT:
12855 case ICmpInst::ICMP_SGE:
12856 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12857 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12858 return true;
12859 break;
12860 case ICmpInst::ICMP_ULT:
12861 case ICmpInst::ICMP_ULE:
12862 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12863 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12864 return true;
12865 break;
12866 case ICmpInst::ICMP_UGT:
12867 case ICmpInst::ICMP_UGE:
12868 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12869 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12870 return true;
12871 break;
12872 }
12873
12874 // Maybe it can be proved via operations?
12875 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12876 return true;
12877
12878 return false;
12879}
12880
12881bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12882 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12883 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12884 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12885 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12886 // reduce the compile time impact of this optimization.
12887 return false;
12888
12889 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12890 if (!Addend)
12891 return false;
12892
12893 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12894
12895 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12896 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12897 ConstantRange FoundLHSRange =
12898 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12899
12900 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12901 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12902
12903 // We can also compute the range of values for `LHS` that satisfy the
12904 // consequent, "`LHS` `Pred` `RHS`":
12905 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12906 // The antecedent implies the consequent if every value of `LHS` that
12907 // satisfies the antecedent also satisfies the consequent.
12908 return LHSRange.icmp(Pred, ConstRHS);
12909}
12910
12911bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12912 bool IsSigned) {
12913 assert(isKnownPositive(Stride) && "Positive stride expected!");
12914
12915 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12916 const SCEV *One = getOne(Stride->getType());
12917
12918 if (IsSigned) {
12919 APInt MaxRHS = getSignedRangeMax(RHS);
12920 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12921 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12922
12923 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12924 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12925 }
12926
12927 APInt MaxRHS = getUnsignedRangeMax(RHS);
12928 APInt MaxValue = APInt::getMaxValue(BitWidth);
12929 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12930
12931 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12932 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12933}
12934
12935bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12936 bool IsSigned) {
12937
12938 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12939 const SCEV *One = getOne(Stride->getType());
12940
12941 if (IsSigned) {
12942 APInt MinRHS = getSignedRangeMin(RHS);
12943 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12944 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12945
12946 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12947 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12948 }
12949
12950 APInt MinRHS = getUnsignedRangeMin(RHS);
12951 APInt MinValue = APInt::getMinValue(BitWidth);
12952 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12953
12954 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12955 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12956}
12957
12959 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12960 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12961 // expression fixes the case of N=0.
12962 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12963 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12964 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12965}
12966
12967const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12968 const SCEV *Stride,
12969 const SCEV *End,
12970 unsigned BitWidth,
12971 bool IsSigned) {
12972 // The logic in this function assumes we can represent a positive stride.
12973 // If we can't, the backedge-taken count must be zero.
12974 if (IsSigned && BitWidth == 1)
12975 return getZero(Stride->getType());
12976
12977 // This code below only been closely audited for negative strides in the
12978 // unsigned comparison case, it may be correct for signed comparison, but
12979 // that needs to be established.
12980 if (IsSigned && isKnownNegative(Stride))
12981 return getCouldNotCompute();
12982
12983 // Calculate the maximum backedge count based on the range of values
12984 // permitted by Start, End, and Stride.
12985 APInt MinStart =
12986 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12987
12988 APInt MinStride =
12989 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12990
12991 // We assume either the stride is positive, or the backedge-taken count
12992 // is zero. So force StrideForMaxBECount to be at least one.
12993 APInt One(BitWidth, 1);
12994 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12995 : APIntOps::umax(One, MinStride);
12996
12997 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12998 : APInt::getMaxValue(BitWidth);
12999 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13000
13001 // Although End can be a MAX expression we estimate MaxEnd considering only
13002 // the case End = RHS of the loop termination condition. This is safe because
13003 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13004 // taken count.
13005 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13006 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13007
13008 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13009 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13010 : APIntOps::umax(MaxEnd, MinStart);
13011
13012 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13013 getConstant(StrideForMaxBECount) /* Step */);
13014}
13015
13017ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13018 const Loop *L, bool IsSigned,
13019 bool ControlsOnlyExit, bool AllowPredicates) {
13021
13023 bool PredicatedIV = false;
13024 if (!IV) {
13025 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13026 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13027 if (AR && AR->getLoop() == L && AR->isAffine()) {
13028 auto canProveNUW = [&]() {
13029 // We can use the comparison to infer no-wrap flags only if it fully
13030 // controls the loop exit.
13031 if (!ControlsOnlyExit)
13032 return false;
13033
13034 if (!isLoopInvariant(RHS, L))
13035 return false;
13036
13037 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13038 // We need the sequence defined by AR to strictly increase in the
13039 // unsigned integer domain for the logic below to hold.
13040 return false;
13041
13042 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13043 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13044 // If RHS <=u Limit, then there must exist a value V in the sequence
13045 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13046 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13047 // overflow occurs. This limit also implies that a signed comparison
13048 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13049 // the high bits on both sides must be zero.
13050 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13051 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13052 Limit = Limit.zext(OuterBitWidth);
13053 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13054 };
13055 auto Flags = AR->getNoWrapFlags();
13056 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13057 Flags = setFlags(Flags, SCEV::FlagNUW);
13058
13059 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13060 if (AR->hasNoUnsignedWrap()) {
13061 // Emulate what getZeroExtendExpr would have done during construction
13062 // if we'd been able to infer the fact just above at that time.
13063 const SCEV *Step = AR->getStepRecurrence(*this);
13064 Type *Ty = ZExt->getType();
13065 auto *S = getAddRecExpr(
13067 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13069 }
13070 }
13071 }
13072 }
13073
13074
13075 if (!IV && AllowPredicates) {
13076 // Try to make this an AddRec using runtime tests, in the first X
13077 // iterations of this loop, where X is the SCEV expression found by the
13078 // algorithm below.
13079 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13080 PredicatedIV = true;
13081 }
13082
13083 // Avoid weird loops
13084 if (!IV || IV->getLoop() != L || !IV->isAffine())
13085 return getCouldNotCompute();
13086
13087 // A precondition of this method is that the condition being analyzed
13088 // reaches an exiting branch which dominates the latch. Given that, we can
13089 // assume that an increment which violates the nowrap specification and
13090 // produces poison must cause undefined behavior when the resulting poison
13091 // value is branched upon and thus we can conclude that the backedge is
13092 // taken no more often than would be required to produce that poison value.
13093 // Note that a well defined loop can exit on the iteration which violates
13094 // the nowrap specification if there is another exit (either explicit or
13095 // implicit/exceptional) which causes the loop to execute before the
13096 // exiting instruction we're analyzing would trigger UB.
13097 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13098 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13100
13101 const SCEV *Stride = IV->getStepRecurrence(*this);
13102
13103 bool PositiveStride = isKnownPositive(Stride);
13104
13105 // Avoid negative or zero stride values.
13106 if (!PositiveStride) {
13107 // We can compute the correct backedge taken count for loops with unknown
13108 // strides if we can prove that the loop is not an infinite loop with side
13109 // effects. Here's the loop structure we are trying to handle -
13110 //
13111 // i = start
13112 // do {
13113 // A[i] = i;
13114 // i += s;
13115 // } while (i < end);
13116 //
13117 // The backedge taken count for such loops is evaluated as -
13118 // (max(end, start + stride) - start - 1) /u stride
13119 //
13120 // The additional preconditions that we need to check to prove correctness
13121 // of the above formula is as follows -
13122 //
13123 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13124 // NoWrap flag).
13125 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13126 // no side effects within the loop)
13127 // c) loop has a single static exit (with no abnormal exits)
13128 //
13129 // Precondition a) implies that if the stride is negative, this is a single
13130 // trip loop. The backedge taken count formula reduces to zero in this case.
13131 //
13132 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13133 // then a zero stride means the backedge can't be taken without executing
13134 // undefined behavior.
13135 //
13136 // The positive stride case is the same as isKnownPositive(Stride) returning
13137 // true (original behavior of the function).
13138 //
13139 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13141 return getCouldNotCompute();
13142
13143 if (!isKnownNonZero(Stride)) {
13144 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13145 // if it might eventually be greater than start and if so, on which
13146 // iteration. We can't even produce a useful upper bound.
13147 if (!isLoopInvariant(RHS, L))
13148 return getCouldNotCompute();
13149
13150 // We allow a potentially zero stride, but we need to divide by stride
13151 // below. Since the loop can't be infinite and this check must control
13152 // the sole exit, we can infer the exit must be taken on the first
13153 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13154 // we know the numerator in the divides below must be zero, so we can
13155 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13156 // and produce the right result.
13157 // FIXME: Handle the case where Stride is poison?
13158 auto wouldZeroStrideBeUB = [&]() {
13159 // Proof by contradiction. Suppose the stride were zero. If we can
13160 // prove that the backedge *is* taken on the first iteration, then since
13161 // we know this condition controls the sole exit, we must have an
13162 // infinite loop. We can't have a (well defined) infinite loop per
13163 // check just above.
13164 // Note: The (Start - Stride) term is used to get the start' term from
13165 // (start' + stride,+,stride). Remember that we only care about the
13166 // result of this expression when stride == 0 at runtime.
13167 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13168 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13169 };
13170 if (!wouldZeroStrideBeUB()) {
13171 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13172 }
13173 }
13174 } else if (!NoWrap) {
13175 // Avoid proven overflow cases: this will ensure that the backedge taken
13176 // count will not generate any unsigned overflow.
13177 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13178 return getCouldNotCompute();
13179 }
13180
13181 // On all paths just preceeding, we established the following invariant:
13182 // IV can be assumed not to overflow up to and including the exiting
13183 // iteration. We proved this in one of two ways:
13184 // 1) We can show overflow doesn't occur before the exiting iteration
13185 // 1a) canIVOverflowOnLT, and b) step of one
13186 // 2) We can show that if overflow occurs, the loop must execute UB
13187 // before any possible exit.
13188 // Note that we have not yet proved RHS invariant (in general).
13189
13190 const SCEV *Start = IV->getStart();
13191
13192 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13193 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13194 // Use integer-typed versions for actual computation; we can't subtract
13195 // pointers in general.
13196 const SCEV *OrigStart = Start;
13197 const SCEV *OrigRHS = RHS;
13198 if (Start->getType()->isPointerTy()) {
13200 if (isa<SCEVCouldNotCompute>(Start))
13201 return Start;
13202 }
13203 if (RHS->getType()->isPointerTy()) {
13206 return RHS;
13207 }
13208
13209 const SCEV *End = nullptr, *BECount = nullptr,
13210 *BECountIfBackedgeTaken = nullptr;
13211 if (!isLoopInvariant(RHS, L)) {
13212 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13213 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13214 RHSAddRec->getNoWrapFlags()) {
13215 // The structure of loop we are trying to calculate backedge count of:
13216 //
13217 // left = left_start
13218 // right = right_start
13219 //
13220 // while(left < right){
13221 // ... do something here ...
13222 // left += s1; // stride of left is s1 (s1 > 0)
13223 // right += s2; // stride of right is s2 (s2 < 0)
13224 // }
13225 //
13226
13227 const SCEV *RHSStart = RHSAddRec->getStart();
13228 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13229
13230 // If Stride - RHSStride is positive and does not overflow, we can write
13231 // backedge count as ->
13232 // ceil((End - Start) /u (Stride - RHSStride))
13233 // Where, End = max(RHSStart, Start)
13234
13235 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13236 if (isKnownNegative(RHSStride) &&
13237 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13238 RHSStride)) {
13239
13240 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13241 if (isKnownPositive(Denominator)) {
13242 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13243 : getUMaxExpr(RHSStart, Start);
13244
13245 // We can do this because End >= Start, as End = max(RHSStart, Start)
13246 const SCEV *Delta = getMinusSCEV(End, Start);
13247
13248 BECount = getUDivCeilSCEV(Delta, Denominator);
13249 BECountIfBackedgeTaken =
13250 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13251 }
13252 }
13253 }
13254 if (BECount == nullptr) {
13255 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13256 // given the start, stride and max value for the end bound of the
13257 // loop (RHS), and the fact that IV does not overflow (which is
13258 // checked above).
13259 const SCEV *MaxBECount = computeMaxBECountForLT(
13260 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13261 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13262 MaxBECount, false /*MaxOrZero*/, Predicates);
13263 }
13264 } else {
13265 // We use the expression (max(End,Start)-Start)/Stride to describe the
13266 // backedge count, as if the backedge is taken at least once
13267 // max(End,Start) is End and so the result is as above, and if not
13268 // max(End,Start) is Start so we get a backedge count of zero.
13269 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13270 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13271 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13272 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13273 // Can we prove (max(RHS,Start) > Start - Stride?
13274 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13275 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13276 // In this case, we can use a refined formula for computing backedge
13277 // taken count. The general formula remains:
13278 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13279 // We want to use the alternate formula:
13280 // "((End - 1) - (Start - Stride)) /u Stride"
13281 // Let's do a quick case analysis to show these are equivalent under
13282 // our precondition that max(RHS,Start) > Start - Stride.
13283 // * For RHS <= Start, the backedge-taken count must be zero.
13284 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13285 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13286 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13287 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13288 // reducing this to the stride of 1 case.
13289 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13290 // Stride".
13291 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13292 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13293 // "((RHS - (Start - Stride) - 1) /u Stride".
13294 // Our preconditions trivially imply no overflow in that form.
13295 const SCEV *MinusOne = getMinusOne(Stride->getType());
13296 const SCEV *Numerator =
13297 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13298 BECount = getUDivExpr(Numerator, Stride);
13299 }
13300
13301 if (!BECount) {
13302 auto canProveRHSGreaterThanEqualStart = [&]() {
13303 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13304 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13305 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13306
13307 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13308 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13309 return true;
13310
13311 // (RHS > Start - 1) implies RHS >= Start.
13312 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13313 // "Start - 1" doesn't overflow.
13314 // * For signed comparison, if Start - 1 does overflow, it's equal
13315 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13316 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13317 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13318 //
13319 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13320 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13321 auto *StartMinusOne =
13322 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13323 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13324 };
13325
13326 // If we know that RHS >= Start in the context of loop, then we know
13327 // that max(RHS, Start) = RHS at this point.
13328 if (canProveRHSGreaterThanEqualStart()) {
13329 End = RHS;
13330 } else {
13331 // If RHS < Start, the backedge will be taken zero times. So in
13332 // general, we can write the backedge-taken count as:
13333 //
13334 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13335 //
13336 // We convert it to the following to make it more convenient for SCEV:
13337 //
13338 // ceil(max(RHS, Start) - Start) / Stride
13339 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13340
13341 // See what would happen if we assume the backedge is taken. This is
13342 // used to compute MaxBECount.
13343 BECountIfBackedgeTaken =
13344 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13345 }
13346
13347 // At this point, we know:
13348 //
13349 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13350 // 2. The index variable doesn't overflow.
13351 //
13352 // Therefore, we know N exists such that
13353 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13354 // doesn't overflow.
13355 //
13356 // Using this information, try to prove whether the addition in
13357 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13358 const SCEV *One = getOne(Stride->getType());
13359 bool MayAddOverflow = [&] {
13360 if (isKnownToBeAPowerOfTwo(Stride)) {
13361 // Suppose Stride is a power of two, and Start/End are unsigned
13362 // integers. Let UMAX be the largest representable unsigned
13363 // integer.
13364 //
13365 // By the preconditions of this function, we know
13366 // "(Start + Stride * N) >= End", and this doesn't overflow.
13367 // As a formula:
13368 //
13369 // End <= (Start + Stride * N) <= UMAX
13370 //
13371 // Subtracting Start from all the terms:
13372 //
13373 // End - Start <= Stride * N <= UMAX - Start
13374 //
13375 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13376 //
13377 // End - Start <= Stride * N <= UMAX
13378 //
13379 // Stride * N is a multiple of Stride. Therefore,
13380 //
13381 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13382 //
13383 // Since Stride is a power of two, UMAX + 1 is divisible by
13384 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13385 // write:
13386 //
13387 // End - Start <= Stride * N <= UMAX - Stride - 1
13388 //
13389 // Dropping the middle term:
13390 //
13391 // End - Start <= UMAX - Stride - 1
13392 //
13393 // Adding Stride - 1 to both sides:
13394 //
13395 // (End - Start) + (Stride - 1) <= UMAX
13396 //
13397 // In other words, the addition doesn't have unsigned overflow.
13398 //
13399 // A similar proof works if we treat Start/End as signed values.
13400 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13401 // to use signed max instead of unsigned max. Note that we're
13402 // trying to prove a lack of unsigned overflow in either case.
13403 return false;
13404 }
13405 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13406 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13407 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13408 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13409 // 1 <s End.
13410 //
13411 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13412 // End.
13413 return false;
13414 }
13415 return true;
13416 }();
13417
13418 const SCEV *Delta = getMinusSCEV(End, Start);
13419 if (!MayAddOverflow) {
13420 // floor((D + (S - 1)) / S)
13421 // We prefer this formulation if it's legal because it's fewer
13422 // operations.
13423 BECount =
13424 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13425 } else {
13426 BECount = getUDivCeilSCEV(Delta, Stride);
13427 }
13428 }
13429 }
13430
13431 const SCEV *ConstantMaxBECount;
13432 bool MaxOrZero = false;
13433 if (isa<SCEVConstant>(BECount)) {
13434 ConstantMaxBECount = BECount;
13435 } else if (BECountIfBackedgeTaken &&
13436 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13437 // If we know exactly how many times the backedge will be taken if it's
13438 // taken at least once, then the backedge count will either be that or
13439 // zero.
13440 ConstantMaxBECount = BECountIfBackedgeTaken;
13441 MaxOrZero = true;
13442 } else {
13443 ConstantMaxBECount = computeMaxBECountForLT(
13444 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13445 }
13446
13447 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13448 !isa<SCEVCouldNotCompute>(BECount))
13449 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13450
13451 const SCEV *SymbolicMaxBECount =
13452 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13453 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13454 Predicates);
13455}
13456
13457ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13458 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13459 bool ControlsOnlyExit, bool AllowPredicates) {
13461 // We handle only IV > Invariant
13462 if (!isLoopInvariant(RHS, L))
13463 return getCouldNotCompute();
13464
13465 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13466 if (!IV && AllowPredicates)
13467 // Try to make this an AddRec using runtime tests, in the first X
13468 // iterations of this loop, where X is the SCEV expression found by the
13469 // algorithm below.
13470 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13471
13472 // Avoid weird loops
13473 if (!IV || IV->getLoop() != L || !IV->isAffine())
13474 return getCouldNotCompute();
13475
13476 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13477 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13479
13480 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13481
13482 // Avoid negative or zero stride values
13483 if (!isKnownPositive(Stride))
13484 return getCouldNotCompute();
13485
13486 // Avoid proven overflow cases: this will ensure that the backedge taken count
13487 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13488 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13489 // behaviors like the case of C language.
13490 if (!Stride->isOne() && !NoWrap)
13491 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13492 return getCouldNotCompute();
13493
13494 const SCEV *Start = IV->getStart();
13495 const SCEV *End = RHS;
13496 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13497 // If we know that Start >= RHS in the context of loop, then we know that
13498 // min(RHS, Start) = RHS at this point.
13500 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13501 End = RHS;
13502 else
13503 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13504 }
13505
13506 if (Start->getType()->isPointerTy()) {
13508 if (isa<SCEVCouldNotCompute>(Start))
13509 return Start;
13510 }
13511 if (End->getType()->isPointerTy()) {
13512 End = getLosslessPtrToIntExpr(End);
13513 if (isa<SCEVCouldNotCompute>(End))
13514 return End;
13515 }
13516
13517 // Compute ((Start - End) + (Stride - 1)) / Stride.
13518 // FIXME: This can overflow. Holding off on fixing this for now;
13519 // howManyGreaterThans will hopefully be gone soon.
13520 const SCEV *One = getOne(Stride->getType());
13521 const SCEV *BECount = getUDivExpr(
13522 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13523
13524 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13526
13527 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13528 : getUnsignedRangeMin(Stride);
13529
13530 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13531 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13532 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13533
13534 // Although End can be a MIN expression we estimate MinEnd considering only
13535 // the case End = RHS. This is safe because in the other case (Start - End)
13536 // is zero, leading to a zero maximum backedge taken count.
13537 APInt MinEnd =
13538 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13539 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13540
13541 const SCEV *ConstantMaxBECount =
13542 isa<SCEVConstant>(BECount)
13543 ? BECount
13544 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13545 getConstant(MinStride));
13546
13547 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13548 ConstantMaxBECount = BECount;
13549 const SCEV *SymbolicMaxBECount =
13550 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13551
13552 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13553 Predicates);
13554}
13555
13557 ScalarEvolution &SE) const {
13558 if (Range.isFullSet()) // Infinite loop.
13559 return SE.getCouldNotCompute();
13560
13561 // If the start is a non-zero constant, shift the range to simplify things.
13562 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13563 if (!SC->getValue()->isZero()) {
13565 Operands[0] = SE.getZero(SC->getType());
13566 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13568 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13569 return ShiftedAddRec->getNumIterationsInRange(
13570 Range.subtract(SC->getAPInt()), SE);
13571 // This is strange and shouldn't happen.
13572 return SE.getCouldNotCompute();
13573 }
13574
13575 // The only time we can solve this is when we have all constant indices.
13576 // Otherwise, we cannot determine the overflow conditions.
13577 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13578 return SE.getCouldNotCompute();
13579
13580 // Okay at this point we know that all elements of the chrec are constants and
13581 // that the start element is zero.
13582
13583 // First check to see if the range contains zero. If not, the first
13584 // iteration exits.
13585 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13586 if (!Range.contains(APInt(BitWidth, 0)))
13587 return SE.getZero(getType());
13588
13589 if (isAffine()) {
13590 // If this is an affine expression then we have this situation:
13591 // Solve {0,+,A} in Range === Ax in Range
13592
13593 // We know that zero is in the range. If A is positive then we know that
13594 // the upper value of the range must be the first possible exit value.
13595 // If A is negative then the lower of the range is the last possible loop
13596 // value. Also note that we already checked for a full range.
13597 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13598 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13599
13600 // The exit value should be (End+A)/A.
13601 APInt ExitVal = (End + A).udiv(A);
13602 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13603
13604 // Evaluate at the exit value. If we really did fall out of the valid
13605 // range, then we computed our trip count, otherwise wrap around or other
13606 // things must have happened.
13607 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13608 if (Range.contains(Val->getValue()))
13609 return SE.getCouldNotCompute(); // Something strange happened
13610
13611 // Ensure that the previous value is in the range.
13612 assert(Range.contains(
13614 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13615 "Linear scev computation is off in a bad way!");
13616 return SE.getConstant(ExitValue);
13617 }
13618
13619 if (isQuadratic()) {
13620 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13621 return SE.getConstant(*S);
13622 }
13623
13624 return SE.getCouldNotCompute();
13625}
13626
13627const SCEVAddRecExpr *
13629 assert(getNumOperands() > 1 && "AddRec with zero step?");
13630 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13631 // but in this case we cannot guarantee that the value returned will be an
13632 // AddRec because SCEV does not have a fixed point where it stops
13633 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13634 // may happen if we reach arithmetic depth limit while simplifying. So we
13635 // construct the returned value explicitly.
13637 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13638 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13639 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13640 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13641 // We know that the last operand is not a constant zero (otherwise it would
13642 // have been popped out earlier). This guarantees us that if the result has
13643 // the same last operand, then it will also not be popped out, meaning that
13644 // the returned value will be an AddRec.
13645 const SCEV *Last = getOperand(getNumOperands() - 1);
13646 assert(!Last->isZero() && "Recurrency with zero step?");
13647 Ops.push_back(Last);
13650}
13651
13652// Return true when S contains at least an undef value.
13654 return SCEVExprContains(S, [](const SCEV *S) {
13655 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13656 return isa<UndefValue>(SU->getValue());
13657 return false;
13658 });
13659}
13660
13661// Return true when S contains a value that is a nullptr.
13663 return SCEVExprContains(S, [](const SCEV *S) {
13664 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13665 return SU->getValue() == nullptr;
13666 return false;
13667 });
13668}
13669
13670/// Return the size of an element read or written by Inst.
13672 Type *Ty;
13673 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13674 Ty = Store->getValueOperand()->getType();
13675 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13676 Ty = Load->getType();
13677 else
13678 return nullptr;
13679
13681 return getSizeOfExpr(ETy, Ty);
13682}
13683
13684//===----------------------------------------------------------------------===//
13685// SCEVCallbackVH Class Implementation
13686//===----------------------------------------------------------------------===//
13687
13689 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13690 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13691 SE->ConstantEvolutionLoopExitValue.erase(PN);
13692 SE->eraseValueFromMap(getValPtr());
13693 // this now dangles!
13694}
13695
13696void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13697 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13698
13699 // Forget all the expressions associated with users of the old value,
13700 // so that future queries will recompute the expressions using the new
13701 // value.
13702 SE->forgetValue(getValPtr());
13703 // this now dangles!
13704}
13705
13706ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13707 : CallbackVH(V), SE(se) {}
13708
13709//===----------------------------------------------------------------------===//
13710// ScalarEvolution Class Implementation
13711//===----------------------------------------------------------------------===//
13712
13715 LoopInfo &LI)
13716 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13717 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13718 LoopDispositions(64), BlockDispositions(64) {
13719 // To use guards for proving predicates, we need to scan every instruction in
13720 // relevant basic blocks, and not just terminators. Doing this is a waste of
13721 // time if the IR does not actually contain any calls to
13722 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13723 //
13724 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13725 // to _add_ guards to the module when there weren't any before, and wants
13726 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13727 // efficient in lieu of being smart in that rather obscure case.
13728
13729 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13730 F.getParent(), Intrinsic::experimental_guard);
13731 HasGuards = GuardDecl && !GuardDecl->use_empty();
13732}
13733
13735 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13736 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13737 ValueExprMap(std::move(Arg.ValueExprMap)),
13738 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13739 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13740 PendingMerges(std::move(Arg.PendingMerges)),
13741 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13742 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13743 PredicatedBackedgeTakenCounts(
13744 std::move(Arg.PredicatedBackedgeTakenCounts)),
13745 BECountUsers(std::move(Arg.BECountUsers)),
13746 ConstantEvolutionLoopExitValue(
13747 std::move(Arg.ConstantEvolutionLoopExitValue)),
13748 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13749 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13750 LoopDispositions(std::move(Arg.LoopDispositions)),
13751 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13752 BlockDispositions(std::move(Arg.BlockDispositions)),
13753 SCEVUsers(std::move(Arg.SCEVUsers)),
13754 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13755 SignedRanges(std::move(Arg.SignedRanges)),
13756 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13757 UniquePreds(std::move(Arg.UniquePreds)),
13758 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13759 LoopUsers(std::move(Arg.LoopUsers)),
13760 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13761 FirstUnknown(Arg.FirstUnknown) {
13762 Arg.FirstUnknown = nullptr;
13763}
13764
13766 // Iterate through all the SCEVUnknown instances and call their
13767 // destructors, so that they release their references to their values.
13768 for (SCEVUnknown *U = FirstUnknown; U;) {
13769 SCEVUnknown *Tmp = U;
13770 U = U->Next;
13771 Tmp->~SCEVUnknown();
13772 }
13773 FirstUnknown = nullptr;
13774
13775 ExprValueMap.clear();
13776 ValueExprMap.clear();
13777 HasRecMap.clear();
13778 BackedgeTakenCounts.clear();
13779 PredicatedBackedgeTakenCounts.clear();
13780
13781 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13782 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13783 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13784 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13785 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13786}
13787
13791
13792/// When printing a top-level SCEV for trip counts, it's helpful to include
13793/// a type for constants which are otherwise hard to disambiguate.
13794static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13795 if (isa<SCEVConstant>(S))
13796 OS << *S->getType() << " ";
13797 OS << *S;
13798}
13799
13801 const Loop *L) {
13802 // Print all inner loops first
13803 for (Loop *I : *L)
13804 PrintLoopInfo(OS, SE, I);
13805
13806 OS << "Loop ";
13807 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13808 OS << ": ";
13809
13810 SmallVector<BasicBlock *, 8> ExitingBlocks;
13811 L->getExitingBlocks(ExitingBlocks);
13812 if (ExitingBlocks.size() != 1)
13813 OS << "<multiple exits> ";
13814
13815 auto *BTC = SE->getBackedgeTakenCount(L);
13816 if (!isa<SCEVCouldNotCompute>(BTC)) {
13817 OS << "backedge-taken count is ";
13818 PrintSCEVWithTypeHint(OS, BTC);
13819 } else
13820 OS << "Unpredictable backedge-taken count.";
13821 OS << "\n";
13822
13823 if (ExitingBlocks.size() > 1)
13824 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13825 OS << " exit count for " << ExitingBlock->getName() << ": ";
13826 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13827 PrintSCEVWithTypeHint(OS, EC);
13828 if (isa<SCEVCouldNotCompute>(EC)) {
13829 // Retry with predicates.
13831 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13832 if (!isa<SCEVCouldNotCompute>(EC)) {
13833 OS << "\n predicated exit count for " << ExitingBlock->getName()
13834 << ": ";
13835 PrintSCEVWithTypeHint(OS, EC);
13836 OS << "\n Predicates:\n";
13837 for (const auto *P : Predicates)
13838 P->print(OS, 4);
13839 }
13840 }
13841 OS << "\n";
13842 }
13843
13844 OS << "Loop ";
13845 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13846 OS << ": ";
13847
13848 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13849 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13850 OS << "constant max backedge-taken count is ";
13851 PrintSCEVWithTypeHint(OS, ConstantBTC);
13853 OS << ", actual taken count either this or zero.";
13854 } else {
13855 OS << "Unpredictable constant max backedge-taken count. ";
13856 }
13857
13858 OS << "\n"
13859 "Loop ";
13860 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13861 OS << ": ";
13862
13863 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13864 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13865 OS << "symbolic max backedge-taken count is ";
13866 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13868 OS << ", actual taken count either this or zero.";
13869 } else {
13870 OS << "Unpredictable symbolic max backedge-taken count. ";
13871 }
13872 OS << "\n";
13873
13874 if (ExitingBlocks.size() > 1)
13875 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13876 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13877 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13879 PrintSCEVWithTypeHint(OS, ExitBTC);
13880 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13881 // Retry with predicates.
13883 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13885 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13886 OS << "\n predicated symbolic max exit count for "
13887 << ExitingBlock->getName() << ": ";
13888 PrintSCEVWithTypeHint(OS, ExitBTC);
13889 OS << "\n Predicates:\n";
13890 for (const auto *P : Predicates)
13891 P->print(OS, 4);
13892 }
13893 }
13894 OS << "\n";
13895 }
13896
13898 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13899 if (PBT != BTC) {
13900 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13901 OS << "Loop ";
13902 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13903 OS << ": ";
13904 if (!isa<SCEVCouldNotCompute>(PBT)) {
13905 OS << "Predicated backedge-taken count is ";
13906 PrintSCEVWithTypeHint(OS, PBT);
13907 } else
13908 OS << "Unpredictable predicated backedge-taken count.";
13909 OS << "\n";
13910 OS << " Predicates:\n";
13911 for (const auto *P : Preds)
13912 P->print(OS, 4);
13913 }
13914 Preds.clear();
13915
13916 auto *PredConstantMax =
13918 if (PredConstantMax != ConstantBTC) {
13919 assert(!Preds.empty() &&
13920 "different predicated constant max BTC but no predicates");
13921 OS << "Loop ";
13922 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13923 OS << ": ";
13924 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13925 OS << "Predicated constant max backedge-taken count is ";
13926 PrintSCEVWithTypeHint(OS, PredConstantMax);
13927 } else
13928 OS << "Unpredictable predicated constant max backedge-taken count.";
13929 OS << "\n";
13930 OS << " Predicates:\n";
13931 for (const auto *P : Preds)
13932 P->print(OS, 4);
13933 }
13934 Preds.clear();
13935
13936 auto *PredSymbolicMax =
13938 if (SymbolicBTC != PredSymbolicMax) {
13939 assert(!Preds.empty() &&
13940 "Different predicated symbolic max BTC, but no predicates");
13941 OS << "Loop ";
13942 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13943 OS << ": ";
13944 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13945 OS << "Predicated symbolic max backedge-taken count is ";
13946 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13947 } else
13948 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13949 OS << "\n";
13950 OS << " Predicates:\n";
13951 for (const auto *P : Preds)
13952 P->print(OS, 4);
13953 }
13954
13956 OS << "Loop ";
13957 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13958 OS << ": ";
13959 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13960 }
13961}
13962
13963namespace llvm {
13965 switch (LD) {
13967 OS << "Variant";
13968 break;
13970 OS << "Invariant";
13971 break;
13973 OS << "Computable";
13974 break;
13975 }
13976 return OS;
13977}
13978
13980 switch (BD) {
13982 OS << "DoesNotDominate";
13983 break;
13985 OS << "Dominates";
13986 break;
13988 OS << "ProperlyDominates";
13989 break;
13990 }
13991 return OS;
13992}
13993} // namespace llvm
13994
13996 // ScalarEvolution's implementation of the print method is to print
13997 // out SCEV values of all instructions that are interesting. Doing
13998 // this potentially causes it to create new SCEV objects though,
13999 // which technically conflicts with the const qualifier. This isn't
14000 // observable from outside the class though, so casting away the
14001 // const isn't dangerous.
14002 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14003
14004 if (ClassifyExpressions) {
14005 OS << "Classifying expressions for: ";
14006 F.printAsOperand(OS, /*PrintType=*/false);
14007 OS << "\n";
14008 for (Instruction &I : instructions(F))
14009 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14010 OS << I << '\n';
14011 OS << " --> ";
14012 const SCEV *SV = SE.getSCEV(&I);
14013 SV->print(OS);
14014 if (!isa<SCEVCouldNotCompute>(SV)) {
14015 OS << " U: ";
14016 SE.getUnsignedRange(SV).print(OS);
14017 OS << " S: ";
14018 SE.getSignedRange(SV).print(OS);
14019 }
14020
14021 const Loop *L = LI.getLoopFor(I.getParent());
14022
14023 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14024 if (AtUse != SV) {
14025 OS << " --> ";
14026 AtUse->print(OS);
14027 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14028 OS << " U: ";
14029 SE.getUnsignedRange(AtUse).print(OS);
14030 OS << " S: ";
14031 SE.getSignedRange(AtUse).print(OS);
14032 }
14033 }
14034
14035 if (L) {
14036 OS << "\t\t" "Exits: ";
14037 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14038 if (!SE.isLoopInvariant(ExitValue, L)) {
14039 OS << "<<Unknown>>";
14040 } else {
14041 OS << *ExitValue;
14042 }
14043
14044 bool First = true;
14045 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14046 if (First) {
14047 OS << "\t\t" "LoopDispositions: { ";
14048 First = false;
14049 } else {
14050 OS << ", ";
14051 }
14052
14053 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14054 OS << ": " << SE.getLoopDisposition(SV, Iter);
14055 }
14056
14057 for (const auto *InnerL : depth_first(L)) {
14058 if (InnerL == L)
14059 continue;
14060 if (First) {
14061 OS << "\t\t" "LoopDispositions: { ";
14062 First = false;
14063 } else {
14064 OS << ", ";
14065 }
14066
14067 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14068 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14069 }
14070
14071 OS << " }";
14072 }
14073
14074 OS << "\n";
14075 }
14076 }
14077
14078 OS << "Determining loop execution counts for: ";
14079 F.printAsOperand(OS, /*PrintType=*/false);
14080 OS << "\n";
14081 for (Loop *I : LI)
14082 PrintLoopInfo(OS, &SE, I);
14083}
14084
14087 auto &Values = LoopDispositions[S];
14088 for (auto &V : Values) {
14089 if (V.getPointer() == L)
14090 return V.getInt();
14091 }
14092 Values.emplace_back(L, LoopVariant);
14093 LoopDisposition D = computeLoopDisposition(S, L);
14094 auto &Values2 = LoopDispositions[S];
14095 for (auto &V : llvm::reverse(Values2)) {
14096 if (V.getPointer() == L) {
14097 V.setInt(D);
14098 break;
14099 }
14100 }
14101 return D;
14102}
14103
14105ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14106 switch (S->getSCEVType()) {
14107 case scConstant:
14108 case scVScale:
14109 return LoopInvariant;
14110 case scAddRecExpr: {
14111 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14112
14113 // If L is the addrec's loop, it's computable.
14114 if (AR->getLoop() == L)
14115 return LoopComputable;
14116
14117 // Add recurrences are never invariant in the function-body (null loop).
14118 if (!L)
14119 return LoopVariant;
14120
14121 // Everything that is not defined at loop entry is variant.
14122 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14123 return LoopVariant;
14124 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14125 " dominate the contained loop's header?");
14126
14127 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14128 if (AR->getLoop()->contains(L))
14129 return LoopInvariant;
14130
14131 // This recurrence is variant w.r.t. L if any of its operands
14132 // are variant.
14133 for (const auto *Op : AR->operands())
14134 if (!isLoopInvariant(Op, L))
14135 return LoopVariant;
14136
14137 // Otherwise it's loop-invariant.
14138 return LoopInvariant;
14139 }
14140 case scTruncate:
14141 case scZeroExtend:
14142 case scSignExtend:
14143 case scPtrToInt:
14144 case scAddExpr:
14145 case scMulExpr:
14146 case scUDivExpr:
14147 case scUMaxExpr:
14148 case scSMaxExpr:
14149 case scUMinExpr:
14150 case scSMinExpr:
14151 case scSequentialUMinExpr: {
14152 bool HasVarying = false;
14153 for (const auto *Op : S->operands()) {
14155 if (D == LoopVariant)
14156 return LoopVariant;
14157 if (D == LoopComputable)
14158 HasVarying = true;
14159 }
14160 return HasVarying ? LoopComputable : LoopInvariant;
14161 }
14162 case scUnknown:
14163 // All non-instruction values are loop invariant. All instructions are loop
14164 // invariant if they are not contained in the specified loop.
14165 // Instructions are never considered invariant in the function body
14166 // (null loop) because they are defined within the "loop".
14167 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14168 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14169 return LoopInvariant;
14170 case scCouldNotCompute:
14171 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14172 }
14173 llvm_unreachable("Unknown SCEV kind!");
14174}
14175
14177 return getLoopDisposition(S, L) == LoopInvariant;
14178}
14179
14181 return getLoopDisposition(S, L) == LoopComputable;
14182}
14183
14186 auto &Values = BlockDispositions[S];
14187 for (auto &V : Values) {
14188 if (V.getPointer() == BB)
14189 return V.getInt();
14190 }
14191 Values.emplace_back(BB, DoesNotDominateBlock);
14192 BlockDisposition D = computeBlockDisposition(S, BB);
14193 auto &Values2 = BlockDispositions[S];
14194 for (auto &V : llvm::reverse(Values2)) {
14195 if (V.getPointer() == BB) {
14196 V.setInt(D);
14197 break;
14198 }
14199 }
14200 return D;
14201}
14202
14204ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14205 switch (S->getSCEVType()) {
14206 case scConstant:
14207 case scVScale:
14209 case scAddRecExpr: {
14210 // This uses a "dominates" query instead of "properly dominates" query
14211 // to test for proper dominance too, because the instruction which
14212 // produces the addrec's value is a PHI, and a PHI effectively properly
14213 // dominates its entire containing block.
14214 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14215 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14216 return DoesNotDominateBlock;
14217
14218 // Fall through into SCEVNAryExpr handling.
14219 [[fallthrough]];
14220 }
14221 case scTruncate:
14222 case scZeroExtend:
14223 case scSignExtend:
14224 case scPtrToInt:
14225 case scAddExpr:
14226 case scMulExpr:
14227 case scUDivExpr:
14228 case scUMaxExpr:
14229 case scSMaxExpr:
14230 case scUMinExpr:
14231 case scSMinExpr:
14232 case scSequentialUMinExpr: {
14233 bool Proper = true;
14234 for (const SCEV *NAryOp : S->operands()) {
14236 if (D == DoesNotDominateBlock)
14237 return DoesNotDominateBlock;
14238 if (D == DominatesBlock)
14239 Proper = false;
14240 }
14241 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14242 }
14243 case scUnknown:
14244 if (Instruction *I =
14245 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14246 if (I->getParent() == BB)
14247 return DominatesBlock;
14248 if (DT.properlyDominates(I->getParent(), BB))
14250 return DoesNotDominateBlock;
14251 }
14253 case scCouldNotCompute:
14254 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14255 }
14256 llvm_unreachable("Unknown SCEV kind!");
14257}
14258
14259bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14260 return getBlockDisposition(S, BB) >= DominatesBlock;
14261}
14262
14265}
14266
14267bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14268 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14269}
14270
14271void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14272 bool Predicated) {
14273 auto &BECounts =
14274 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14275 auto It = BECounts.find(L);
14276 if (It != BECounts.end()) {
14277 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14278 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14279 if (!isa<SCEVConstant>(S)) {
14280 auto UserIt = BECountUsers.find(S);
14281 assert(UserIt != BECountUsers.end());
14282 UserIt->second.erase({L, Predicated});
14283 }
14284 }
14285 }
14286 BECounts.erase(It);
14287 }
14288}
14289
14290void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14291 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14292 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14293
14294 while (!Worklist.empty()) {
14295 const SCEV *Curr = Worklist.pop_back_val();
14296 auto Users = SCEVUsers.find(Curr);
14297 if (Users != SCEVUsers.end())
14298 for (const auto *User : Users->second)
14299 if (ToForget.insert(User).second)
14300 Worklist.push_back(User);
14301 }
14302
14303 for (const auto *S : ToForget)
14304 forgetMemoizedResultsImpl(S);
14305
14306 for (auto I = PredicatedSCEVRewrites.begin();
14307 I != PredicatedSCEVRewrites.end();) {
14308 std::pair<const SCEV *, const Loop *> Entry = I->first;
14309 if (ToForget.count(Entry.first))
14310 PredicatedSCEVRewrites.erase(I++);
14311 else
14312 ++I;
14313 }
14314}
14315
14316void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14317 LoopDispositions.erase(S);
14318 BlockDispositions.erase(S);
14319 UnsignedRanges.erase(S);
14320 SignedRanges.erase(S);
14321 HasRecMap.erase(S);
14322 ConstantMultipleCache.erase(S);
14323
14324 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14325 UnsignedWrapViaInductionTried.erase(AR);
14326 SignedWrapViaInductionTried.erase(AR);
14327 }
14328
14329 auto ExprIt = ExprValueMap.find(S);
14330 if (ExprIt != ExprValueMap.end()) {
14331 for (Value *V : ExprIt->second) {
14332 auto ValueIt = ValueExprMap.find_as(V);
14333 if (ValueIt != ValueExprMap.end())
14334 ValueExprMap.erase(ValueIt);
14335 }
14336 ExprValueMap.erase(ExprIt);
14337 }
14338
14339 auto ScopeIt = ValuesAtScopes.find(S);
14340 if (ScopeIt != ValuesAtScopes.end()) {
14341 for (const auto &Pair : ScopeIt->second)
14342 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14343 llvm::erase(ValuesAtScopesUsers[Pair.second],
14344 std::make_pair(Pair.first, S));
14345 ValuesAtScopes.erase(ScopeIt);
14346 }
14347
14348 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14349 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14350 for (const auto &Pair : ScopeUserIt->second)
14351 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14352 ValuesAtScopesUsers.erase(ScopeUserIt);
14353 }
14354
14355 auto BEUsersIt = BECountUsers.find(S);
14356 if (BEUsersIt != BECountUsers.end()) {
14357 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14358 auto Copy = BEUsersIt->second;
14359 for (const auto &Pair : Copy)
14360 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14361 BECountUsers.erase(BEUsersIt);
14362 }
14363
14364 auto FoldUser = FoldCacheUser.find(S);
14365 if (FoldUser != FoldCacheUser.end())
14366 for (auto &KV : FoldUser->second)
14367 FoldCache.erase(KV);
14368 FoldCacheUser.erase(S);
14369}
14370
14371void
14372ScalarEvolution::getUsedLoops(const SCEV *S,
14373 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14374 struct FindUsedLoops {
14375 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14376 : LoopsUsed(LoopsUsed) {}
14377 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14378 bool follow(const SCEV *S) {
14379 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14380 LoopsUsed.insert(AR->getLoop());
14381 return true;
14382 }
14383
14384 bool isDone() const { return false; }
14385 };
14386
14387 FindUsedLoops F(LoopsUsed);
14388 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14389}
14390
14391void ScalarEvolution::getReachableBlocks(
14394 Worklist.push_back(&F.getEntryBlock());
14395 while (!Worklist.empty()) {
14396 BasicBlock *BB = Worklist.pop_back_val();
14397 if (!Reachable.insert(BB).second)
14398 continue;
14399
14400 Value *Cond;
14401 BasicBlock *TrueBB, *FalseBB;
14402 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14403 m_BasicBlock(FalseBB)))) {
14404 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14405 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14406 continue;
14407 }
14408
14409 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14410 const SCEV *L = getSCEV(Cmp->getOperand(0));
14411 const SCEV *R = getSCEV(Cmp->getOperand(1));
14412 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14413 Worklist.push_back(TrueBB);
14414 continue;
14415 }
14416 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14417 R)) {
14418 Worklist.push_back(FalseBB);
14419 continue;
14420 }
14421 }
14422 }
14423
14424 append_range(Worklist, successors(BB));
14425 }
14426}
14427
14429 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14430 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14431
14432 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14433
14434 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14435 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14436 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14437
14438 const SCEV *visitConstant(const SCEVConstant *Constant) {
14439 return SE.getConstant(Constant->getAPInt());
14440 }
14441
14442 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14443 return SE.getUnknown(Expr->getValue());
14444 }
14445
14446 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14447 return SE.getCouldNotCompute();
14448 }
14449 };
14450
14451 SCEVMapper SCM(SE2);
14452 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14453 SE2.getReachableBlocks(ReachableBlocks, F);
14454
14455 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14456 if (containsUndefs(Old) || containsUndefs(New)) {
14457 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14458 // not propagate undef aggressively). This means we can (and do) fail
14459 // verification in cases where a transform makes a value go from "undef"
14460 // to "undef+1" (say). The transform is fine, since in both cases the
14461 // result is "undef", but SCEV thinks the value increased by 1.
14462 return nullptr;
14463 }
14464
14465 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14466 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14467 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14468 return nullptr;
14469
14470 return Delta;
14471 };
14472
14473 while (!LoopStack.empty()) {
14474 auto *L = LoopStack.pop_back_val();
14475 llvm::append_range(LoopStack, *L);
14476
14477 // Only verify BECounts in reachable loops. For an unreachable loop,
14478 // any BECount is legal.
14479 if (!ReachableBlocks.contains(L->getHeader()))
14480 continue;
14481
14482 // Only verify cached BECounts. Computing new BECounts may change the
14483 // results of subsequent SCEV uses.
14484 auto It = BackedgeTakenCounts.find(L);
14485 if (It == BackedgeTakenCounts.end())
14486 continue;
14487
14488 auto *CurBECount =
14489 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14490 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14491
14492 if (CurBECount == SE2.getCouldNotCompute() ||
14493 NewBECount == SE2.getCouldNotCompute()) {
14494 // NB! This situation is legal, but is very suspicious -- whatever pass
14495 // change the loop to make a trip count go from could not compute to
14496 // computable or vice-versa *should have* invalidated SCEV. However, we
14497 // choose not to assert here (for now) since we don't want false
14498 // positives.
14499 continue;
14500 }
14501
14502 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14503 SE.getTypeSizeInBits(NewBECount->getType()))
14504 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14505 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14506 SE.getTypeSizeInBits(NewBECount->getType()))
14507 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14508
14509 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14510 if (Delta && !Delta->isZero()) {
14511 dbgs() << "Trip Count for " << *L << " Changed!\n";
14512 dbgs() << "Old: " << *CurBECount << "\n";
14513 dbgs() << "New: " << *NewBECount << "\n";
14514 dbgs() << "Delta: " << *Delta << "\n";
14515 std::abort();
14516 }
14517 }
14518
14519 // Collect all valid loops currently in LoopInfo.
14520 SmallPtrSet<Loop *, 32> ValidLoops;
14521 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14522 while (!Worklist.empty()) {
14523 Loop *L = Worklist.pop_back_val();
14524 if (ValidLoops.insert(L).second)
14525 Worklist.append(L->begin(), L->end());
14526 }
14527 for (const auto &KV : ValueExprMap) {
14528#ifndef NDEBUG
14529 // Check for SCEV expressions referencing invalid/deleted loops.
14530 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14531 assert(ValidLoops.contains(AR->getLoop()) &&
14532 "AddRec references invalid loop");
14533 }
14534#endif
14535
14536 // Check that the value is also part of the reverse map.
14537 auto It = ExprValueMap.find(KV.second);
14538 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14539 dbgs() << "Value " << *KV.first
14540 << " is in ValueExprMap but not in ExprValueMap\n";
14541 std::abort();
14542 }
14543
14544 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14545 if (!ReachableBlocks.contains(I->getParent()))
14546 continue;
14547 const SCEV *OldSCEV = SCM.visit(KV.second);
14548 const SCEV *NewSCEV = SE2.getSCEV(I);
14549 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14550 if (Delta && !Delta->isZero()) {
14551 dbgs() << "SCEV for value " << *I << " changed!\n"
14552 << "Old: " << *OldSCEV << "\n"
14553 << "New: " << *NewSCEV << "\n"
14554 << "Delta: " << *Delta << "\n";
14555 std::abort();
14556 }
14557 }
14558 }
14559
14560 for (const auto &KV : ExprValueMap) {
14561 for (Value *V : KV.second) {
14562 const SCEV *S = ValueExprMap.lookup(V);
14563 if (!S) {
14564 dbgs() << "Value " << *V
14565 << " is in ExprValueMap but not in ValueExprMap\n";
14566 std::abort();
14567 }
14568 if (S != KV.first) {
14569 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14570 << *KV.first << "\n";
14571 std::abort();
14572 }
14573 }
14574 }
14575
14576 // Verify integrity of SCEV users.
14577 for (const auto &S : UniqueSCEVs) {
14578 for (const auto *Op : S.operands()) {
14579 // We do not store dependencies of constants.
14580 if (isa<SCEVConstant>(Op))
14581 continue;
14582 auto It = SCEVUsers.find(Op);
14583 if (It != SCEVUsers.end() && It->second.count(&S))
14584 continue;
14585 dbgs() << "Use of operand " << *Op << " by user " << S
14586 << " is not being tracked!\n";
14587 std::abort();
14588 }
14589 }
14590
14591 // Verify integrity of ValuesAtScopes users.
14592 for (const auto &ValueAndVec : ValuesAtScopes) {
14593 const SCEV *Value = ValueAndVec.first;
14594 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14595 const Loop *L = LoopAndValueAtScope.first;
14596 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14597 if (!isa<SCEVConstant>(ValueAtScope)) {
14598 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14599 if (It != ValuesAtScopesUsers.end() &&
14600 is_contained(It->second, std::make_pair(L, Value)))
14601 continue;
14602 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14603 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14604 std::abort();
14605 }
14606 }
14607 }
14608
14609 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14610 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14611 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14612 const Loop *L = LoopAndValue.first;
14613 const SCEV *Value = LoopAndValue.second;
14615 auto It = ValuesAtScopes.find(Value);
14616 if (It != ValuesAtScopes.end() &&
14617 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14618 continue;
14619 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14620 << *ValueAtScope << " missing in ValuesAtScopes\n";
14621 std::abort();
14622 }
14623 }
14624
14625 // Verify integrity of BECountUsers.
14626 auto VerifyBECountUsers = [&](bool Predicated) {
14627 auto &BECounts =
14628 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14629 for (const auto &LoopAndBEInfo : BECounts) {
14630 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14631 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14632 if (!isa<SCEVConstant>(S)) {
14633 auto UserIt = BECountUsers.find(S);
14634 if (UserIt != BECountUsers.end() &&
14635 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14636 continue;
14637 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14638 << " missing from BECountUsers\n";
14639 std::abort();
14640 }
14641 }
14642 }
14643 }
14644 };
14645 VerifyBECountUsers(/* Predicated */ false);
14646 VerifyBECountUsers(/* Predicated */ true);
14647
14648 // Verify intergity of loop disposition cache.
14649 for (auto &[S, Values] : LoopDispositions) {
14650 for (auto [Loop, CachedDisposition] : Values) {
14651 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14652 if (CachedDisposition != RecomputedDisposition) {
14653 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14654 << " is incorrect: cached " << CachedDisposition << ", actual "
14655 << RecomputedDisposition << "\n";
14656 std::abort();
14657 }
14658 }
14659 }
14660
14661 // Verify integrity of the block disposition cache.
14662 for (auto &[S, Values] : BlockDispositions) {
14663 for (auto [BB, CachedDisposition] : Values) {
14664 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14665 if (CachedDisposition != RecomputedDisposition) {
14666 dbgs() << "Cached disposition of " << *S << " for block %"
14667 << BB->getName() << " is incorrect: cached " << CachedDisposition
14668 << ", actual " << RecomputedDisposition << "\n";
14669 std::abort();
14670 }
14671 }
14672 }
14673
14674 // Verify FoldCache/FoldCacheUser caches.
14675 for (auto [FoldID, Expr] : FoldCache) {
14676 auto I = FoldCacheUser.find(Expr);
14677 if (I == FoldCacheUser.end()) {
14678 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14679 << "!\n";
14680 std::abort();
14681 }
14682 if (!is_contained(I->second, FoldID)) {
14683 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14684 std::abort();
14685 }
14686 }
14687 for (auto [Expr, IDs] : FoldCacheUser) {
14688 for (auto &FoldID : IDs) {
14689 const SCEV *S = FoldCache.lookup(FoldID);
14690 if (!S) {
14691 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14692 << "!\n";
14693 std::abort();
14694 }
14695 if (S != Expr) {
14696 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14697 << " != " << *Expr << "!\n";
14698 std::abort();
14699 }
14700 }
14701 }
14702
14703 // Verify that ConstantMultipleCache computations are correct. We check that
14704 // cached multiples and recomputed multiples are multiples of each other to
14705 // verify correctness. It is possible that a recomputed multiple is different
14706 // from the cached multiple due to strengthened no wrap flags or changes in
14707 // KnownBits computations.
14708 for (auto [S, Multiple] : ConstantMultipleCache) {
14709 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14710 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14711 Multiple.urem(RecomputedMultiple) != 0 &&
14712 RecomputedMultiple.urem(Multiple) != 0)) {
14713 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14714 << *S << " : Computed " << RecomputedMultiple
14715 << " but cache contains " << Multiple << "!\n";
14716 std::abort();
14717 }
14718 }
14719}
14720
14722 Function &F, const PreservedAnalyses &PA,
14723 FunctionAnalysisManager::Invalidator &Inv) {
14724 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14725 // of its dependencies is invalidated.
14726 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14727 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14728 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14729 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14730 Inv.invalidate<LoopAnalysis>(F, PA);
14731}
14732
14733AnalysisKey ScalarEvolutionAnalysis::Key;
14734
14737 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14738 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14739 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14740 auto &LI = AM.getResult<LoopAnalysis>(F);
14741 return ScalarEvolution(F, TLI, AC, DT, LI);
14742}
14743
14749
14752 // For compatibility with opt's -analyze feature under legacy pass manager
14753 // which was not ported to NPM. This keeps tests using
14754 // update_analyze_test_checks.py working.
14755 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14756 << F.getName() << "':\n";
14758 return PreservedAnalyses::all();
14759}
14760
14762 "Scalar Evolution Analysis", false, true)
14768 "Scalar Evolution Analysis", false, true)
14769
14771
14773
14775 SE.reset(new ScalarEvolution(
14777 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14779 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14780 return false;
14781}
14782
14784
14786 SE->print(OS);
14787}
14788
14790 if (!VerifySCEV)
14791 return;
14792
14793 SE->verify();
14794}
14795
14803
14805 const SCEV *RHS) {
14806 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14807}
14808
14809const SCEVPredicate *
14811 const SCEV *LHS, const SCEV *RHS) {
14813 assert(LHS->getType() == RHS->getType() &&
14814 "Type mismatch between LHS and RHS");
14815 // Unique this node based on the arguments
14816 ID.AddInteger(SCEVPredicate::P_Compare);
14817 ID.AddInteger(Pred);
14818 ID.AddPointer(LHS);
14819 ID.AddPointer(RHS);
14820 void *IP = nullptr;
14821 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14822 return S;
14823 SCEVComparePredicate *Eq = new (SCEVAllocator)
14824 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14825 UniquePreds.InsertNode(Eq, IP);
14826 return Eq;
14827}
14828
14830 const SCEVAddRecExpr *AR,
14833 // Unique this node based on the arguments
14834 ID.AddInteger(SCEVPredicate::P_Wrap);
14835 ID.AddPointer(AR);
14836 ID.AddInteger(AddedFlags);
14837 void *IP = nullptr;
14838 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14839 return S;
14840 auto *OF = new (SCEVAllocator)
14841 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14842 UniquePreds.InsertNode(OF, IP);
14843 return OF;
14844}
14845
14846namespace {
14847
14848class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14849public:
14850
14851 /// Rewrites \p S in the context of a loop L and the SCEV predication
14852 /// infrastructure.
14853 ///
14854 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14855 /// equivalences present in \p Pred.
14856 ///
14857 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14858 /// \p NewPreds such that the result will be an AddRecExpr.
14859 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14861 const SCEVPredicate *Pred) {
14862 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14863 return Rewriter.visit(S);
14864 }
14865
14866 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14867 if (Pred) {
14868 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14869 for (const auto *Pred : U->getPredicates())
14870 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14871 if (IPred->getLHS() == Expr &&
14872 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14873 return IPred->getRHS();
14874 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14875 if (IPred->getLHS() == Expr &&
14876 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14877 return IPred->getRHS();
14878 }
14879 }
14880 return convertToAddRecWithPreds(Expr);
14881 }
14882
14883 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14884 const SCEV *Operand = visit(Expr->getOperand());
14885 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14886 if (AR && AR->getLoop() == L && AR->isAffine()) {
14887 // This couldn't be folded because the operand didn't have the nuw
14888 // flag. Add the nusw flag as an assumption that we could make.
14889 const SCEV *Step = AR->getStepRecurrence(SE);
14890 Type *Ty = Expr->getType();
14891 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14892 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14893 SE.getSignExtendExpr(Step, Ty), L,
14894 AR->getNoWrapFlags());
14895 }
14896 return SE.getZeroExtendExpr(Operand, Expr->getType());
14897 }
14898
14899 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14900 const SCEV *Operand = visit(Expr->getOperand());
14901 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14902 if (AR && AR->getLoop() == L && AR->isAffine()) {
14903 // This couldn't be folded because the operand didn't have the nsw
14904 // flag. Add the nssw flag as an assumption that we could make.
14905 const SCEV *Step = AR->getStepRecurrence(SE);
14906 Type *Ty = Expr->getType();
14907 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14908 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14909 SE.getSignExtendExpr(Step, Ty), L,
14910 AR->getNoWrapFlags());
14911 }
14912 return SE.getSignExtendExpr(Operand, Expr->getType());
14913 }
14914
14915private:
14916 explicit SCEVPredicateRewriter(
14917 const Loop *L, ScalarEvolution &SE,
14918 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14919 const SCEVPredicate *Pred)
14920 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14921
14922 bool addOverflowAssumption(const SCEVPredicate *P) {
14923 if (!NewPreds) {
14924 // Check if we've already made this assumption.
14925 return Pred && Pred->implies(P, SE);
14926 }
14927 NewPreds->push_back(P);
14928 return true;
14929 }
14930
14931 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14933 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14934 return addOverflowAssumption(A);
14935 }
14936
14937 // If \p Expr represents a PHINode, we try to see if it can be represented
14938 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14939 // to add this predicate as a runtime overflow check, we return the AddRec.
14940 // If \p Expr does not meet these conditions (is not a PHI node, or we
14941 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14942 // return \p Expr.
14943 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14944 if (!isa<PHINode>(Expr->getValue()))
14945 return Expr;
14946 std::optional<
14947 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14948 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14949 if (!PredicatedRewrite)
14950 return Expr;
14951 for (const auto *P : PredicatedRewrite->second){
14952 // Wrap predicates from outer loops are not supported.
14953 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14954 if (L != WP->getExpr()->getLoop())
14955 return Expr;
14956 }
14957 if (!addOverflowAssumption(P))
14958 return Expr;
14959 }
14960 return PredicatedRewrite->first;
14961 }
14962
14963 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
14964 const SCEVPredicate *Pred;
14965 const Loop *L;
14966};
14967
14968} // end anonymous namespace
14969
14970const SCEV *
14972 const SCEVPredicate &Preds) {
14973 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14974}
14975
14977 const SCEV *S, const Loop *L,
14980 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14981 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14982
14983 if (!AddRec)
14984 return nullptr;
14985
14986 // Check if any of the transformed predicates is known to be false. In that
14987 // case, it doesn't make sense to convert to a predicated AddRec, as the
14988 // versioned loop will never execute.
14989 for (const SCEVPredicate *Pred : TransformPreds) {
14990 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
14991 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
14992 continue;
14993
14994 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
14995 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
14996 if (isa<SCEVCouldNotCompute>(ExitCount))
14997 continue;
14998
14999 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15000 if (!Step->isOne())
15001 continue;
15002
15003 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15004 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15005 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15006 return nullptr;
15007 }
15008
15009 // Since the transformation was successful, we can now transfer the SCEV
15010 // predicates.
15011 Preds.append(TransformPreds.begin(), TransformPreds.end());
15012
15013 return AddRec;
15014}
15015
15016/// SCEV predicates
15020
15022 const ICmpInst::Predicate Pred,
15023 const SCEV *LHS, const SCEV *RHS)
15024 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15025 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15026 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15027}
15028
15030 ScalarEvolution &SE) const {
15031 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15032
15033 if (!Op)
15034 return false;
15035
15036 if (Pred != ICmpInst::ICMP_EQ)
15037 return false;
15038
15039 return Op->LHS == LHS && Op->RHS == RHS;
15040}
15041
15042bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15043
15045 if (Pred == ICmpInst::ICMP_EQ)
15046 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15047 else
15048 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15049 << *RHS << "\n";
15050
15051}
15052
15054 const SCEVAddRecExpr *AR,
15055 IncrementWrapFlags Flags)
15056 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15057
15058const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15059
15061 ScalarEvolution &SE) const {
15062 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15063 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15064 return false;
15065
15066 if (Op->AR == AR)
15067 return true;
15068
15069 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15071 return false;
15072
15073 const SCEV *Start = AR->getStart();
15074 const SCEV *OpStart = Op->AR->getStart();
15075 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15076 return false;
15077
15078 // Reject pointers to different address spaces.
15079 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15080 return false;
15081
15082 const SCEV *Step = AR->getStepRecurrence(SE);
15083 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15084 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15085 return false;
15086
15087 // If both steps are positive, this implies N, if N's start and step are
15088 // ULE/SLE (for NSUW/NSSW) than this'.
15089 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15090 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15091 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15092
15093 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15094 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15095 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15096 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15097 : SE.getNoopOrSignExtend(Start, WiderTy);
15099 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15100 SE.isKnownPredicate(Pred, OpStart, Start);
15101}
15102
15104 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15105 IncrementWrapFlags IFlags = Flags;
15106
15107 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15108 IFlags = clearFlags(IFlags, IncrementNSSW);
15109
15110 return IFlags == IncrementAnyWrap;
15111}
15112
15113void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15114 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15116 OS << "<nusw>";
15118 OS << "<nssw>";
15119 OS << "\n";
15120}
15121
15124 ScalarEvolution &SE) {
15125 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15126 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15127
15128 // We can safely transfer the NSW flag as NSSW.
15129 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15130 ImpliedFlags = IncrementNSSW;
15131
15132 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15133 // If the increment is positive, the SCEV NUW flag will also imply the
15134 // WrapPredicate NUSW flag.
15135 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15136 if (Step->getValue()->getValue().isNonNegative())
15137 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15138 }
15139
15140 return ImpliedFlags;
15141}
15142
15143/// Union predicates don't get cached so create a dummy set ID for it.
15145 ScalarEvolution &SE)
15146 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15147 for (const auto *P : Preds)
15148 add(P, SE);
15149}
15150
15152 return all_of(Preds,
15153 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15154}
15155
15157 ScalarEvolution &SE) const {
15158 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15159 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15160 return this->implies(I, SE);
15161 });
15162
15163 return any_of(Preds,
15164 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15165}
15166
15168 for (const auto *Pred : Preds)
15169 Pred->print(OS, Depth);
15170}
15171
15172void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15173 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15174 for (const auto *Pred : Set->Preds)
15175 add(Pred, SE);
15176 return;
15177 }
15178
15179 // Only add predicate if it is not already implied by this union predicate.
15180 if (implies(N, SE))
15181 return;
15182
15183 // Build a new vector containing the current predicates, except the ones that
15184 // are implied by the new predicate N.
15186 for (auto *P : Preds) {
15187 if (N->implies(P, SE))
15188 continue;
15189 PrunedPreds.push_back(P);
15190 }
15191 Preds = std::move(PrunedPreds);
15192 Preds.push_back(N);
15193}
15194
15196 Loop &L)
15197 : SE(SE), L(L) {
15199 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15200}
15201
15204 for (const auto *Op : Ops)
15205 // We do not expect that forgetting cached data for SCEVConstants will ever
15206 // open any prospects for sharpening or introduce any correctness issues,
15207 // so we don't bother storing their dependencies.
15208 if (!isa<SCEVConstant>(Op))
15209 SCEVUsers[Op].insert(User);
15210}
15211
15213 const SCEV *Expr = SE.getSCEV(V);
15214 RewriteEntry &Entry = RewriteMap[Expr];
15215
15216 // If we already have an entry and the version matches, return it.
15217 if (Entry.second && Generation == Entry.first)
15218 return Entry.second;
15219
15220 // We found an entry but it's stale. Rewrite the stale entry
15221 // according to the current predicate.
15222 if (Entry.second)
15223 Expr = Entry.second;
15224
15225 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15226 Entry = {Generation, NewSCEV};
15227
15228 return NewSCEV;
15229}
15230
15232 if (!BackedgeCount) {
15234 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15235 for (const auto *P : Preds)
15236 addPredicate(*P);
15237 }
15238 return BackedgeCount;
15239}
15240
15242 if (!SymbolicMaxBackedgeCount) {
15244 SymbolicMaxBackedgeCount =
15245 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15246 for (const auto *P : Preds)
15247 addPredicate(*P);
15248 }
15249 return SymbolicMaxBackedgeCount;
15250}
15251
15253 if (!SmallConstantMaxTripCount) {
15255 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15256 for (const auto *P : Preds)
15257 addPredicate(*P);
15258 }
15259 return *SmallConstantMaxTripCount;
15260}
15261
15263 if (Preds->implies(&Pred, SE))
15264 return;
15265
15266 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15267 NewPreds.push_back(&Pred);
15268 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15269 updateGeneration();
15270}
15271
15273 return *Preds;
15274}
15275
15276void PredicatedScalarEvolution::updateGeneration() {
15277 // If the generation number wrapped recompute everything.
15278 if (++Generation == 0) {
15279 for (auto &II : RewriteMap) {
15280 const SCEV *Rewritten = II.second.second;
15281 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15282 }
15283 }
15284}
15285
15288 const SCEV *Expr = getSCEV(V);
15289 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15290
15291 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15292
15293 // Clear the statically implied flags.
15294 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15295 addPredicate(*SE.getWrapPredicate(AR, Flags));
15296
15297 auto II = FlagsMap.insert({V, Flags});
15298 if (!II.second)
15299 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15300}
15301
15304 const SCEV *Expr = getSCEV(V);
15305 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15306
15308 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15309
15310 auto II = FlagsMap.find(V);
15311
15312 if (II != FlagsMap.end())
15313 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15314
15316}
15317
15319 const SCEV *Expr = this->getSCEV(V);
15321 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15322
15323 if (!New)
15324 return nullptr;
15325
15326 for (const auto *P : NewPreds)
15327 addPredicate(*P);
15328
15329 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15330 return New;
15331}
15332
15335 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15336 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15337 SE)),
15338 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15339 for (auto I : Init.FlagsMap)
15340 FlagsMap.insert(I);
15341}
15342
15344 // For each block.
15345 for (auto *BB : L.getBlocks())
15346 for (auto &I : *BB) {
15347 if (!SE.isSCEVable(I.getType()))
15348 continue;
15349
15350 auto *Expr = SE.getSCEV(&I);
15351 auto II = RewriteMap.find(Expr);
15352
15353 if (II == RewriteMap.end())
15354 continue;
15355
15356 // Don't print things that are not interesting.
15357 if (II->second.second == Expr)
15358 continue;
15359
15360 OS.indent(Depth) << "[PSE]" << I << ":\n";
15361 OS.indent(Depth + 2) << *Expr << "\n";
15362 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15363 }
15364}
15365
15366// Match the mathematical pattern A - (A / B) * B, where A and B can be
15367// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15368// for URem with constant power-of-2 second operands.
15369// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15370// 4, A / B becomes X / 8).
15371bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15372 const SCEV *&RHS) {
15373 if (Expr->getType()->isPointerTy())
15374 return false;
15375
15376 // Try to match 'zext (trunc A to iB) to iY', which is used
15377 // for URem with constant power-of-2 second operands. Make sure the size of
15378 // the operand A matches the size of the whole expressions.
15379 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15380 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15381 LHS = Trunc->getOperand();
15382 // Bail out if the type of the LHS is larger than the type of the
15383 // expression for now.
15384 if (getTypeSizeInBits(LHS->getType()) >
15385 getTypeSizeInBits(Expr->getType()))
15386 return false;
15387 if (LHS->getType() != Expr->getType())
15388 LHS = getZeroExtendExpr(LHS, Expr->getType());
15389 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15390 << getTypeSizeInBits(Trunc->getType()));
15391 return true;
15392 }
15393 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15394 if (Add == nullptr || Add->getNumOperands() != 2)
15395 return false;
15396
15397 const SCEV *A = Add->getOperand(1);
15398 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15399
15400 if (Mul == nullptr)
15401 return false;
15402
15403 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15404 // (SomeExpr + (-(SomeExpr / B) * B)).
15405 if (Expr == getURemExpr(A, B)) {
15406 LHS = A;
15407 RHS = B;
15408 return true;
15409 }
15410 return false;
15411 };
15412
15413 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15414 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15415 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15416 MatchURemWithDivisor(Mul->getOperand(2));
15417
15418 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15419 if (Mul->getNumOperands() == 2)
15420 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15421 MatchURemWithDivisor(Mul->getOperand(0)) ||
15422 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15423 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15424 return false;
15425}
15426
15429 BasicBlock *Header = L->getHeader();
15430 BasicBlock *Pred = L->getLoopPredecessor();
15431 LoopGuards Guards(SE);
15432 if (!Pred)
15433 return Guards;
15435 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15436 return Guards;
15437}
15438
15439void ScalarEvolution::LoopGuards::collectFromPHI(
15441 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15443 unsigned Depth) {
15444 if (!SE.isSCEVable(Phi.getType()))
15445 return;
15446
15447 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15448 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15449 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15450 if (!VisitedBlocks.insert(InBlock).second)
15451 return {nullptr, scCouldNotCompute};
15452 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15453 if (Inserted)
15454 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15455 Depth + 1);
15456 auto &RewriteMap = G->second.RewriteMap;
15457 if (RewriteMap.empty())
15458 return {nullptr, scCouldNotCompute};
15459 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15460 if (S == RewriteMap.end())
15461 return {nullptr, scCouldNotCompute};
15462 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15463 if (!SM)
15464 return {nullptr, scCouldNotCompute};
15465 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15466 return {C0, SM->getSCEVType()};
15467 return {nullptr, scCouldNotCompute};
15468 };
15469 auto MergeMinMaxConst = [](MinMaxPattern P1,
15470 MinMaxPattern P2) -> MinMaxPattern {
15471 auto [C1, T1] = P1;
15472 auto [C2, T2] = P2;
15473 if (!C1 || !C2 || T1 != T2)
15474 return {nullptr, scCouldNotCompute};
15475 switch (T1) {
15476 case scUMaxExpr:
15477 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15478 case scSMaxExpr:
15479 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15480 case scUMinExpr:
15481 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15482 case scSMinExpr:
15483 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15484 default:
15485 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15486 }
15487 };
15488 auto P = GetMinMaxConst(0);
15489 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15490 if (!P.first)
15491 break;
15492 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15493 }
15494 if (P.first) {
15495 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15497 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15498 Guards.RewriteMap.insert({LHS, RHS});
15499 }
15500}
15501
15502void ScalarEvolution::LoopGuards::collectFromBlock(
15503 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15504 const BasicBlock *Block, const BasicBlock *Pred,
15505 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15506 SmallVector<const SCEV *> ExprsToRewrite;
15507 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15508 const SCEV *RHS,
15509 DenseMap<const SCEV *, const SCEV *>
15510 &RewriteMap) {
15511 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15512 // replacement SCEV which isn't directly implied by the structure of that
15513 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15514 // legal. See the scoping rules for flags in the header to understand why.
15515
15516 // If LHS is a constant, apply information to the other expression.
15517 if (isa<SCEVConstant>(LHS)) {
15518 std::swap(LHS, RHS);
15520 }
15521
15522 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15523 // create this form when combining two checks of the form (X u< C2 + C1) and
15524 // (X >=u C1).
15525 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15526 &ExprsToRewrite]() {
15527 const SCEVConstant *C1;
15528 const SCEVUnknown *LHSUnknown;
15529 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15530 if (!match(LHS,
15531 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15532 !C2)
15533 return false;
15534
15535 auto ExactRegion =
15536 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15537 .sub(C1->getAPInt());
15538
15539 // Bail out, unless we have a non-wrapping, monotonic range.
15540 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15541 return false;
15542 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15543 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15544 I->second = SE.getUMaxExpr(
15545 SE.getConstant(ExactRegion.getUnsignedMin()),
15546 SE.getUMinExpr(RewrittenLHS,
15547 SE.getConstant(ExactRegion.getUnsignedMax())));
15548 ExprsToRewrite.push_back(LHSUnknown);
15549 return true;
15550 };
15551 if (MatchRangeCheckIdiom())
15552 return;
15553
15554 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15555 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15556 // the non-constant operand and in \p LHS the constant operand.
15557 auto IsMinMaxSCEVWithNonNegativeConstant =
15558 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15559 const SCEV *&RHS) {
15560 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15561 if (MinMax->getNumOperands() != 2)
15562 return false;
15563 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15564 if (C->getAPInt().isNegative())
15565 return false;
15566 SCTy = MinMax->getSCEVType();
15567 LHS = MinMax->getOperand(0);
15568 RHS = MinMax->getOperand(1);
15569 return true;
15570 }
15571 }
15572 return false;
15573 };
15574
15575 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15576 // constant, and returns their APInt in ExprVal and in DivisorVal.
15577 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15578 APInt &ExprVal, APInt &DivisorVal) {
15579 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15580 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15581 if (!ConstExpr || !ConstDivisor)
15582 return false;
15583 ExprVal = ConstExpr->getAPInt();
15584 DivisorVal = ConstDivisor->getAPInt();
15585 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15586 };
15587
15588 // Return a new SCEV that modifies \p Expr to the closest number divides by
15589 // \p Divisor and greater or equal than Expr.
15590 // For now, only handle constant Expr and Divisor.
15591 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15592 const SCEV *Divisor) {
15593 APInt ExprVal;
15594 APInt DivisorVal;
15595 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15596 return Expr;
15597 APInt Rem = ExprVal.urem(DivisorVal);
15598 if (!Rem.isZero())
15599 // return the SCEV: Expr + Divisor - Expr % Divisor
15600 return SE.getConstant(ExprVal + DivisorVal - Rem);
15601 return Expr;
15602 };
15603
15604 // Return a new SCEV that modifies \p Expr to the closest number divides by
15605 // \p Divisor and less or equal than Expr.
15606 // For now, only handle constant Expr and Divisor.
15607 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15608 const SCEV *Divisor) {
15609 APInt ExprVal;
15610 APInt DivisorVal;
15611 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15612 return Expr;
15613 APInt Rem = ExprVal.urem(DivisorVal);
15614 // return the SCEV: Expr - Expr % Divisor
15615 return SE.getConstant(ExprVal - Rem);
15616 };
15617
15618 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15619 // recursively. This is done by aligning up/down the constant value to the
15620 // Divisor.
15621 std::function<const SCEV *(const SCEV *, const SCEV *)>
15622 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15623 const SCEV *Divisor) {
15624 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15625 SCEVTypes SCTy;
15626 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15627 MinMaxRHS))
15628 return MinMaxExpr;
15629 auto IsMin =
15630 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15631 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15632 "Expected non-negative operand!");
15633 auto *DivisibleExpr =
15634 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15635 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15637 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15638 return SE.getMinMaxExpr(SCTy, Ops);
15639 };
15640
15641 // If we have LHS == 0, check if LHS is computing a property of some unknown
15642 // SCEV %v which we can rewrite %v to express explicitly.
15643 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15644 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15645 // explicitly express that.
15646 const SCEV *URemLHS = nullptr;
15647 const SCEV *URemRHS = nullptr;
15648 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15649 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15650 auto I = RewriteMap.find(LHSUnknown);
15651 const SCEV *RewrittenLHS =
15652 I != RewriteMap.end() ? I->second : LHSUnknown;
15653 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15654 const auto *Multiple =
15655 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15656 RewriteMap[LHSUnknown] = Multiple;
15657 ExprsToRewrite.push_back(LHSUnknown);
15658 return;
15659 }
15660 }
15661 }
15662
15663 // Do not apply information for constants or if RHS contains an AddRec.
15665 return;
15666
15667 // If RHS is SCEVUnknown, make sure the information is applied to it.
15669 std::swap(LHS, RHS);
15671 }
15672
15673 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15674 // and \p FromRewritten are the same (i.e. there has been no rewrite
15675 // registered for \p From), then puts this value in the list of rewritten
15676 // expressions.
15677 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15678 const SCEV *To) {
15679 if (From == FromRewritten)
15680 ExprsToRewrite.push_back(From);
15681 RewriteMap[From] = To;
15682 };
15683
15684 // Checks whether \p S has already been rewritten. In that case returns the
15685 // existing rewrite because we want to chain further rewrites onto the
15686 // already rewritten value. Otherwise returns \p S.
15687 auto GetMaybeRewritten = [&](const SCEV *S) {
15688 return RewriteMap.lookup_or(S, S);
15689 };
15690
15691 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15692 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15693 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15694 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15695 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15696 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15697 // DividesBy.
15698 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15699 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15700 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15701 if (Mul->getNumOperands() != 2)
15702 return false;
15703 auto *MulLHS = Mul->getOperand(0);
15704 auto *MulRHS = Mul->getOperand(1);
15705 if (isa<SCEVConstant>(MulLHS))
15706 std::swap(MulLHS, MulRHS);
15707 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15708 if (Div->getOperand(1) == MulRHS) {
15709 DividesBy = MulRHS;
15710 return true;
15711 }
15712 }
15713 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15714 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15715 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15716 return false;
15717 };
15718
15719 // Return true if Expr known to divide by \p DividesBy.
15720 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15721 [&](const SCEV *Expr, const SCEV *DividesBy) {
15722 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15723 return true;
15724 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15725 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15726 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15727 return false;
15728 };
15729
15730 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15731 const SCEV *DividesBy = nullptr;
15732 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15733 // Check that the whole expression is divided by DividesBy
15734 DividesBy =
15735 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15736
15737 // Collect rewrites for LHS and its transitive operands based on the
15738 // condition.
15739 // For min/max expressions, also apply the guard to its operands:
15740 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15741 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15742 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15743 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15744
15745 // We cannot express strict predicates in SCEV, so instead we replace them
15746 // with non-strict ones against plus or minus one of RHS depending on the
15747 // predicate.
15748 const SCEV *One = SE.getOne(RHS->getType());
15749 switch (Predicate) {
15750 case CmpInst::ICMP_ULT:
15751 if (RHS->getType()->isPointerTy())
15752 return;
15753 RHS = SE.getUMaxExpr(RHS, One);
15754 [[fallthrough]];
15755 case CmpInst::ICMP_SLT: {
15756 RHS = SE.getMinusSCEV(RHS, One);
15757 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15758 break;
15759 }
15760 case CmpInst::ICMP_UGT:
15761 case CmpInst::ICMP_SGT:
15762 RHS = SE.getAddExpr(RHS, One);
15763 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15764 break;
15765 case CmpInst::ICMP_ULE:
15766 case CmpInst::ICMP_SLE:
15767 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15768 break;
15769 case CmpInst::ICMP_UGE:
15770 case CmpInst::ICMP_SGE:
15771 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15772 break;
15773 default:
15774 break;
15775 }
15776
15778 SmallPtrSet<const SCEV *, 16> Visited;
15779
15780 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15781 append_range(Worklist, S->operands());
15782 };
15783
15784 while (!Worklist.empty()) {
15785 const SCEV *From = Worklist.pop_back_val();
15786 if (isa<SCEVConstant>(From))
15787 continue;
15788 if (!Visited.insert(From).second)
15789 continue;
15790 const SCEV *FromRewritten = GetMaybeRewritten(From);
15791 const SCEV *To = nullptr;
15792
15793 switch (Predicate) {
15794 case CmpInst::ICMP_ULT:
15795 case CmpInst::ICMP_ULE:
15796 To = SE.getUMinExpr(FromRewritten, RHS);
15797 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15798 EnqueueOperands(UMax);
15799 break;
15800 case CmpInst::ICMP_SLT:
15801 case CmpInst::ICMP_SLE:
15802 To = SE.getSMinExpr(FromRewritten, RHS);
15803 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15804 EnqueueOperands(SMax);
15805 break;
15806 case CmpInst::ICMP_UGT:
15807 case CmpInst::ICMP_UGE:
15808 To = SE.getUMaxExpr(FromRewritten, RHS);
15809 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15810 EnqueueOperands(UMin);
15811 break;
15812 case CmpInst::ICMP_SGT:
15813 case CmpInst::ICMP_SGE:
15814 To = SE.getSMaxExpr(FromRewritten, RHS);
15815 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15816 EnqueueOperands(SMin);
15817 break;
15818 case CmpInst::ICMP_EQ:
15820 To = RHS;
15821 break;
15822 case CmpInst::ICMP_NE:
15823 if (match(RHS, m_scev_Zero())) {
15824 const SCEV *OneAlignedUp =
15825 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15826 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15827 }
15828 break;
15829 default:
15830 break;
15831 }
15832
15833 if (To)
15834 AddRewrite(From, FromRewritten, To);
15835 }
15836 };
15837
15839 // First, collect information from assumptions dominating the loop.
15840 for (auto &AssumeVH : SE.AC.assumptions()) {
15841 if (!AssumeVH)
15842 continue;
15843 auto *AssumeI = cast<CallInst>(AssumeVH);
15844 if (!SE.DT.dominates(AssumeI, Block))
15845 continue;
15846 Terms.emplace_back(AssumeI->getOperand(0), true);
15847 }
15848
15849 // Second, collect information from llvm.experimental.guards dominating the loop.
15850 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15851 SE.F.getParent(), Intrinsic::experimental_guard);
15852 if (GuardDecl)
15853 for (const auto *GU : GuardDecl->users())
15854 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15855 if (Guard->getFunction() == Block->getParent() &&
15856 SE.DT.dominates(Guard, Block))
15857 Terms.emplace_back(Guard->getArgOperand(0), true);
15858
15859 // Third, collect conditions from dominating branches. Starting at the loop
15860 // predecessor, climb up the predecessor chain, as long as there are
15861 // predecessors that can be found that have unique successors leading to the
15862 // original header.
15863 // TODO: share this logic with isLoopEntryGuardedByCond.
15864 unsigned NumCollectedConditions = 0;
15865 VisitedBlocks.insert(Block);
15866 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15867 for (; Pair.first;
15868 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15869 VisitedBlocks.insert(Pair.second);
15870 const BranchInst *LoopEntryPredicate =
15871 dyn_cast<BranchInst>(Pair.first->getTerminator());
15872 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15873 continue;
15874
15875 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15876 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15877 NumCollectedConditions++;
15878
15879 // If we are recursively collecting guards stop after 2
15880 // conditions to limit compile-time impact for now.
15881 if (Depth > 0 && NumCollectedConditions == 2)
15882 break;
15883 }
15884 // Finally, if we stopped climbing the predecessor chain because
15885 // there wasn't a unique one to continue, try to collect conditions
15886 // for PHINodes by recursively following all of their incoming
15887 // blocks and try to merge the found conditions to build a new one
15888 // for the Phi.
15889 if (Pair.second->hasNPredecessorsOrMore(2) &&
15891 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15892 for (auto &Phi : Pair.second->phis())
15893 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15894 }
15895
15896 // Now apply the information from the collected conditions to
15897 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15898 // earliest conditions is processed first. This ensures the SCEVs with the
15899 // shortest dependency chains are constructed first.
15900 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15901 SmallVector<Value *, 8> Worklist;
15902 SmallPtrSet<Value *, 8> Visited;
15903 Worklist.push_back(Term);
15904 while (!Worklist.empty()) {
15905 Value *Cond = Worklist.pop_back_val();
15906 if (!Visited.insert(Cond).second)
15907 continue;
15908
15909 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15910 auto Predicate =
15911 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15912 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15913 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15914 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15915 continue;
15916 }
15917
15918 Value *L, *R;
15919 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15920 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15921 Worklist.push_back(L);
15922 Worklist.push_back(R);
15923 }
15924 }
15925 }
15926
15927 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15928 // the replacement expressions are contained in the ranges of the replaced
15929 // expressions.
15930 Guards.PreserveNUW = true;
15931 Guards.PreserveNSW = true;
15932 for (const SCEV *Expr : ExprsToRewrite) {
15933 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15934 Guards.PreserveNUW &=
15935 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15936 Guards.PreserveNSW &=
15937 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15938 }
15939
15940 // Now that all rewrite information is collect, rewrite the collected
15941 // expressions with the information in the map. This applies information to
15942 // sub-expressions.
15943 if (ExprsToRewrite.size() > 1) {
15944 for (const SCEV *Expr : ExprsToRewrite) {
15945 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15946 Guards.RewriteMap.erase(Expr);
15947 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15948 }
15949 }
15950}
15951
15953 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15954 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15955 /// replacement is loop invariant in the loop of the AddRec.
15956 class SCEVLoopGuardRewriter
15957 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15959
15961
15962 public:
15963 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15964 const ScalarEvolution::LoopGuards &Guards)
15965 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15966 if (Guards.PreserveNUW)
15967 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15968 if (Guards.PreserveNSW)
15969 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15970 }
15971
15972 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15973
15974 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15975 return Map.lookup_or(Expr, Expr);
15976 }
15977
15978 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15979 if (const SCEV *S = Map.lookup(Expr))
15980 return S;
15981
15982 // If we didn't find the extact ZExt expr in the map, check if there's
15983 // an entry for a smaller ZExt we can use instead.
15984 Type *Ty = Expr->getType();
15985 const SCEV *Op = Expr->getOperand(0);
15986 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15987 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15988 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15989 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15990 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15991 if (const SCEV *S = Map.lookup(NarrowExt))
15992 return SE.getZeroExtendExpr(S, Ty);
15993 Bitwidth = Bitwidth / 2;
15994 }
15995
15997 Expr);
15998 }
15999
16000 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16001 if (const SCEV *S = Map.lookup(Expr))
16002 return S;
16004 Expr);
16005 }
16006
16007 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16008 if (const SCEV *S = Map.lookup(Expr))
16009 return S;
16011 }
16012
16013 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16014 if (const SCEV *S = Map.lookup(Expr))
16015 return S;
16017 }
16018
16019 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16020 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16021 // (Const + A + B). There may be guard info for A + B, and if so, apply
16022 // it.
16023 // TODO: Could more generally apply guards to Add sub-expressions.
16024 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16025 Expr->getNumOperands() == 3) {
16026 if (const SCEV *S = Map.lookup(
16027 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
16028 return SE.getAddExpr(Expr->getOperand(0), S);
16029 }
16031 bool Changed = false;
16032 for (const auto *Op : Expr->operands()) {
16033 Operands.push_back(
16035 Changed |= Op != Operands.back();
16036 }
16037 // We are only replacing operands with equivalent values, so transfer the
16038 // flags from the original expression.
16039 return !Changed ? Expr
16040 : SE.getAddExpr(Operands,
16042 Expr->getNoWrapFlags(), FlagMask));
16043 }
16044
16045 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16047 bool Changed = false;
16048 for (const auto *Op : Expr->operands()) {
16049 Operands.push_back(
16051 Changed |= Op != Operands.back();
16052 }
16053 // We are only replacing operands with equivalent values, so transfer the
16054 // flags from the original expression.
16055 return !Changed ? Expr
16056 : SE.getMulExpr(Operands,
16058 Expr->getNoWrapFlags(), FlagMask));
16059 }
16060 };
16061
16062 if (RewriteMap.empty())
16063 return Expr;
16064
16065 SCEVLoopGuardRewriter Rewriter(SE, *this);
16066 return Rewriter.visit(Expr);
16067}
16068
16069const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16070 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16071}
16072
16074 const LoopGuards &Guards) {
16075 return Guards.rewrite(Expr);
16076}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
mir Rename Register Operands
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:167
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:206
APInt abs() const
Get the absolute value.
Definition APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:219
unsigned countTrailingZeros() const
Definition APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:356
unsigned logBase2() const
Definition APInt.h:1761
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:440
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1221
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
iterator end() const
Definition ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator begin() const
Definition ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:950
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:678
@ ICMP_SLT
signed less than
Definition InstrTypes.h:707
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:708
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:701
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:705
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:703
@ ICMP_NE
not equal
Definition InstrTypes.h:700
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:706
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:704
bool isSigned() const
Definition InstrTypes.h:932
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:829
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:944
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:791
bool isUnsigned() const
Definition InstrTypes.h:938
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:928
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:671
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:187
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:229
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:173
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:161
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:156
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:214
void swap(DenseMap &RHS)
Definition DenseMap.h:744
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:570
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:597
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V, bool HasCoroSuspendInst=false) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1077
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:623
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:654
TypeSize getSizeInBits() const
Definition DataLayout.h:634
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:295
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:294
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:169
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
class_match< const SCEV > m_SCEV()
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
Definition MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:330
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2060
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1727
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:738
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2138
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:252
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2055
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:682
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:157
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:95
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2130
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1734
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:420
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:288
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:71
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1956
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2032
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1963
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:257
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1899
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:853
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:294
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:101
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:189
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:138
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:122
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:98
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.