LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (matchURem(Op, LHS, RHS))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 if (SM->getNumOperands() == 2)
1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845 if (MulLHS->getAPInt().isPowerOf2())
1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848 MulLHS->getAPInt().logBase2();
1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850 return getMulExpr(
1851 getZeroExtendExpr(MulLHS, Ty),
1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2427static SCEV::NoWrapFlags
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 if (Ops.size() == 2) {
2704 if (Mul && Mul->getNumOperands() == 2 &&
2705 Mul->getOperand(0)->isAllOnesValue()) {
2706 const SCEV *X;
2707 const SCEV *Y;
2708 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709 return getMulExpr(Y, getUDivExpr(X, Y));
2710 }
2711 }
2712 }
2713
2714 // Skip past any other cast SCEVs.
2715 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2716 ++Idx;
2717
2718 // If there are add operands they would be next.
2719 if (Idx < Ops.size()) {
2720 bool DeletedAdd = false;
2721 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2722 // common NUW flag for expression after inlining. Other flags cannot be
2723 // preserved, because they may depend on the original order of operations.
2724 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2725 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2726 if (Ops.size() > AddOpsInlineThreshold ||
2727 Add->getNumOperands() > AddOpsInlineThreshold)
2728 break;
2729 // If we have an add, expand the add operands onto the end of the operands
2730 // list.
2731 Ops.erase(Ops.begin()+Idx);
2732 append_range(Ops, Add->operands());
2733 DeletedAdd = true;
2734 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2735 }
2736
2737 // If we deleted at least one add, we added operands to the end of the list,
2738 // and they are not necessarily sorted. Recurse to resort and resimplify
2739 // any operands we just acquired.
2740 if (DeletedAdd)
2741 return getAddExpr(Ops, CommonFlags, Depth + 1);
2742 }
2743
2744 // Skip over the add expression until we get to a multiply.
2745 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2746 ++Idx;
2747
2748 // Check to see if there are any folding opportunities present with
2749 // operands multiplied by constant values.
2750 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2754 APInt AccumulatedConstant(BitWidth, 0);
2755 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2756 Ops, APInt(BitWidth, 1), *this)) {
2757 struct APIntCompare {
2758 bool operator()(const APInt &LHS, const APInt &RHS) const {
2759 return LHS.ult(RHS);
2760 }
2761 };
2762
2763 // Some interesting folding opportunity is present, so its worthwhile to
2764 // re-generate the operands list. Group the operands by constant scale,
2765 // to avoid multiplying by the same constant scale multiple times.
2766 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2767 for (const SCEV *NewOp : NewOps)
2768 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2769 // Re-generate the operands list.
2770 Ops.clear();
2771 if (AccumulatedConstant != 0)
2772 Ops.push_back(getConstant(AccumulatedConstant));
2773 for (auto &MulOp : MulOpLists) {
2774 if (MulOp.first == 1) {
2775 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2776 } else if (MulOp.first != 0) {
2777 Ops.push_back(getMulExpr(
2778 getConstant(MulOp.first),
2779 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2780 SCEV::FlagAnyWrap, Depth + 1));
2781 }
2782 }
2783 if (Ops.empty())
2784 return getZero(Ty);
2785 if (Ops.size() == 1)
2786 return Ops[0];
2787 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2788 }
2789 }
2790
2791 // If we are adding something to a multiply expression, make sure the
2792 // something is not already an operand of the multiply. If so, merge it into
2793 // the multiply.
2794 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2795 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2796 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2797 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2798 if (isa<SCEVConstant>(MulOpSCEV))
2799 continue;
2800 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2801 if (MulOpSCEV == Ops[AddOp]) {
2802 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2803 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2804 if (Mul->getNumOperands() != 2) {
2805 // If the multiply has more than two operands, we must get the
2806 // Y*Z term.
2808 Mul->operands().take_front(MulOp));
2809 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2810 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2811 }
2812 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2813 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2814 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2816 if (Ops.size() == 2) return OuterMul;
2817 if (AddOp < Idx) {
2818 Ops.erase(Ops.begin()+AddOp);
2819 Ops.erase(Ops.begin()+Idx-1);
2820 } else {
2821 Ops.erase(Ops.begin()+Idx);
2822 Ops.erase(Ops.begin()+AddOp-1);
2823 }
2824 Ops.push_back(OuterMul);
2825 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2826 }
2827
2828 // Check this multiply against other multiplies being added together.
2829 for (unsigned OtherMulIdx = Idx+1;
2830 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 ++OtherMulIdx) {
2832 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2833 // If MulOp occurs in OtherMul, we can fold the two multiplies
2834 // together.
2835 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2836 OMulOp != e; ++OMulOp)
2837 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2838 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2839 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2840 if (Mul->getNumOperands() != 2) {
2842 Mul->operands().take_front(MulOp));
2843 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2844 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2845 }
2846 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2847 if (OtherMul->getNumOperands() != 2) {
2849 OtherMul->operands().take_front(OMulOp));
2850 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2851 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2852 }
2853 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2854 const SCEV *InnerMulSum =
2855 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2856 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2858 if (Ops.size() == 2) return OuterMul;
2859 Ops.erase(Ops.begin()+Idx);
2860 Ops.erase(Ops.begin()+OtherMulIdx-1);
2861 Ops.push_back(OuterMul);
2862 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2863 }
2864 }
2865 }
2866 }
2867
2868 // If there are any add recurrences in the operands list, see if any other
2869 // added values are loop invariant. If so, we can fold them into the
2870 // recurrence.
2871 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2872 ++Idx;
2873
2874 // Scan over all recurrences, trying to fold loop invariants into them.
2875 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2876 // Scan all of the other operands to this add and add them to the vector if
2877 // they are loop invariant w.r.t. the recurrence.
2879 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2880 const Loop *AddRecLoop = AddRec->getLoop();
2881 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2882 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2883 LIOps.push_back(Ops[i]);
2884 Ops.erase(Ops.begin()+i);
2885 --i; --e;
2886 }
2887
2888 // If we found some loop invariants, fold them into the recurrence.
2889 if (!LIOps.empty()) {
2890 // Compute nowrap flags for the addition of the loop-invariant ops and
2891 // the addrec. Temporarily push it as an operand for that purpose. These
2892 // flags are valid in the scope of the addrec only.
2893 LIOps.push_back(AddRec);
2894 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2895 LIOps.pop_back();
2896
2897 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2898 LIOps.push_back(AddRec->getStart());
2899
2900 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2901
2902 // It is not in general safe to propagate flags valid on an add within
2903 // the addrec scope to one outside it. We must prove that the inner
2904 // scope is guaranteed to execute if the outer one does to be able to
2905 // safely propagate. We know the program is undefined if poison is
2906 // produced on the inner scoped addrec. We also know that *for this use*
2907 // the outer scoped add can't overflow (because of the flags we just
2908 // computed for the inner scoped add) without the program being undefined.
2909 // Proving that entry to the outer scope neccesitates entry to the inner
2910 // scope, thus proves the program undefined if the flags would be violated
2911 // in the outer scope.
2912 SCEV::NoWrapFlags AddFlags = Flags;
2913 if (AddFlags != SCEV::FlagAnyWrap) {
2914 auto *DefI = getDefiningScopeBound(LIOps);
2915 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2916 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2917 AddFlags = SCEV::FlagAnyWrap;
2918 }
2919 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2920
2921 // Build the new addrec. Propagate the NUW and NSW flags if both the
2922 // outer add and the inner addrec are guaranteed to have no overflow.
2923 // Always propagate NW.
2924 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2925 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2926
2927 // If all of the other operands were loop invariant, we are done.
2928 if (Ops.size() == 1) return NewRec;
2929
2930 // Otherwise, add the folded AddRec by the non-invariant parts.
2931 for (unsigned i = 0;; ++i)
2932 if (Ops[i] == AddRec) {
2933 Ops[i] = NewRec;
2934 break;
2935 }
2936 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2937 }
2938
2939 // Okay, if there weren't any loop invariants to be folded, check to see if
2940 // there are multiple AddRec's with the same loop induction variable being
2941 // added together. If so, we can fold them.
2942 for (unsigned OtherIdx = Idx+1;
2943 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2944 ++OtherIdx) {
2945 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2946 // so that the 1st found AddRecExpr is dominated by all others.
2947 assert(DT.dominates(
2948 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2949 AddRec->getLoop()->getHeader()) &&
2950 "AddRecExprs are not sorted in reverse dominance order?");
2951 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2952 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2953 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2954 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 ++OtherIdx) {
2956 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2957 if (OtherAddRec->getLoop() == AddRecLoop) {
2958 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2959 i != e; ++i) {
2960 if (i >= AddRecOps.size()) {
2961 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2962 break;
2963 }
2965 AddRecOps[i], OtherAddRec->getOperand(i)};
2966 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2969 }
2970 }
2971 // Step size has changed, so we cannot guarantee no self-wraparound.
2972 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2973 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2974 }
2975 }
2976
2977 // Otherwise couldn't fold anything into this recurrence. Move onto the
2978 // next one.
2979 }
2980
2981 // Okay, it looks like we really DO need an add expr. Check to see if we
2982 // already have one, otherwise create a new one.
2983 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2984}
2985
2986const SCEV *
2987ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2988 SCEV::NoWrapFlags Flags) {
2990 ID.AddInteger(scAddExpr);
2991 for (const SCEV *Op : Ops)
2992 ID.AddPointer(Op);
2993 void *IP = nullptr;
2994 SCEVAddExpr *S =
2995 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2996 if (!S) {
2997 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2999 S = new (SCEVAllocator)
3000 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3001 UniqueSCEVs.InsertNode(S, IP);
3002 registerUser(S, Ops);
3003 }
3004 S->setNoWrapFlags(Flags);
3005 return S;
3006}
3007
3008const SCEV *
3009ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3010 const Loop *L, SCEV::NoWrapFlags Flags) {
3011 FoldingSetNodeID ID;
3012 ID.AddInteger(scAddRecExpr);
3013 for (const SCEV *Op : Ops)
3014 ID.AddPointer(Op);
3015 ID.AddPointer(L);
3016 void *IP = nullptr;
3017 SCEVAddRecExpr *S =
3018 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3019 if (!S) {
3020 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3022 S = new (SCEVAllocator)
3023 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3024 UniqueSCEVs.InsertNode(S, IP);
3025 LoopUsers[L].push_back(S);
3026 registerUser(S, Ops);
3027 }
3028 setNoWrapFlags(S, Flags);
3029 return S;
3030}
3031
3032const SCEV *
3033ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3034 SCEV::NoWrapFlags Flags) {
3035 FoldingSetNodeID ID;
3036 ID.AddInteger(scMulExpr);
3037 for (const SCEV *Op : Ops)
3038 ID.AddPointer(Op);
3039 void *IP = nullptr;
3040 SCEVMulExpr *S =
3041 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3042 if (!S) {
3043 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3045 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3046 O, Ops.size());
3047 UniqueSCEVs.InsertNode(S, IP);
3048 registerUser(S, Ops);
3049 }
3050 S->setNoWrapFlags(Flags);
3051 return S;
3052}
3053
3054static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3055 uint64_t k = i*j;
3056 if (j > 1 && k / j != i) Overflow = true;
3057 return k;
3058}
3059
3060/// Compute the result of "n choose k", the binomial coefficient. If an
3061/// intermediate computation overflows, Overflow will be set and the return will
3062/// be garbage. Overflow is not cleared on absence of overflow.
3063static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3064 // We use the multiplicative formula:
3065 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3066 // At each iteration, we take the n-th term of the numeral and divide by the
3067 // (k-n)th term of the denominator. This division will always produce an
3068 // integral result, and helps reduce the chance of overflow in the
3069 // intermediate computations. However, we can still overflow even when the
3070 // final result would fit.
3071
3072 if (n == 0 || n == k) return 1;
3073 if (k > n) return 0;
3074
3075 if (k > n/2)
3076 k = n-k;
3077
3078 uint64_t r = 1;
3079 for (uint64_t i = 1; i <= k; ++i) {
3080 r = umul_ov(r, n-(i-1), Overflow);
3081 r /= i;
3082 }
3083 return r;
3084}
3085
3086/// Determine if any of the operands in this SCEV are a constant or if
3087/// any of the add or multiply expressions in this SCEV contain a constant.
3088static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3089 struct FindConstantInAddMulChain {
3090 bool FoundConstant = false;
3091
3092 bool follow(const SCEV *S) {
3093 FoundConstant |= isa<SCEVConstant>(S);
3094 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3095 }
3096
3097 bool isDone() const {
3098 return FoundConstant;
3099 }
3100 };
3101
3102 FindConstantInAddMulChain F;
3104 ST.visitAll(StartExpr);
3105 return F.FoundConstant;
3106}
3107
3108/// Get a canonical multiply expression, or something simpler if possible.
3110 SCEV::NoWrapFlags OrigFlags,
3111 unsigned Depth) {
3112 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3113 "only nuw or nsw allowed");
3114 assert(!Ops.empty() && "Cannot get empty mul!");
3115 if (Ops.size() == 1) return Ops[0];
3116#ifndef NDEBUG
3117 Type *ETy = Ops[0]->getType();
3118 assert(!ETy->isPointerTy());
3119 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3120 assert(Ops[i]->getType() == ETy &&
3121 "SCEVMulExpr operand types don't match!");
3122#endif
3123
3124 const SCEV *Folded = constantFoldAndGroupOps(
3125 *this, LI, DT, Ops,
3126 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3127 [](const APInt &C) { return C.isOne(); }, // identity
3128 [](const APInt &C) { return C.isZero(); }); // absorber
3129 if (Folded)
3130 return Folded;
3131
3132 // Delay expensive flag strengthening until necessary.
3133 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3134 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3135 };
3136
3137 // Limit recursion calls depth.
3139 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3140
3141 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3142 // Don't strengthen flags if we have no new information.
3143 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3144 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3145 Mul->setNoWrapFlags(ComputeFlags(Ops));
3146 return S;
3147 }
3148
3149 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3150 if (Ops.size() == 2) {
3151 // C1*(C2+V) -> C1*C2 + C1*V
3152 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3153 // If any of Add's ops are Adds or Muls with a constant, apply this
3154 // transformation as well.
3155 //
3156 // TODO: There are some cases where this transformation is not
3157 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3158 // this transformation should be narrowed down.
3159 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3160 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3162 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3164 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3165 }
3166
3167 if (Ops[0]->isAllOnesValue()) {
3168 // If we have a mul by -1 of an add, try distributing the -1 among the
3169 // add operands.
3170 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3172 bool AnyFolded = false;
3173 for (const SCEV *AddOp : Add->operands()) {
3174 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3175 Depth + 1);
3176 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3177 NewOps.push_back(Mul);
3178 }
3179 if (AnyFolded)
3180 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3181 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3182 // Negation preserves a recurrence's no self-wrap property.
3184 for (const SCEV *AddRecOp : AddRec->operands())
3185 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3186 Depth + 1));
3187 // Let M be the minimum representable signed value. AddRec with nsw
3188 // multiplied by -1 can have signed overflow if and only if it takes a
3189 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3190 // maximum signed value. In all other cases signed overflow is
3191 // impossible.
3192 auto FlagsMask = SCEV::FlagNW;
3193 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3194 auto MinInt =
3195 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3196 if (getSignedRangeMin(AddRec) != MinInt)
3197 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3198 }
3199 return getAddRecExpr(Operands, AddRec->getLoop(),
3200 AddRec->getNoWrapFlags(FlagsMask));
3201 }
3202 }
3203
3204 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3205 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3206 const SCEVAddExpr *InnerAdd;
3207 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3208 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3209 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3210 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3211 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3213 SCEV::FlagNUW)) {
3214 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3215 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3216 };
3217 }
3218
3219 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3220 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3221 // of C1, fold to (D /u (C2 /u C1)).
3222 const SCEV *D;
3223 APInt C1V = LHSC->getAPInt();
3224 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3225 // as -1 * 1, as it won't enable additional folds.
3226 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3227 C1V = C1V.abs();
3228 const SCEVConstant *C2;
3229 if (C1V.isPowerOf2() &&
3231 C2->getAPInt().isPowerOf2() &&
3232 C1V.logBase2() <= getMinTrailingZeros(D)) {
3233 const SCEV *NewMul = nullptr;
3234 if (C1V.uge(C2->getAPInt())) {
3235 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3236 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3237 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3238 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3239 }
3240 if (NewMul)
3241 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3242 }
3243 }
3244 }
3245
3246 // Skip over the add expression until we get to a multiply.
3247 unsigned Idx = 0;
3248 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3249 ++Idx;
3250
3251 // If there are mul operands inline them all into this expression.
3252 if (Idx < Ops.size()) {
3253 bool DeletedMul = false;
3254 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3255 if (Ops.size() > MulOpsInlineThreshold)
3256 break;
3257 // If we have an mul, expand the mul operands onto the end of the
3258 // operands list.
3259 Ops.erase(Ops.begin()+Idx);
3260 append_range(Ops, Mul->operands());
3261 DeletedMul = true;
3262 }
3263
3264 // If we deleted at least one mul, we added operands to the end of the
3265 // list, and they are not necessarily sorted. Recurse to resort and
3266 // resimplify any operands we just acquired.
3267 if (DeletedMul)
3268 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3269 }
3270
3271 // If there are any add recurrences in the operands list, see if any other
3272 // added values are loop invariant. If so, we can fold them into the
3273 // recurrence.
3274 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3275 ++Idx;
3276
3277 // Scan over all recurrences, trying to fold loop invariants into them.
3278 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3279 // Scan all of the other operands to this mul and add them to the vector
3280 // if they are loop invariant w.r.t. the recurrence.
3282 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3283 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3284 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3285 LIOps.push_back(Ops[i]);
3286 Ops.erase(Ops.begin()+i);
3287 --i; --e;
3288 }
3289
3290 // If we found some loop invariants, fold them into the recurrence.
3291 if (!LIOps.empty()) {
3292 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3294 NewOps.reserve(AddRec->getNumOperands());
3295 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3296
3297 // If both the mul and addrec are nuw, we can preserve nuw.
3298 // If both the mul and addrec are nsw, we can only preserve nsw if either
3299 // a) they are also nuw, or
3300 // b) all multiplications of addrec operands with scale are nsw.
3301 SCEV::NoWrapFlags Flags =
3302 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3303
3304 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3305 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3306 SCEV::FlagAnyWrap, Depth + 1));
3307
3308 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3310 Instruction::Mul, getSignedRange(Scale),
3312 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3313 Flags = clearFlags(Flags, SCEV::FlagNSW);
3314 }
3315 }
3316
3317 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3318
3319 // If all of the other operands were loop invariant, we are done.
3320 if (Ops.size() == 1) return NewRec;
3321
3322 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3323 for (unsigned i = 0;; ++i)
3324 if (Ops[i] == AddRec) {
3325 Ops[i] = NewRec;
3326 break;
3327 }
3328 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3329 }
3330
3331 // Okay, if there weren't any loop invariants to be folded, check to see
3332 // if there are multiple AddRec's with the same loop induction variable
3333 // being multiplied together. If so, we can fold them.
3334
3335 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3336 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3337 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3338 // ]]],+,...up to x=2n}.
3339 // Note that the arguments to choose() are always integers with values
3340 // known at compile time, never SCEV objects.
3341 //
3342 // The implementation avoids pointless extra computations when the two
3343 // addrec's are of different length (mathematically, it's equivalent to
3344 // an infinite stream of zeros on the right).
3345 bool OpsModified = false;
3346 for (unsigned OtherIdx = Idx+1;
3347 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3348 ++OtherIdx) {
3349 const SCEVAddRecExpr *OtherAddRec =
3350 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3351 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3352 continue;
3353
3354 // Limit max number of arguments to avoid creation of unreasonably big
3355 // SCEVAddRecs with very complex operands.
3356 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3357 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3358 continue;
3359
3360 bool Overflow = false;
3361 Type *Ty = AddRec->getType();
3362 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3364 for (int x = 0, xe = AddRec->getNumOperands() +
3365 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3367 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3368 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3369 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3370 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3371 z < ze && !Overflow; ++z) {
3372 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3373 uint64_t Coeff;
3374 if (LargerThan64Bits)
3375 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3376 else
3377 Coeff = Coeff1*Coeff2;
3378 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3379 const SCEV *Term1 = AddRec->getOperand(y-z);
3380 const SCEV *Term2 = OtherAddRec->getOperand(z);
3381 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3382 SCEV::FlagAnyWrap, Depth + 1));
3383 }
3384 }
3385 if (SumOps.empty())
3386 SumOps.push_back(getZero(Ty));
3387 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3388 }
3389 if (!Overflow) {
3390 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3392 if (Ops.size() == 2) return NewAddRec;
3393 Ops[Idx] = NewAddRec;
3394 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3395 OpsModified = true;
3396 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3397 if (!AddRec)
3398 break;
3399 }
3400 }
3401 if (OpsModified)
3402 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3403
3404 // Otherwise couldn't fold anything into this recurrence. Move onto the
3405 // next one.
3406 }
3407
3408 // Okay, it looks like we really DO need an mul expr. Check to see if we
3409 // already have one, otherwise create a new one.
3410 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3411}
3412
3413/// Represents an unsigned remainder expression based on unsigned division.
3415 const SCEV *RHS) {
3416 assert(getEffectiveSCEVType(LHS->getType()) ==
3417 getEffectiveSCEVType(RHS->getType()) &&
3418 "SCEVURemExpr operand types don't match!");
3419
3420 // Short-circuit easy cases
3421 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3422 // If constant is one, the result is trivial
3423 if (RHSC->getValue()->isOne())
3424 return getZero(LHS->getType()); // X urem 1 --> 0
3425
3426 // If constant is a power of two, fold into a zext(trunc(LHS)).
3427 if (RHSC->getAPInt().isPowerOf2()) {
3428 Type *FullTy = LHS->getType();
3429 Type *TruncTy =
3430 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3431 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3432 }
3433 }
3434
3435 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3436 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3437 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3438 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3439}
3440
3441/// Get a canonical unsigned division expression, or something simpler if
3442/// possible.
3444 const SCEV *RHS) {
3445 assert(!LHS->getType()->isPointerTy() &&
3446 "SCEVUDivExpr operand can't be pointer!");
3447 assert(LHS->getType() == RHS->getType() &&
3448 "SCEVUDivExpr operand types don't match!");
3449
3451 ID.AddInteger(scUDivExpr);
3452 ID.AddPointer(LHS);
3453 ID.AddPointer(RHS);
3454 void *IP = nullptr;
3455 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3456 return S;
3457
3458 // 0 udiv Y == 0
3459 if (match(LHS, m_scev_Zero()))
3460 return LHS;
3461
3462 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3463 if (RHSC->getValue()->isOne())
3464 return LHS; // X udiv 1 --> x
3465 // If the denominator is zero, the result of the udiv is undefined. Don't
3466 // try to analyze it, because the resolution chosen here may differ from
3467 // the resolution chosen in other parts of the compiler.
3468 if (!RHSC->getValue()->isZero()) {
3469 // Determine if the division can be folded into the operands of
3470 // its operands.
3471 // TODO: Generalize this to non-constants by using known-bits information.
3472 Type *Ty = LHS->getType();
3473 unsigned LZ = RHSC->getAPInt().countl_zero();
3474 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3475 // For non-power-of-two values, effectively round the value up to the
3476 // nearest power of two.
3477 if (!RHSC->getAPInt().isPowerOf2())
3478 ++MaxShiftAmt;
3479 IntegerType *ExtTy =
3480 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3481 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3482 if (const SCEVConstant *Step =
3483 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3484 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3485 const APInt &StepInt = Step->getAPInt();
3486 const APInt &DivInt = RHSC->getAPInt();
3487 if (!StepInt.urem(DivInt) &&
3488 getZeroExtendExpr(AR, ExtTy) ==
3489 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3490 getZeroExtendExpr(Step, ExtTy),
3491 AR->getLoop(), SCEV::FlagAnyWrap)) {
3493 for (const SCEV *Op : AR->operands())
3494 Operands.push_back(getUDivExpr(Op, RHS));
3495 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3496 }
3497 /// Get a canonical UDivExpr for a recurrence.
3498 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3499 // We can currently only fold X%N if X is constant.
3500 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3501 if (StartC && !DivInt.urem(StepInt) &&
3502 getZeroExtendExpr(AR, ExtTy) ==
3503 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3504 getZeroExtendExpr(Step, ExtTy),
3505 AR->getLoop(), SCEV::FlagAnyWrap)) {
3506 const APInt &StartInt = StartC->getAPInt();
3507 const APInt &StartRem = StartInt.urem(StepInt);
3508 if (StartRem != 0) {
3509 const SCEV *NewLHS =
3510 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3511 AR->getLoop(), SCEV::FlagNW);
3512 if (LHS != NewLHS) {
3513 LHS = NewLHS;
3514
3515 // Reset the ID to include the new LHS, and check if it is
3516 // already cached.
3517 ID.clear();
3518 ID.AddInteger(scUDivExpr);
3519 ID.AddPointer(LHS);
3520 ID.AddPointer(RHS);
3521 IP = nullptr;
3522 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3523 return S;
3524 }
3525 }
3526 }
3527 }
3528 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3529 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3531 for (const SCEV *Op : M->operands())
3532 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3533 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3534 // Find an operand that's safely divisible.
3535 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3536 const SCEV *Op = M->getOperand(i);
3537 const SCEV *Div = getUDivExpr(Op, RHSC);
3538 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3539 Operands = SmallVector<const SCEV *, 4>(M->operands());
3540 Operands[i] = Div;
3541 return getMulExpr(Operands);
3542 }
3543 }
3544 }
3545
3546 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3547 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3548 if (auto *DivisorConstant =
3549 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3550 bool Overflow = false;
3551 APInt NewRHS =
3552 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3553 if (Overflow) {
3554 return getConstant(RHSC->getType(), 0, false);
3555 }
3556 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3557 }
3558 }
3559
3560 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3561 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3563 for (const SCEV *Op : A->operands())
3564 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3565 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3566 Operands.clear();
3567 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3568 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3569 if (isa<SCEVUDivExpr>(Op) ||
3570 getMulExpr(Op, RHS) != A->getOperand(i))
3571 break;
3572 Operands.push_back(Op);
3573 }
3574 if (Operands.size() == A->getNumOperands())
3575 return getAddExpr(Operands);
3576 }
3577 }
3578
3579 // Fold if both operands are constant.
3580 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3581 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3582 }
3583 }
3584
3585 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3586 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3587 AE && AE->getNumOperands() == 2) {
3588 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3589 const APInt &NegC = VC->getAPInt();
3590 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3591 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3592 if (MME && MME->getNumOperands() == 2 &&
3593 isa<SCEVConstant>(MME->getOperand(0)) &&
3594 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3595 MME->getOperand(1) == RHS)
3596 return getZero(LHS->getType());
3597 }
3598 }
3599 }
3600
3601 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3602 // changes). Make sure we get a new one.
3603 IP = nullptr;
3604 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3605 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3606 LHS, RHS);
3607 UniqueSCEVs.InsertNode(S, IP);
3608 registerUser(S, {LHS, RHS});
3609 return S;
3610}
3611
3612APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3613 APInt A = C1->getAPInt().abs();
3614 APInt B = C2->getAPInt().abs();
3615 uint32_t ABW = A.getBitWidth();
3616 uint32_t BBW = B.getBitWidth();
3617
3618 if (ABW > BBW)
3619 B = B.zext(ABW);
3620 else if (ABW < BBW)
3621 A = A.zext(BBW);
3622
3623 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3624}
3625
3626/// Get a canonical unsigned division expression, or something simpler if
3627/// possible. There is no representation for an exact udiv in SCEV IR, but we
3628/// can attempt to remove factors from the LHS and RHS. We can't do this when
3629/// it's not exact because the udiv may be clearing bits.
3631 const SCEV *RHS) {
3632 // TODO: we could try to find factors in all sorts of things, but for now we
3633 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3634 // end of this file for inspiration.
3635
3637 if (!Mul || !Mul->hasNoUnsignedWrap())
3638 return getUDivExpr(LHS, RHS);
3639
3640 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3641 // If the mulexpr multiplies by a constant, then that constant must be the
3642 // first element of the mulexpr.
3643 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3644 if (LHSCst == RHSCst) {
3646 return getMulExpr(Operands);
3647 }
3648
3649 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3650 // that there's a factor provided by one of the other terms. We need to
3651 // check.
3652 APInt Factor = gcd(LHSCst, RHSCst);
3653 if (!Factor.isIntN(1)) {
3654 LHSCst =
3655 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3656 RHSCst =
3657 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3659 Operands.push_back(LHSCst);
3660 append_range(Operands, Mul->operands().drop_front());
3661 LHS = getMulExpr(Operands);
3662 RHS = RHSCst;
3664 if (!Mul)
3665 return getUDivExactExpr(LHS, RHS);
3666 }
3667 }
3668 }
3669
3670 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3671 if (Mul->getOperand(i) == RHS) {
3673 append_range(Operands, Mul->operands().take_front(i));
3674 append_range(Operands, Mul->operands().drop_front(i + 1));
3675 return getMulExpr(Operands);
3676 }
3677 }
3678
3679 return getUDivExpr(LHS, RHS);
3680}
3681
3682/// Get an add recurrence expression for the specified loop. Simplify the
3683/// expression as much as possible.
3684const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3685 const Loop *L,
3686 SCEV::NoWrapFlags Flags) {
3688 Operands.push_back(Start);
3689 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3690 if (StepChrec->getLoop() == L) {
3691 append_range(Operands, StepChrec->operands());
3692 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3693 }
3694
3695 Operands.push_back(Step);
3696 return getAddRecExpr(Operands, L, Flags);
3697}
3698
3699/// Get an add recurrence expression for the specified loop. Simplify the
3700/// expression as much as possible.
3701const SCEV *
3703 const Loop *L, SCEV::NoWrapFlags Flags) {
3704 if (Operands.size() == 1) return Operands[0];
3705#ifndef NDEBUG
3707 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3708 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3709 "SCEVAddRecExpr operand types don't match!");
3710 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3711 }
3712 for (const SCEV *Op : Operands)
3714 "SCEVAddRecExpr operand is not available at loop entry!");
3715#endif
3716
3717 if (Operands.back()->isZero()) {
3718 Operands.pop_back();
3719 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3720 }
3721
3722 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3723 // use that information to infer NUW and NSW flags. However, computing a
3724 // BE count requires calling getAddRecExpr, so we may not yet have a
3725 // meaningful BE count at this point (and if we don't, we'd be stuck
3726 // with a SCEVCouldNotCompute as the cached BE count).
3727
3728 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3729
3730 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3731 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3732 const Loop *NestedLoop = NestedAR->getLoop();
3733 if (L->contains(NestedLoop)
3734 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3735 : (!NestedLoop->contains(L) &&
3736 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3737 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3738 Operands[0] = NestedAR->getStart();
3739 // AddRecs require their operands be loop-invariant with respect to their
3740 // loops. Don't perform this transformation if it would break this
3741 // requirement.
3742 bool AllInvariant = all_of(
3743 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3744
3745 if (AllInvariant) {
3746 // Create a recurrence for the outer loop with the same step size.
3747 //
3748 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3749 // inner recurrence has the same property.
3750 SCEV::NoWrapFlags OuterFlags =
3751 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3752
3753 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3754 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3755 return isLoopInvariant(Op, NestedLoop);
3756 });
3757
3758 if (AllInvariant) {
3759 // Ok, both add recurrences are valid after the transformation.
3760 //
3761 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3762 // the outer recurrence has the same property.
3763 SCEV::NoWrapFlags InnerFlags =
3764 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3765 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3766 }
3767 }
3768 // Reset Operands to its original state.
3769 Operands[0] = NestedAR;
3770 }
3771 }
3772
3773 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3774 // already have one, otherwise create a new one.
3775 return getOrCreateAddRecExpr(Operands, L, Flags);
3776}
3777
3778const SCEV *
3780 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3781 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3782 // getSCEV(Base)->getType() has the same address space as Base->getType()
3783 // because SCEV::getType() preserves the address space.
3784 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3785 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3786 if (NW != GEPNoWrapFlags::none()) {
3787 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3788 // but to do that, we have to ensure that said flag is valid in the entire
3789 // defined scope of the SCEV.
3790 // TODO: non-instructions have global scope. We might be able to prove
3791 // some global scope cases
3792 auto *GEPI = dyn_cast<Instruction>(GEP);
3793 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3794 NW = GEPNoWrapFlags::none();
3795 }
3796
3798 if (NW.hasNoUnsignedSignedWrap())
3799 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3800 if (NW.hasNoUnsignedWrap())
3801 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3802
3803 Type *CurTy = GEP->getType();
3804 bool FirstIter = true;
3806 for (const SCEV *IndexExpr : IndexExprs) {
3807 // Compute the (potentially symbolic) offset in bytes for this index.
3808 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3809 // For a struct, add the member offset.
3810 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3811 unsigned FieldNo = Index->getZExtValue();
3812 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3813 Offsets.push_back(FieldOffset);
3814
3815 // Update CurTy to the type of the field at Index.
3816 CurTy = STy->getTypeAtIndex(Index);
3817 } else {
3818 // Update CurTy to its element type.
3819 if (FirstIter) {
3820 assert(isa<PointerType>(CurTy) &&
3821 "The first index of a GEP indexes a pointer");
3822 CurTy = GEP->getSourceElementType();
3823 FirstIter = false;
3824 } else {
3826 }
3827 // For an array, add the element offset, explicitly scaled.
3828 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3829 // Getelementptr indices are signed.
3830 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3831
3832 // Multiply the index by the element size to compute the element offset.
3833 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3834 Offsets.push_back(LocalOffset);
3835 }
3836 }
3837
3838 // Handle degenerate case of GEP without offsets.
3839 if (Offsets.empty())
3840 return BaseExpr;
3841
3842 // Add the offsets together, assuming nsw if inbounds.
3843 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3844 // Add the base address and the offset. We cannot use the nsw flag, as the
3845 // base address is unsigned. However, if we know that the offset is
3846 // non-negative, we can use nuw.
3847 bool NUW = NW.hasNoUnsignedWrap() ||
3850 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3851 assert(BaseExpr->getType() == GEPExpr->getType() &&
3852 "GEP should not change type mid-flight.");
3853 return GEPExpr;
3854}
3855
3856SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3859 ID.AddInteger(SCEVType);
3860 for (const SCEV *Op : Ops)
3861 ID.AddPointer(Op);
3862 void *IP = nullptr;
3863 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3864}
3865
3866const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3868 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3869}
3870
3873 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3874 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3875 if (Ops.size() == 1) return Ops[0];
3876#ifndef NDEBUG
3877 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3878 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3879 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3880 "Operand types don't match!");
3881 assert(Ops[0]->getType()->isPointerTy() ==
3882 Ops[i]->getType()->isPointerTy() &&
3883 "min/max should be consistently pointerish");
3884 }
3885#endif
3886
3887 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3888 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3889
3890 const SCEV *Folded = constantFoldAndGroupOps(
3891 *this, LI, DT, Ops,
3892 [&](const APInt &C1, const APInt &C2) {
3893 switch (Kind) {
3894 case scSMaxExpr:
3895 return APIntOps::smax(C1, C2);
3896 case scSMinExpr:
3897 return APIntOps::smin(C1, C2);
3898 case scUMaxExpr:
3899 return APIntOps::umax(C1, C2);
3900 case scUMinExpr:
3901 return APIntOps::umin(C1, C2);
3902 default:
3903 llvm_unreachable("Unknown SCEV min/max opcode");
3904 }
3905 },
3906 [&](const APInt &C) {
3907 // identity
3908 if (IsMax)
3909 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3910 else
3911 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3912 },
3913 [&](const APInt &C) {
3914 // absorber
3915 if (IsMax)
3916 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3917 else
3918 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3919 });
3920 if (Folded)
3921 return Folded;
3922
3923 // Check if we have created the same expression before.
3924 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3925 return S;
3926 }
3927
3928 // Find the first operation of the same kind
3929 unsigned Idx = 0;
3930 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3931 ++Idx;
3932
3933 // Check to see if one of the operands is of the same kind. If so, expand its
3934 // operands onto our operand list, and recurse to simplify.
3935 if (Idx < Ops.size()) {
3936 bool DeletedAny = false;
3937 while (Ops[Idx]->getSCEVType() == Kind) {
3938 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3939 Ops.erase(Ops.begin()+Idx);
3940 append_range(Ops, SMME->operands());
3941 DeletedAny = true;
3942 }
3943
3944 if (DeletedAny)
3945 return getMinMaxExpr(Kind, Ops);
3946 }
3947
3948 // Okay, check to see if the same value occurs in the operand list twice. If
3949 // so, delete one. Since we sorted the list, these values are required to
3950 // be adjacent.
3955 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3956 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3957 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3958 if (Ops[i] == Ops[i + 1] ||
3959 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3960 // X op Y op Y --> X op Y
3961 // X op Y --> X, if we know X, Y are ordered appropriately
3962 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3963 --i;
3964 --e;
3965 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3966 Ops[i + 1])) {
3967 // X op Y --> Y, if we know X, Y are ordered appropriately
3968 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3969 --i;
3970 --e;
3971 }
3972 }
3973
3974 if (Ops.size() == 1) return Ops[0];
3975
3976 assert(!Ops.empty() && "Reduced smax down to nothing!");
3977
3978 // Okay, it looks like we really DO need an expr. Check to see if we
3979 // already have one, otherwise create a new one.
3981 ID.AddInteger(Kind);
3982 for (const SCEV *Op : Ops)
3983 ID.AddPointer(Op);
3984 void *IP = nullptr;
3985 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3986 if (ExistingSCEV)
3987 return ExistingSCEV;
3988 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3990 SCEV *S = new (SCEVAllocator)
3991 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3992
3993 UniqueSCEVs.InsertNode(S, IP);
3994 registerUser(S, Ops);
3995 return S;
3996}
3997
3998namespace {
3999
4000class SCEVSequentialMinMaxDeduplicatingVisitor final
4001 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4002 std::optional<const SCEV *>> {
4003 using RetVal = std::optional<const SCEV *>;
4005
4006 ScalarEvolution &SE;
4007 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4008 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4010
4011 bool canRecurseInto(SCEVTypes Kind) const {
4012 // We can only recurse into the SCEV expression of the same effective type
4013 // as the type of our root SCEV expression.
4014 return RootKind == Kind || NonSequentialRootKind == Kind;
4015 };
4016
4017 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4019 "Only for min/max expressions.");
4020 SCEVTypes Kind = S->getSCEVType();
4021
4022 if (!canRecurseInto(Kind))
4023 return S;
4024
4025 auto *NAry = cast<SCEVNAryExpr>(S);
4027 bool Changed = visit(Kind, NAry->operands(), NewOps);
4028
4029 if (!Changed)
4030 return S;
4031 if (NewOps.empty())
4032 return std::nullopt;
4033
4035 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4036 : SE.getMinMaxExpr(Kind, NewOps);
4037 }
4038
4039 RetVal visit(const SCEV *S) {
4040 // Has the whole operand been seen already?
4041 if (!SeenOps.insert(S).second)
4042 return std::nullopt;
4043 return Base::visit(S);
4044 }
4045
4046public:
4047 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4048 SCEVTypes RootKind)
4049 : SE(SE), RootKind(RootKind),
4050 NonSequentialRootKind(
4051 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4052 RootKind)) {}
4053
4054 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4055 SmallVectorImpl<const SCEV *> &NewOps) {
4056 bool Changed = false;
4058 Ops.reserve(OrigOps.size());
4059
4060 for (const SCEV *Op : OrigOps) {
4061 RetVal NewOp = visit(Op);
4062 if (NewOp != Op)
4063 Changed = true;
4064 if (NewOp)
4065 Ops.emplace_back(*NewOp);
4066 }
4067
4068 if (Changed)
4069 NewOps = std::move(Ops);
4070 return Changed;
4071 }
4072
4073 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4074
4075 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4076
4077 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4078
4079 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4080
4081 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4082
4083 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4084
4085 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4086
4087 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4088
4089 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4090
4091 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4092
4093 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4094 return visitAnyMinMaxExpr(Expr);
4095 }
4096
4097 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4098 return visitAnyMinMaxExpr(Expr);
4099 }
4100
4101 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4102 return visitAnyMinMaxExpr(Expr);
4103 }
4104
4105 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4106 return visitAnyMinMaxExpr(Expr);
4107 }
4108
4109 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4110 return visitAnyMinMaxExpr(Expr);
4111 }
4112
4113 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4114
4115 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4116};
4117
4118} // namespace
4119
4121 switch (Kind) {
4122 case scConstant:
4123 case scVScale:
4124 case scTruncate:
4125 case scZeroExtend:
4126 case scSignExtend:
4127 case scPtrToInt:
4128 case scAddExpr:
4129 case scMulExpr:
4130 case scUDivExpr:
4131 case scAddRecExpr:
4132 case scUMaxExpr:
4133 case scSMaxExpr:
4134 case scUMinExpr:
4135 case scSMinExpr:
4136 case scUnknown:
4137 // If any operand is poison, the whole expression is poison.
4138 return true;
4140 // FIXME: if the *first* operand is poison, the whole expression is poison.
4141 return false; // Pessimistically, say that it does not propagate poison.
4142 case scCouldNotCompute:
4143 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4144 }
4145 llvm_unreachable("Unknown SCEV kind!");
4146}
4147
4148namespace {
4149// The only way poison may be introduced in a SCEV expression is from a
4150// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4151// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4152// introduce poison -- they encode guaranteed, non-speculated knowledge.
4153//
4154// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4155// with the notable exception of umin_seq, where only poison from the first
4156// operand is (unconditionally) propagated.
4157struct SCEVPoisonCollector {
4158 bool LookThroughMaybePoisonBlocking;
4159 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4160 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4161 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4162
4163 bool follow(const SCEV *S) {
4164 if (!LookThroughMaybePoisonBlocking &&
4166 return false;
4167
4168 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4169 if (!isGuaranteedNotToBePoison(SU->getValue()))
4170 MaybePoison.insert(SU);
4171 }
4172 return true;
4173 }
4174 bool isDone() const { return false; }
4175};
4176} // namespace
4177
4178/// Return true if V is poison given that AssumedPoison is already poison.
4179static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4180 // First collect all SCEVs that might result in AssumedPoison to be poison.
4181 // We need to look through potentially poison-blocking operations here,
4182 // because we want to find all SCEVs that *might* result in poison, not only
4183 // those that are *required* to.
4184 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4185 visitAll(AssumedPoison, PC1);
4186
4187 // AssumedPoison is never poison. As the assumption is false, the implication
4188 // is true. Don't bother walking the other SCEV in this case.
4189 if (PC1.MaybePoison.empty())
4190 return true;
4191
4192 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4193 // as well. We cannot look through potentially poison-blocking operations
4194 // here, as their arguments only *may* make the result poison.
4195 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4196 visitAll(S, PC2);
4197
4198 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4199 // it will also make S poison by being part of PC2.MaybePoison.
4200 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4201}
4202
4204 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4205 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4206 visitAll(S, PC);
4207 for (const SCEVUnknown *SU : PC.MaybePoison)
4208 Result.insert(SU->getValue());
4209}
4210
4212 const SCEV *S, Instruction *I,
4213 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4214 // If the instruction cannot be poison, it's always safe to reuse.
4216 return true;
4217
4218 // Otherwise, it is possible that I is more poisonous that S. Collect the
4219 // poison-contributors of S, and then check whether I has any additional
4220 // poison-contributors. Poison that is contributed through poison-generating
4221 // flags is handled by dropping those flags instead.
4223 getPoisonGeneratingValues(PoisonVals, S);
4224
4225 SmallVector<Value *> Worklist;
4227 Worklist.push_back(I);
4228 while (!Worklist.empty()) {
4229 Value *V = Worklist.pop_back_val();
4230 if (!Visited.insert(V).second)
4231 continue;
4232
4233 // Avoid walking large instruction graphs.
4234 if (Visited.size() > 16)
4235 return false;
4236
4237 // Either the value can't be poison, or the S would also be poison if it
4238 // is.
4239 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4240 continue;
4241
4242 auto *I = dyn_cast<Instruction>(V);
4243 if (!I)
4244 return false;
4245
4246 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4247 // can't replace an arbitrary add with disjoint or, even if we drop the
4248 // flag. We would need to convert the or into an add.
4249 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4250 if (PDI->isDisjoint())
4251 return false;
4252
4253 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4254 // because SCEV currently assumes it can't be poison. Remove this special
4255 // case once we proper model when vscale can be poison.
4256 if (auto *II = dyn_cast<IntrinsicInst>(I);
4257 II && II->getIntrinsicID() == Intrinsic::vscale)
4258 continue;
4259
4260 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4261 return false;
4262
4263 // If the instruction can't create poison, we can recurse to its operands.
4264 if (I->hasPoisonGeneratingAnnotations())
4265 DropPoisonGeneratingInsts.push_back(I);
4266
4267 llvm::append_range(Worklist, I->operands());
4268 }
4269 return true;
4270}
4271
4272const SCEV *
4275 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4276 "Not a SCEVSequentialMinMaxExpr!");
4277 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4278 if (Ops.size() == 1)
4279 return Ops[0];
4280#ifndef NDEBUG
4281 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4282 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4283 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4284 "Operand types don't match!");
4285 assert(Ops[0]->getType()->isPointerTy() ==
4286 Ops[i]->getType()->isPointerTy() &&
4287 "min/max should be consistently pointerish");
4288 }
4289#endif
4290
4291 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4292 // so we can *NOT* do any kind of sorting of the expressions!
4293
4294 // Check if we have created the same expression before.
4295 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4296 return S;
4297
4298 // FIXME: there are *some* simplifications that we can do here.
4299
4300 // Keep only the first instance of an operand.
4301 {
4302 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4303 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4304 if (Changed)
4305 return getSequentialMinMaxExpr(Kind, Ops);
4306 }
4307
4308 // Check to see if one of the operands is of the same kind. If so, expand its
4309 // operands onto our operand list, and recurse to simplify.
4310 {
4311 unsigned Idx = 0;
4312 bool DeletedAny = false;
4313 while (Idx < Ops.size()) {
4314 if (Ops[Idx]->getSCEVType() != Kind) {
4315 ++Idx;
4316 continue;
4317 }
4318 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4319 Ops.erase(Ops.begin() + Idx);
4320 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4321 SMME->operands().end());
4322 DeletedAny = true;
4323 }
4324
4325 if (DeletedAny)
4326 return getSequentialMinMaxExpr(Kind, Ops);
4327 }
4328
4329 const SCEV *SaturationPoint;
4331 switch (Kind) {
4333 SaturationPoint = getZero(Ops[0]->getType());
4334 Pred = ICmpInst::ICMP_ULE;
4335 break;
4336 default:
4337 llvm_unreachable("Not a sequential min/max type.");
4338 }
4339
4340 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4341 if (!isGuaranteedNotToCauseUB(Ops[i]))
4342 continue;
4343 // We can replace %x umin_seq %y with %x umin %y if either:
4344 // * %y being poison implies %x is also poison.
4345 // * %x cannot be the saturating value (e.g. zero for umin).
4346 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4347 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4348 SaturationPoint)) {
4349 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4350 Ops[i - 1] = getMinMaxExpr(
4352 SeqOps);
4353 Ops.erase(Ops.begin() + i);
4354 return getSequentialMinMaxExpr(Kind, Ops);
4355 }
4356 // Fold %x umin_seq %y to %x if %x ule %y.
4357 // TODO: We might be able to prove the predicate for a later operand.
4358 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4359 Ops.erase(Ops.begin() + i);
4360 return getSequentialMinMaxExpr(Kind, Ops);
4361 }
4362 }
4363
4364 // Okay, it looks like we really DO need an expr. Check to see if we
4365 // already have one, otherwise create a new one.
4367 ID.AddInteger(Kind);
4368 for (const SCEV *Op : Ops)
4369 ID.AddPointer(Op);
4370 void *IP = nullptr;
4371 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4372 if (ExistingSCEV)
4373 return ExistingSCEV;
4374
4375 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4377 SCEV *S = new (SCEVAllocator)
4378 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4379
4380 UniqueSCEVs.InsertNode(S, IP);
4381 registerUser(S, Ops);
4382 return S;
4383}
4384
4385const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4386 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4387 return getSMaxExpr(Ops);
4388}
4389
4393
4394const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4395 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4396 return getUMaxExpr(Ops);
4397}
4398
4402
4404 const SCEV *RHS) {
4405 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4406 return getSMinExpr(Ops);
4407}
4408
4412
4413const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4414 bool Sequential) {
4415 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4416 return getUMinExpr(Ops, Sequential);
4417}
4418
4424
4425const SCEV *
4427 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4428 if (Size.isScalable())
4429 Res = getMulExpr(Res, getVScale(IntTy));
4430 return Res;
4431}
4432
4434 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4435}
4436
4438 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4439}
4440
4442 StructType *STy,
4443 unsigned FieldNo) {
4444 // We can bypass creating a target-independent constant expression and then
4445 // folding it back into a ConstantInt. This is just a compile-time
4446 // optimization.
4447 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4448 assert(!SL->getSizeInBits().isScalable() &&
4449 "Cannot get offset for structure containing scalable vector types");
4450 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4451}
4452
4454 // Don't attempt to do anything other than create a SCEVUnknown object
4455 // here. createSCEV only calls getUnknown after checking for all other
4456 // interesting possibilities, and any other code that calls getUnknown
4457 // is doing so in order to hide a value from SCEV canonicalization.
4458
4460 ID.AddInteger(scUnknown);
4461 ID.AddPointer(V);
4462 void *IP = nullptr;
4463 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4464 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4465 "Stale SCEVUnknown in uniquing map!");
4466 return S;
4467 }
4468 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4469 FirstUnknown);
4470 FirstUnknown = cast<SCEVUnknown>(S);
4471 UniqueSCEVs.InsertNode(S, IP);
4472 return S;
4473}
4474
4475//===----------------------------------------------------------------------===//
4476// Basic SCEV Analysis and PHI Idiom Recognition Code
4477//
4478
4479/// Test if values of the given type are analyzable within the SCEV
4480/// framework. This primarily includes integer types, and it can optionally
4481/// include pointer types if the ScalarEvolution class has access to
4482/// target-specific information.
4484 // Integers and pointers are always SCEVable.
4485 return Ty->isIntOrPtrTy();
4486}
4487
4488/// Return the size in bits of the specified type, for which isSCEVable must
4489/// return true.
4491 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4492 if (Ty->isPointerTy())
4494 return getDataLayout().getTypeSizeInBits(Ty);
4495}
4496
4497/// Return a type with the same bitwidth as the given type and which represents
4498/// how SCEV will treat the given type, for which isSCEVable must return
4499/// true. For pointer types, this is the pointer index sized integer type.
4501 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4502
4503 if (Ty->isIntegerTy())
4504 return Ty;
4505
4506 // The only other support type is pointer.
4507 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4508 return getDataLayout().getIndexType(Ty);
4509}
4510
4512 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4513}
4514
4516 const SCEV *B) {
4517 /// For a valid use point to exist, the defining scope of one operand
4518 /// must dominate the other.
4519 bool PreciseA, PreciseB;
4520 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4521 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4522 if (!PreciseA || !PreciseB)
4523 // Can't tell.
4524 return false;
4525 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4526 DT.dominates(ScopeB, ScopeA);
4527}
4528
4530 return CouldNotCompute.get();
4531}
4532
4533bool ScalarEvolution::checkValidity(const SCEV *S) const {
4534 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4535 auto *SU = dyn_cast<SCEVUnknown>(S);
4536 return SU && SU->getValue() == nullptr;
4537 });
4538
4539 return !ContainsNulls;
4540}
4541
4543 HasRecMapType::iterator I = HasRecMap.find(S);
4544 if (I != HasRecMap.end())
4545 return I->second;
4546
4547 bool FoundAddRec =
4548 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4549 HasRecMap.insert({S, FoundAddRec});
4550 return FoundAddRec;
4551}
4552
4553/// Return the ValueOffsetPair set for \p S. \p S can be represented
4554/// by the value and offset from any ValueOffsetPair in the set.
4555ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4556 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4557 if (SI == ExprValueMap.end())
4558 return {};
4559 return SI->second.getArrayRef();
4560}
4561
4562/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4563/// cannot be used separately. eraseValueFromMap should be used to remove
4564/// V from ValueExprMap and ExprValueMap at the same time.
4565void ScalarEvolution::eraseValueFromMap(Value *V) {
4566 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4567 if (I != ValueExprMap.end()) {
4568 auto EVIt = ExprValueMap.find(I->second);
4569 bool Removed = EVIt->second.remove(V);
4570 (void) Removed;
4571 assert(Removed && "Value not in ExprValueMap?");
4572 ValueExprMap.erase(I);
4573 }
4574}
4575
4576void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4577 // A recursive query may have already computed the SCEV. It should be
4578 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4579 // inferred nowrap flags.
4580 auto It = ValueExprMap.find_as(V);
4581 if (It == ValueExprMap.end()) {
4582 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4583 ExprValueMap[S].insert(V);
4584 }
4585}
4586
4587/// Return an existing SCEV if it exists, otherwise analyze the expression and
4588/// create a new one.
4590 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4591
4592 if (const SCEV *S = getExistingSCEV(V))
4593 return S;
4594 return createSCEVIter(V);
4595}
4596
4598 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4599
4600 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4601 if (I != ValueExprMap.end()) {
4602 const SCEV *S = I->second;
4603 assert(checkValidity(S) &&
4604 "existing SCEV has not been properly invalidated");
4605 return S;
4606 }
4607 return nullptr;
4608}
4609
4610/// Return a SCEV corresponding to -V = -1*V
4612 SCEV::NoWrapFlags Flags) {
4613 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4614 return getConstant(
4615 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4616
4617 Type *Ty = V->getType();
4618 Ty = getEffectiveSCEVType(Ty);
4619 return getMulExpr(V, getMinusOne(Ty), Flags);
4620}
4621
4622/// If Expr computes ~A, return A else return nullptr
4623static const SCEV *MatchNotExpr(const SCEV *Expr) {
4624 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4625 if (!Add || Add->getNumOperands() != 2 ||
4626 !Add->getOperand(0)->isAllOnesValue())
4627 return nullptr;
4628
4629 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4630 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4631 !AddRHS->getOperand(0)->isAllOnesValue())
4632 return nullptr;
4633
4634 return AddRHS->getOperand(1);
4635}
4636
4637/// Return a SCEV corresponding to ~V = -1-V
4639 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4640
4641 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4642 return getConstant(
4643 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4644
4645 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4646 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4647 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4648 SmallVector<const SCEV *, 2> MatchedOperands;
4649 for (const SCEV *Operand : MME->operands()) {
4650 const SCEV *Matched = MatchNotExpr(Operand);
4651 if (!Matched)
4652 return (const SCEV *)nullptr;
4653 MatchedOperands.push_back(Matched);
4654 }
4655 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4656 MatchedOperands);
4657 };
4658 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4659 return Replaced;
4660 }
4661
4662 Type *Ty = V->getType();
4663 Ty = getEffectiveSCEVType(Ty);
4664 return getMinusSCEV(getMinusOne(Ty), V);
4665}
4666
4668 assert(P->getType()->isPointerTy());
4669
4670 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4671 // The base of an AddRec is the first operand.
4672 SmallVector<const SCEV *> Ops{AddRec->operands()};
4673 Ops[0] = removePointerBase(Ops[0]);
4674 // Don't try to transfer nowrap flags for now. We could in some cases
4675 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4676 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4677 }
4678 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4679 // The base of an Add is the pointer operand.
4680 SmallVector<const SCEV *> Ops{Add->operands()};
4681 const SCEV **PtrOp = nullptr;
4682 for (const SCEV *&AddOp : Ops) {
4683 if (AddOp->getType()->isPointerTy()) {
4684 assert(!PtrOp && "Cannot have multiple pointer ops");
4685 PtrOp = &AddOp;
4686 }
4687 }
4688 *PtrOp = removePointerBase(*PtrOp);
4689 // Don't try to transfer nowrap flags for now. We could in some cases
4690 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4691 return getAddExpr(Ops);
4692 }
4693 // Any other expression must be a pointer base.
4694 return getZero(P->getType());
4695}
4696
4697const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4698 SCEV::NoWrapFlags Flags,
4699 unsigned Depth) {
4700 // Fast path: X - X --> 0.
4701 if (LHS == RHS)
4702 return getZero(LHS->getType());
4703
4704 // If we subtract two pointers with different pointer bases, bail.
4705 // Eventually, we're going to add an assertion to getMulExpr that we
4706 // can't multiply by a pointer.
4707 if (RHS->getType()->isPointerTy()) {
4708 if (!LHS->getType()->isPointerTy() ||
4709 getPointerBase(LHS) != getPointerBase(RHS))
4710 return getCouldNotCompute();
4711 LHS = removePointerBase(LHS);
4712 RHS = removePointerBase(RHS);
4713 }
4714
4715 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4716 // makes it so that we cannot make much use of NUW.
4717 auto AddFlags = SCEV::FlagAnyWrap;
4718 const bool RHSIsNotMinSigned =
4720 if (hasFlags(Flags, SCEV::FlagNSW)) {
4721 // Let M be the minimum representable signed value. Then (-1)*RHS
4722 // signed-wraps if and only if RHS is M. That can happen even for
4723 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4724 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4725 // (-1)*RHS, we need to prove that RHS != M.
4726 //
4727 // If LHS is non-negative and we know that LHS - RHS does not
4728 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4729 // either by proving that RHS > M or that LHS >= 0.
4730 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4731 AddFlags = SCEV::FlagNSW;
4732 }
4733 }
4734
4735 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4736 // RHS is NSW and LHS >= 0.
4737 //
4738 // The difficulty here is that the NSW flag may have been proven
4739 // relative to a loop that is to be found in a recurrence in LHS and
4740 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4741 // larger scope than intended.
4742 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4743
4744 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4745}
4746
4748 unsigned Depth) {
4749 Type *SrcTy = V->getType();
4750 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4751 "Cannot truncate or zero extend with non-integer arguments!");
4752 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4753 return V; // No conversion
4754 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4755 return getTruncateExpr(V, Ty, Depth);
4756 return getZeroExtendExpr(V, Ty, Depth);
4757}
4758
4760 unsigned Depth) {
4761 Type *SrcTy = V->getType();
4762 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4763 "Cannot truncate or zero extend with non-integer arguments!");
4764 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4765 return V; // No conversion
4766 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4767 return getTruncateExpr(V, Ty, Depth);
4768 return getSignExtendExpr(V, Ty, Depth);
4769}
4770
4771const SCEV *
4773 Type *SrcTy = V->getType();
4774 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4775 "Cannot noop or zero extend with non-integer arguments!");
4777 "getNoopOrZeroExtend cannot truncate!");
4778 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4779 return V; // No conversion
4780 return getZeroExtendExpr(V, Ty);
4781}
4782
4783const SCEV *
4785 Type *SrcTy = V->getType();
4786 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4787 "Cannot noop or sign extend with non-integer arguments!");
4789 "getNoopOrSignExtend cannot truncate!");
4790 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4791 return V; // No conversion
4792 return getSignExtendExpr(V, Ty);
4793}
4794
4795const SCEV *
4797 Type *SrcTy = V->getType();
4798 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4799 "Cannot noop or any extend with non-integer arguments!");
4801 "getNoopOrAnyExtend cannot truncate!");
4802 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4803 return V; // No conversion
4804 return getAnyExtendExpr(V, Ty);
4805}
4806
4807const SCEV *
4809 Type *SrcTy = V->getType();
4810 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4811 "Cannot truncate or noop with non-integer arguments!");
4813 "getTruncateOrNoop cannot extend!");
4814 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4815 return V; // No conversion
4816 return getTruncateExpr(V, Ty);
4817}
4818
4820 const SCEV *RHS) {
4821 const SCEV *PromotedLHS = LHS;
4822 const SCEV *PromotedRHS = RHS;
4823
4824 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4825 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4826 else
4827 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4828
4829 return getUMaxExpr(PromotedLHS, PromotedRHS);
4830}
4831
4833 const SCEV *RHS,
4834 bool Sequential) {
4835 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4836 return getUMinFromMismatchedTypes(Ops, Sequential);
4837}
4838
4839const SCEV *
4841 bool Sequential) {
4842 assert(!Ops.empty() && "At least one operand must be!");
4843 // Trivial case.
4844 if (Ops.size() == 1)
4845 return Ops[0];
4846
4847 // Find the max type first.
4848 Type *MaxType = nullptr;
4849 for (const auto *S : Ops)
4850 if (MaxType)
4851 MaxType = getWiderType(MaxType, S->getType());
4852 else
4853 MaxType = S->getType();
4854 assert(MaxType && "Failed to find maximum type!");
4855
4856 // Extend all ops to max type.
4857 SmallVector<const SCEV *, 2> PromotedOps;
4858 for (const auto *S : Ops)
4859 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4860
4861 // Generate umin.
4862 return getUMinExpr(PromotedOps, Sequential);
4863}
4864
4866 // A pointer operand may evaluate to a nonpointer expression, such as null.
4867 if (!V->getType()->isPointerTy())
4868 return V;
4869
4870 while (true) {
4871 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4872 V = AddRec->getStart();
4873 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4874 const SCEV *PtrOp = nullptr;
4875 for (const SCEV *AddOp : Add->operands()) {
4876 if (AddOp->getType()->isPointerTy()) {
4877 assert(!PtrOp && "Cannot have multiple pointer ops");
4878 PtrOp = AddOp;
4879 }
4880 }
4881 assert(PtrOp && "Must have pointer op");
4882 V = PtrOp;
4883 } else // Not something we can look further into.
4884 return V;
4885 }
4886}
4887
4888/// Push users of the given Instruction onto the given Worklist.
4892 // Push the def-use children onto the Worklist stack.
4893 for (User *U : I->users()) {
4894 auto *UserInsn = cast<Instruction>(U);
4895 if (Visited.insert(UserInsn).second)
4896 Worklist.push_back(UserInsn);
4897 }
4898}
4899
4900namespace {
4901
4902/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4903/// expression in case its Loop is L. If it is not L then
4904/// if IgnoreOtherLoops is true then use AddRec itself
4905/// otherwise rewrite cannot be done.
4906/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4907class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4908public:
4909 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4910 bool IgnoreOtherLoops = true) {
4911 SCEVInitRewriter Rewriter(L, SE);
4912 const SCEV *Result = Rewriter.visit(S);
4913 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4914 return SE.getCouldNotCompute();
4915 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4916 ? SE.getCouldNotCompute()
4917 : Result;
4918 }
4919
4920 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4921 if (!SE.isLoopInvariant(Expr, L))
4922 SeenLoopVariantSCEVUnknown = true;
4923 return Expr;
4924 }
4925
4926 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4927 // Only re-write AddRecExprs for this loop.
4928 if (Expr->getLoop() == L)
4929 return Expr->getStart();
4930 SeenOtherLoops = true;
4931 return Expr;
4932 }
4933
4934 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4935
4936 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4937
4938private:
4939 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4940 : SCEVRewriteVisitor(SE), L(L) {}
4941
4942 const Loop *L;
4943 bool SeenLoopVariantSCEVUnknown = false;
4944 bool SeenOtherLoops = false;
4945};
4946
4947/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4948/// increment expression in case its Loop is L. If it is not L then
4949/// use AddRec itself.
4950/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4951class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4952public:
4953 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4954 SCEVPostIncRewriter Rewriter(L, SE);
4955 const SCEV *Result = Rewriter.visit(S);
4956 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4957 ? SE.getCouldNotCompute()
4958 : Result;
4959 }
4960
4961 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4962 if (!SE.isLoopInvariant(Expr, L))
4963 SeenLoopVariantSCEVUnknown = true;
4964 return Expr;
4965 }
4966
4967 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4968 // Only re-write AddRecExprs for this loop.
4969 if (Expr->getLoop() == L)
4970 return Expr->getPostIncExpr(SE);
4971 SeenOtherLoops = true;
4972 return Expr;
4973 }
4974
4975 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4976
4977 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4978
4979private:
4980 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4981 : SCEVRewriteVisitor(SE), L(L) {}
4982
4983 const Loop *L;
4984 bool SeenLoopVariantSCEVUnknown = false;
4985 bool SeenOtherLoops = false;
4986};
4987
4988/// This class evaluates the compare condition by matching it against the
4989/// condition of loop latch. If there is a match we assume a true value
4990/// for the condition while building SCEV nodes.
4991class SCEVBackedgeConditionFolder
4992 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4993public:
4994 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4995 ScalarEvolution &SE) {
4996 bool IsPosBECond = false;
4997 Value *BECond = nullptr;
4998 if (BasicBlock *Latch = L->getLoopLatch()) {
4999 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5000 if (BI && BI->isConditional()) {
5001 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5002 "Both outgoing branches should not target same header!");
5003 BECond = BI->getCondition();
5004 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5005 } else {
5006 return S;
5007 }
5008 }
5009 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5010 return Rewriter.visit(S);
5011 }
5012
5013 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5014 const SCEV *Result = Expr;
5015 bool InvariantF = SE.isLoopInvariant(Expr, L);
5016
5017 if (!InvariantF) {
5019 switch (I->getOpcode()) {
5020 case Instruction::Select: {
5021 SelectInst *SI = cast<SelectInst>(I);
5022 std::optional<const SCEV *> Res =
5023 compareWithBackedgeCondition(SI->getCondition());
5024 if (Res) {
5025 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5026 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5027 }
5028 break;
5029 }
5030 default: {
5031 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5032 if (Res)
5033 Result = *Res;
5034 break;
5035 }
5036 }
5037 }
5038 return Result;
5039 }
5040
5041private:
5042 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5043 bool IsPosBECond, ScalarEvolution &SE)
5044 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5045 IsPositiveBECond(IsPosBECond) {}
5046
5047 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5048
5049 const Loop *L;
5050 /// Loop back condition.
5051 Value *BackedgeCond = nullptr;
5052 /// Set to true if loop back is on positive branch condition.
5053 bool IsPositiveBECond;
5054};
5055
5056std::optional<const SCEV *>
5057SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5058
5059 // If value matches the backedge condition for loop latch,
5060 // then return a constant evolution node based on loopback
5061 // branch taken.
5062 if (BackedgeCond == IC)
5063 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5065 return std::nullopt;
5066}
5067
5068class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5069public:
5070 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5071 ScalarEvolution &SE) {
5072 SCEVShiftRewriter Rewriter(L, SE);
5073 const SCEV *Result = Rewriter.visit(S);
5074 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5075 }
5076
5077 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5078 // Only allow AddRecExprs for this loop.
5079 if (!SE.isLoopInvariant(Expr, L))
5080 Valid = false;
5081 return Expr;
5082 }
5083
5084 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5085 if (Expr->getLoop() == L && Expr->isAffine())
5086 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5087 Valid = false;
5088 return Expr;
5089 }
5090
5091 bool isValid() { return Valid; }
5092
5093private:
5094 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5095 : SCEVRewriteVisitor(SE), L(L) {}
5096
5097 const Loop *L;
5098 bool Valid = true;
5099};
5100
5101} // end anonymous namespace
5102
5104ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5105 if (!AR->isAffine())
5106 return SCEV::FlagAnyWrap;
5107
5108 using OBO = OverflowingBinaryOperator;
5109
5111
5112 if (!AR->hasNoSelfWrap()) {
5113 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5114 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5115 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5116 const APInt &BECountAP = BECountMax->getAPInt();
5117 unsigned NoOverflowBitWidth =
5118 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5119 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5121 }
5122 }
5123
5124 if (!AR->hasNoSignedWrap()) {
5125 ConstantRange AddRecRange = getSignedRange(AR);
5126 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5127
5129 Instruction::Add, IncRange, OBO::NoSignedWrap);
5130 if (NSWRegion.contains(AddRecRange))
5132 }
5133
5134 if (!AR->hasNoUnsignedWrap()) {
5135 ConstantRange AddRecRange = getUnsignedRange(AR);
5136 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5137
5139 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5140 if (NUWRegion.contains(AddRecRange))
5142 }
5143
5144 return Result;
5145}
5146
5148ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5150
5151 if (AR->hasNoSignedWrap())
5152 return Result;
5153
5154 if (!AR->isAffine())
5155 return Result;
5156
5157 // This function can be expensive, only try to prove NSW once per AddRec.
5158 if (!SignedWrapViaInductionTried.insert(AR).second)
5159 return Result;
5160
5161 const SCEV *Step = AR->getStepRecurrence(*this);
5162 const Loop *L = AR->getLoop();
5163
5164 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5165 // Note that this serves two purposes: It filters out loops that are
5166 // simply not analyzable, and it covers the case where this code is
5167 // being called from within backedge-taken count analysis, such that
5168 // attempting to ask for the backedge-taken count would likely result
5169 // in infinite recursion. In the later case, the analysis code will
5170 // cope with a conservative value, and it will take care to purge
5171 // that value once it has finished.
5172 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5173
5174 // Normally, in the cases we can prove no-overflow via a
5175 // backedge guarding condition, we can also compute a backedge
5176 // taken count for the loop. The exceptions are assumptions and
5177 // guards present in the loop -- SCEV is not great at exploiting
5178 // these to compute max backedge taken counts, but can still use
5179 // these to prove lack of overflow. Use this fact to avoid
5180 // doing extra work that may not pay off.
5181
5182 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5183 AC.assumptions().empty())
5184 return Result;
5185
5186 // If the backedge is guarded by a comparison with the pre-inc value the
5187 // addrec is safe. Also, if the entry is guarded by a comparison with the
5188 // start value and the backedge is guarded by a comparison with the post-inc
5189 // value, the addrec is safe.
5191 const SCEV *OverflowLimit =
5192 getSignedOverflowLimitForStep(Step, &Pred, this);
5193 if (OverflowLimit &&
5194 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5195 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5196 Result = setFlags(Result, SCEV::FlagNSW);
5197 }
5198 return Result;
5199}
5201ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5203
5204 if (AR->hasNoUnsignedWrap())
5205 return Result;
5206
5207 if (!AR->isAffine())
5208 return Result;
5209
5210 // This function can be expensive, only try to prove NUW once per AddRec.
5211 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5212 return Result;
5213
5214 const SCEV *Step = AR->getStepRecurrence(*this);
5215 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5216 const Loop *L = AR->getLoop();
5217
5218 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5219 // Note that this serves two purposes: It filters out loops that are
5220 // simply not analyzable, and it covers the case where this code is
5221 // being called from within backedge-taken count analysis, such that
5222 // attempting to ask for the backedge-taken count would likely result
5223 // in infinite recursion. In the later case, the analysis code will
5224 // cope with a conservative value, and it will take care to purge
5225 // that value once it has finished.
5226 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5227
5228 // Normally, in the cases we can prove no-overflow via a
5229 // backedge guarding condition, we can also compute a backedge
5230 // taken count for the loop. The exceptions are assumptions and
5231 // guards present in the loop -- SCEV is not great at exploiting
5232 // these to compute max backedge taken counts, but can still use
5233 // these to prove lack of overflow. Use this fact to avoid
5234 // doing extra work that may not pay off.
5235
5236 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5237 AC.assumptions().empty())
5238 return Result;
5239
5240 // If the backedge is guarded by a comparison with the pre-inc value the
5241 // addrec is safe. Also, if the entry is guarded by a comparison with the
5242 // start value and the backedge is guarded by a comparison with the post-inc
5243 // value, the addrec is safe.
5244 if (isKnownPositive(Step)) {
5245 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5246 getUnsignedRangeMax(Step));
5249 Result = setFlags(Result, SCEV::FlagNUW);
5250 }
5251 }
5252
5253 return Result;
5254}
5255
5256namespace {
5257
5258/// Represents an abstract binary operation. This may exist as a
5259/// normal instruction or constant expression, or may have been
5260/// derived from an expression tree.
5261struct BinaryOp {
5262 unsigned Opcode;
5263 Value *LHS;
5264 Value *RHS;
5265 bool IsNSW = false;
5266 bool IsNUW = false;
5267
5268 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5269 /// constant expression.
5270 Operator *Op = nullptr;
5271
5272 explicit BinaryOp(Operator *Op)
5273 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5274 Op(Op) {
5275 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5276 IsNSW = OBO->hasNoSignedWrap();
5277 IsNUW = OBO->hasNoUnsignedWrap();
5278 }
5279 }
5280
5281 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5282 bool IsNUW = false)
5283 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5284};
5285
5286} // end anonymous namespace
5287
5288/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5289static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5290 AssumptionCache &AC,
5291 const DominatorTree &DT,
5292 const Instruction *CxtI) {
5293 auto *Op = dyn_cast<Operator>(V);
5294 if (!Op)
5295 return std::nullopt;
5296
5297 // Implementation detail: all the cleverness here should happen without
5298 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5299 // SCEV expressions when possible, and we should not break that.
5300
5301 switch (Op->getOpcode()) {
5302 case Instruction::Add:
5303 case Instruction::Sub:
5304 case Instruction::Mul:
5305 case Instruction::UDiv:
5306 case Instruction::URem:
5307 case Instruction::And:
5308 case Instruction::AShr:
5309 case Instruction::Shl:
5310 return BinaryOp(Op);
5311
5312 case Instruction::Or: {
5313 // Convert or disjoint into add nuw nsw.
5314 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5315 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5316 /*IsNSW=*/true, /*IsNUW=*/true);
5317 return BinaryOp(Op);
5318 }
5319
5320 case Instruction::Xor:
5321 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5322 // If the RHS of the xor is a signmask, then this is just an add.
5323 // Instcombine turns add of signmask into xor as a strength reduction step.
5324 if (RHSC->getValue().isSignMask())
5325 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5326 // Binary `xor` is a bit-wise `add`.
5327 if (V->getType()->isIntegerTy(1))
5328 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5329 return BinaryOp(Op);
5330
5331 case Instruction::LShr:
5332 // Turn logical shift right of a constant into a unsigned divide.
5333 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5334 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5335
5336 // If the shift count is not less than the bitwidth, the result of
5337 // the shift is undefined. Don't try to analyze it, because the
5338 // resolution chosen here may differ from the resolution chosen in
5339 // other parts of the compiler.
5340 if (SA->getValue().ult(BitWidth)) {
5341 Constant *X =
5342 ConstantInt::get(SA->getContext(),
5343 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5344 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5345 }
5346 }
5347 return BinaryOp(Op);
5348
5349 case Instruction::ExtractValue: {
5350 auto *EVI = cast<ExtractValueInst>(Op);
5351 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5352 break;
5353
5354 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5355 if (!WO)
5356 break;
5357
5358 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5359 bool Signed = WO->isSigned();
5360 // TODO: Should add nuw/nsw flags for mul as well.
5361 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5362 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5363
5364 // Now that we know that all uses of the arithmetic-result component of
5365 // CI are guarded by the overflow check, we can go ahead and pretend
5366 // that the arithmetic is non-overflowing.
5367 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5368 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5369 }
5370
5371 default:
5372 break;
5373 }
5374
5375 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5376 // semantics as a Sub, return a binary sub expression.
5377 if (auto *II = dyn_cast<IntrinsicInst>(V))
5378 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5379 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5380
5381 return std::nullopt;
5382}
5383
5384/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5385/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5386/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5387/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5388/// follows one of the following patterns:
5389/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5390/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5391/// If the SCEV expression of \p Op conforms with one of the expected patterns
5392/// we return the type of the truncation operation, and indicate whether the
5393/// truncated type should be treated as signed/unsigned by setting
5394/// \p Signed to true/false, respectively.
5395static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5396 bool &Signed, ScalarEvolution &SE) {
5397 // The case where Op == SymbolicPHI (that is, with no type conversions on
5398 // the way) is handled by the regular add recurrence creating logic and
5399 // would have already been triggered in createAddRecForPHI. Reaching it here
5400 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5401 // because one of the other operands of the SCEVAddExpr updating this PHI is
5402 // not invariant).
5403 //
5404 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5405 // this case predicates that allow us to prove that Op == SymbolicPHI will
5406 // be added.
5407 if (Op == SymbolicPHI)
5408 return nullptr;
5409
5410 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5411 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5412 if (SourceBits != NewBits)
5413 return nullptr;
5414
5417 if (!SExt && !ZExt)
5418 return nullptr;
5419 const SCEVTruncateExpr *Trunc =
5422 if (!Trunc)
5423 return nullptr;
5424 const SCEV *X = Trunc->getOperand();
5425 if (X != SymbolicPHI)
5426 return nullptr;
5427 Signed = SExt != nullptr;
5428 return Trunc->getType();
5429}
5430
5431static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5432 if (!PN->getType()->isIntegerTy())
5433 return nullptr;
5434 const Loop *L = LI.getLoopFor(PN->getParent());
5435 if (!L || L->getHeader() != PN->getParent())
5436 return nullptr;
5437 return L;
5438}
5439
5440// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5441// computation that updates the phi follows the following pattern:
5442// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5443// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5444// If so, try to see if it can be rewritten as an AddRecExpr under some
5445// Predicates. If successful, return them as a pair. Also cache the results
5446// of the analysis.
5447//
5448// Example usage scenario:
5449// Say the Rewriter is called for the following SCEV:
5450// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5451// where:
5452// %X = phi i64 (%Start, %BEValue)
5453// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5454// and call this function with %SymbolicPHI = %X.
5455//
5456// The analysis will find that the value coming around the backedge has
5457// the following SCEV:
5458// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5459// Upon concluding that this matches the desired pattern, the function
5460// will return the pair {NewAddRec, SmallPredsVec} where:
5461// NewAddRec = {%Start,+,%Step}
5462// SmallPredsVec = {P1, P2, P3} as follows:
5463// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5464// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5465// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5466// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5467// under the predicates {P1,P2,P3}.
5468// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5469// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5470//
5471// TODO's:
5472//
5473// 1) Extend the Induction descriptor to also support inductions that involve
5474// casts: When needed (namely, when we are called in the context of the
5475// vectorizer induction analysis), a Set of cast instructions will be
5476// populated by this method, and provided back to isInductionPHI. This is
5477// needed to allow the vectorizer to properly record them to be ignored by
5478// the cost model and to avoid vectorizing them (otherwise these casts,
5479// which are redundant under the runtime overflow checks, will be
5480// vectorized, which can be costly).
5481//
5482// 2) Support additional induction/PHISCEV patterns: We also want to support
5483// inductions where the sext-trunc / zext-trunc operations (partly) occur
5484// after the induction update operation (the induction increment):
5485//
5486// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5487// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5488//
5489// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5490// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5491//
5492// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5493std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5494ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5496
5497 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5498 // return an AddRec expression under some predicate.
5499
5500 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5501 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5502 assert(L && "Expecting an integer loop header phi");
5503
5504 // The loop may have multiple entrances or multiple exits; we can analyze
5505 // this phi as an addrec if it has a unique entry value and a unique
5506 // backedge value.
5507 Value *BEValueV = nullptr, *StartValueV = nullptr;
5508 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5509 Value *V = PN->getIncomingValue(i);
5510 if (L->contains(PN->getIncomingBlock(i))) {
5511 if (!BEValueV) {
5512 BEValueV = V;
5513 } else if (BEValueV != V) {
5514 BEValueV = nullptr;
5515 break;
5516 }
5517 } else if (!StartValueV) {
5518 StartValueV = V;
5519 } else if (StartValueV != V) {
5520 StartValueV = nullptr;
5521 break;
5522 }
5523 }
5524 if (!BEValueV || !StartValueV)
5525 return std::nullopt;
5526
5527 const SCEV *BEValue = getSCEV(BEValueV);
5528
5529 // If the value coming around the backedge is an add with the symbolic
5530 // value we just inserted, possibly with casts that we can ignore under
5531 // an appropriate runtime guard, then we found a simple induction variable!
5532 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5533 if (!Add)
5534 return std::nullopt;
5535
5536 // If there is a single occurrence of the symbolic value, possibly
5537 // casted, replace it with a recurrence.
5538 unsigned FoundIndex = Add->getNumOperands();
5539 Type *TruncTy = nullptr;
5540 bool Signed;
5541 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5542 if ((TruncTy =
5543 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5544 if (FoundIndex == e) {
5545 FoundIndex = i;
5546 break;
5547 }
5548
5549 if (FoundIndex == Add->getNumOperands())
5550 return std::nullopt;
5551
5552 // Create an add with everything but the specified operand.
5554 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5555 if (i != FoundIndex)
5556 Ops.push_back(Add->getOperand(i));
5557 const SCEV *Accum = getAddExpr(Ops);
5558
5559 // The runtime checks will not be valid if the step amount is
5560 // varying inside the loop.
5561 if (!isLoopInvariant(Accum, L))
5562 return std::nullopt;
5563
5564 // *** Part2: Create the predicates
5565
5566 // Analysis was successful: we have a phi-with-cast pattern for which we
5567 // can return an AddRec expression under the following predicates:
5568 //
5569 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5570 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5571 // P2: An Equal predicate that guarantees that
5572 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5573 // P3: An Equal predicate that guarantees that
5574 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5575 //
5576 // As we next prove, the above predicates guarantee that:
5577 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5578 //
5579 //
5580 // More formally, we want to prove that:
5581 // Expr(i+1) = Start + (i+1) * Accum
5582 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5583 //
5584 // Given that:
5585 // 1) Expr(0) = Start
5586 // 2) Expr(1) = Start + Accum
5587 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5588 // 3) Induction hypothesis (step i):
5589 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5590 //
5591 // Proof:
5592 // Expr(i+1) =
5593 // = Start + (i+1)*Accum
5594 // = (Start + i*Accum) + Accum
5595 // = Expr(i) + Accum
5596 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5597 // :: from step i
5598 //
5599 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5600 //
5601 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5602 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5603 // + Accum :: from P3
5604 //
5605 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5606 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5607 //
5608 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5609 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5610 //
5611 // By induction, the same applies to all iterations 1<=i<n:
5612 //
5613
5614 // Create a truncated addrec for which we will add a no overflow check (P1).
5615 const SCEV *StartVal = getSCEV(StartValueV);
5616 const SCEV *PHISCEV =
5617 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5618 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5619
5620 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5621 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5622 // will be constant.
5623 //
5624 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5625 // add P1.
5626 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5630 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5631 Predicates.push_back(AddRecPred);
5632 }
5633
5634 // Create the Equal Predicates P2,P3:
5635
5636 // It is possible that the predicates P2 and/or P3 are computable at
5637 // compile time due to StartVal and/or Accum being constants.
5638 // If either one is, then we can check that now and escape if either P2
5639 // or P3 is false.
5640
5641 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5642 // for each of StartVal and Accum
5643 auto getExtendedExpr = [&](const SCEV *Expr,
5644 bool CreateSignExtend) -> const SCEV * {
5645 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5646 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5647 const SCEV *ExtendedExpr =
5648 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5649 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5650 return ExtendedExpr;
5651 };
5652
5653 // Given:
5654 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5655 // = getExtendedExpr(Expr)
5656 // Determine whether the predicate P: Expr == ExtendedExpr
5657 // is known to be false at compile time
5658 auto PredIsKnownFalse = [&](const SCEV *Expr,
5659 const SCEV *ExtendedExpr) -> bool {
5660 return Expr != ExtendedExpr &&
5661 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5662 };
5663
5664 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5665 if (PredIsKnownFalse(StartVal, StartExtended)) {
5666 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5667 return std::nullopt;
5668 }
5669
5670 // The Step is always Signed (because the overflow checks are either
5671 // NSSW or NUSW)
5672 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5673 if (PredIsKnownFalse(Accum, AccumExtended)) {
5674 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5675 return std::nullopt;
5676 }
5677
5678 auto AppendPredicate = [&](const SCEV *Expr,
5679 const SCEV *ExtendedExpr) -> void {
5680 if (Expr != ExtendedExpr &&
5681 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5682 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5683 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5684 Predicates.push_back(Pred);
5685 }
5686 };
5687
5688 AppendPredicate(StartVal, StartExtended);
5689 AppendPredicate(Accum, AccumExtended);
5690
5691 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5692 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5693 // into NewAR if it will also add the runtime overflow checks specified in
5694 // Predicates.
5695 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5696
5697 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5698 std::make_pair(NewAR, Predicates);
5699 // Remember the result of the analysis for this SCEV at this locayyytion.
5700 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5701 return PredRewrite;
5702}
5703
5704std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5706 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5707 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5708 if (!L)
5709 return std::nullopt;
5710
5711 // Check to see if we already analyzed this PHI.
5712 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5713 if (I != PredicatedSCEVRewrites.end()) {
5714 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5715 I->second;
5716 // Analysis was done before and failed to create an AddRec:
5717 if (Rewrite.first == SymbolicPHI)
5718 return std::nullopt;
5719 // Analysis was done before and succeeded to create an AddRec under
5720 // a predicate:
5721 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5722 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5723 return Rewrite;
5724 }
5725
5726 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5727 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5728
5729 // Record in the cache that the analysis failed
5730 if (!Rewrite) {
5732 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5733 return std::nullopt;
5734 }
5735
5736 return Rewrite;
5737}
5738
5739// FIXME: This utility is currently required because the Rewriter currently
5740// does not rewrite this expression:
5741// {0, +, (sext ix (trunc iy to ix) to iy)}
5742// into {0, +, %step},
5743// even when the following Equal predicate exists:
5744// "%step == (sext ix (trunc iy to ix) to iy)".
5746 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5747 if (AR1 == AR2)
5748 return true;
5749
5750 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5751 if (Expr1 != Expr2 &&
5752 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5753 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5754 return false;
5755 return true;
5756 };
5757
5758 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5759 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5760 return false;
5761 return true;
5762}
5763
5764/// A helper function for createAddRecFromPHI to handle simple cases.
5765///
5766/// This function tries to find an AddRec expression for the simplest (yet most
5767/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5768/// If it fails, createAddRecFromPHI will use a more general, but slow,
5769/// technique for finding the AddRec expression.
5770const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5771 Value *BEValueV,
5772 Value *StartValueV) {
5773 const Loop *L = LI.getLoopFor(PN->getParent());
5774 assert(L && L->getHeader() == PN->getParent());
5775 assert(BEValueV && StartValueV);
5776
5777 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5778 if (!BO)
5779 return nullptr;
5780
5781 if (BO->Opcode != Instruction::Add)
5782 return nullptr;
5783
5784 const SCEV *Accum = nullptr;
5785 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5786 Accum = getSCEV(BO->RHS);
5787 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5788 Accum = getSCEV(BO->LHS);
5789
5790 if (!Accum)
5791 return nullptr;
5792
5794 if (BO->IsNUW)
5795 Flags = setFlags(Flags, SCEV::FlagNUW);
5796 if (BO->IsNSW)
5797 Flags = setFlags(Flags, SCEV::FlagNSW);
5798
5799 const SCEV *StartVal = getSCEV(StartValueV);
5800 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5801 insertValueToMap(PN, PHISCEV);
5802
5803 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5804 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5806 proveNoWrapViaConstantRanges(AR)));
5807 }
5808
5809 // We can add Flags to the post-inc expression only if we
5810 // know that it is *undefined behavior* for BEValueV to
5811 // overflow.
5812 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5813 assert(isLoopInvariant(Accum, L) &&
5814 "Accum is defined outside L, but is not invariant?");
5815 if (isAddRecNeverPoison(BEInst, L))
5816 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5817 }
5818
5819 return PHISCEV;
5820}
5821
5822const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5823 const Loop *L = LI.getLoopFor(PN->getParent());
5824 if (!L || L->getHeader() != PN->getParent())
5825 return nullptr;
5826
5827 // The loop may have multiple entrances or multiple exits; we can analyze
5828 // this phi as an addrec if it has a unique entry value and a unique
5829 // backedge value.
5830 Value *BEValueV = nullptr, *StartValueV = nullptr;
5831 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5832 Value *V = PN->getIncomingValue(i);
5833 if (L->contains(PN->getIncomingBlock(i))) {
5834 if (!BEValueV) {
5835 BEValueV = V;
5836 } else if (BEValueV != V) {
5837 BEValueV = nullptr;
5838 break;
5839 }
5840 } else if (!StartValueV) {
5841 StartValueV = V;
5842 } else if (StartValueV != V) {
5843 StartValueV = nullptr;
5844 break;
5845 }
5846 }
5847 if (!BEValueV || !StartValueV)
5848 return nullptr;
5849
5850 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5851 "PHI node already processed?");
5852
5853 // First, try to find AddRec expression without creating a fictituos symbolic
5854 // value for PN.
5855 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5856 return S;
5857
5858 // Handle PHI node value symbolically.
5859 const SCEV *SymbolicName = getUnknown(PN);
5860 insertValueToMap(PN, SymbolicName);
5861
5862 // Using this symbolic name for the PHI, analyze the value coming around
5863 // the back-edge.
5864 const SCEV *BEValue = getSCEV(BEValueV);
5865
5866 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5867 // has a special value for the first iteration of the loop.
5868
5869 // If the value coming around the backedge is an add with the symbolic
5870 // value we just inserted, then we found a simple induction variable!
5871 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5872 // If there is a single occurrence of the symbolic value, replace it
5873 // with a recurrence.
5874 unsigned FoundIndex = Add->getNumOperands();
5875 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5876 if (Add->getOperand(i) == SymbolicName)
5877 if (FoundIndex == e) {
5878 FoundIndex = i;
5879 break;
5880 }
5881
5882 if (FoundIndex != Add->getNumOperands()) {
5883 // Create an add with everything but the specified operand.
5885 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5886 if (i != FoundIndex)
5887 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5888 L, *this));
5889 const SCEV *Accum = getAddExpr(Ops);
5890
5891 // This is not a valid addrec if the step amount is varying each
5892 // loop iteration, but is not itself an addrec in this loop.
5893 if (isLoopInvariant(Accum, L) ||
5894 (isa<SCEVAddRecExpr>(Accum) &&
5895 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5897
5898 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5899 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5900 if (BO->IsNUW)
5901 Flags = setFlags(Flags, SCEV::FlagNUW);
5902 if (BO->IsNSW)
5903 Flags = setFlags(Flags, SCEV::FlagNSW);
5904 }
5905 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5906 if (GEP->getOperand(0) == PN) {
5907 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5908 // If the increment has any nowrap flags, then we know the address
5909 // space cannot be wrapped around.
5910 if (NW != GEPNoWrapFlags::none())
5911 Flags = setFlags(Flags, SCEV::FlagNW);
5912 // If the GEP is nuw or nusw with non-negative offset, we know that
5913 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5914 // offset is treated as signed, while the base is unsigned.
5915 if (NW.hasNoUnsignedWrap() ||
5917 Flags = setFlags(Flags, SCEV::FlagNUW);
5918 }
5919
5920 // We cannot transfer nuw and nsw flags from subtraction
5921 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5922 // for instance.
5923 }
5924
5925 const SCEV *StartVal = getSCEV(StartValueV);
5926 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5927
5928 // Okay, for the entire analysis of this edge we assumed the PHI
5929 // to be symbolic. We now need to go back and purge all of the
5930 // entries for the scalars that use the symbolic expression.
5931 forgetMemoizedResults(SymbolicName);
5932 insertValueToMap(PN, PHISCEV);
5933
5934 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5935 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5937 proveNoWrapViaConstantRanges(AR)));
5938 }
5939
5940 // We can add Flags to the post-inc expression only if we
5941 // know that it is *undefined behavior* for BEValueV to
5942 // overflow.
5943 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5944 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5945 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5946
5947 return PHISCEV;
5948 }
5949 }
5950 } else {
5951 // Otherwise, this could be a loop like this:
5952 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5953 // In this case, j = {1,+,1} and BEValue is j.
5954 // Because the other in-value of i (0) fits the evolution of BEValue
5955 // i really is an addrec evolution.
5956 //
5957 // We can generalize this saying that i is the shifted value of BEValue
5958 // by one iteration:
5959 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5960
5961 // Do not allow refinement in rewriting of BEValue.
5962 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5963 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5964 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5965 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5966 const SCEV *StartVal = getSCEV(StartValueV);
5967 if (Start == StartVal) {
5968 // Okay, for the entire analysis of this edge we assumed the PHI
5969 // to be symbolic. We now need to go back and purge all of the
5970 // entries for the scalars that use the symbolic expression.
5971 forgetMemoizedResults(SymbolicName);
5972 insertValueToMap(PN, Shifted);
5973 return Shifted;
5974 }
5975 }
5976 }
5977
5978 // Remove the temporary PHI node SCEV that has been inserted while intending
5979 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5980 // as it will prevent later (possibly simpler) SCEV expressions to be added
5981 // to the ValueExprMap.
5982 eraseValueFromMap(PN);
5983
5984 return nullptr;
5985}
5986
5987// Try to match a control flow sequence that branches out at BI and merges back
5988// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5989// match.
5991 Value *&C, Value *&LHS, Value *&RHS) {
5992 C = BI->getCondition();
5993
5994 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5995 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5996
5997 if (!LeftEdge.isSingleEdge())
5998 return false;
5999
6000 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6001
6002 Use &LeftUse = Merge->getOperandUse(0);
6003 Use &RightUse = Merge->getOperandUse(1);
6004
6005 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6006 LHS = LeftUse;
6007 RHS = RightUse;
6008 return true;
6009 }
6010
6011 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6012 LHS = RightUse;
6013 RHS = LeftUse;
6014 return true;
6015 }
6016
6017 return false;
6018}
6019
6020const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6021 auto IsReachable =
6022 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6023 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6024 // Try to match
6025 //
6026 // br %cond, label %left, label %right
6027 // left:
6028 // br label %merge
6029 // right:
6030 // br label %merge
6031 // merge:
6032 // V = phi [ %x, %left ], [ %y, %right ]
6033 //
6034 // as "select %cond, %x, %y"
6035
6036 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6037 assert(IDom && "At least the entry block should dominate PN");
6038
6039 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6040 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6041
6042 if (BI && BI->isConditional() &&
6043 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6046 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6047 }
6048
6049 return nullptr;
6050}
6051
6052/// Returns SCEV for the first operand of a phi if all phi operands have
6053/// identical opcodes and operands
6054/// eg.
6055/// a: %add = %a + %b
6056/// br %c
6057/// b: %add1 = %a + %b
6058/// br %c
6059/// c: %phi = phi [%add, a], [%add1, b]
6060/// scev(%phi) => scev(%add)
6061const SCEV *
6062ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6063 BinaryOperator *CommonInst = nullptr;
6064 // Check if instructions are identical.
6065 for (Value *Incoming : PN->incoming_values()) {
6066 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6067 if (!IncomingInst)
6068 return nullptr;
6069 if (CommonInst) {
6070 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6071 return nullptr; // Not identical, give up
6072 } else {
6073 // Remember binary operator
6074 CommonInst = IncomingInst;
6075 }
6076 }
6077 if (!CommonInst)
6078 return nullptr;
6079
6080 // Check if SCEV exprs for instructions are identical.
6081 const SCEV *CommonSCEV = getSCEV(CommonInst);
6082 bool SCEVExprsIdentical =
6084 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6085 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6086}
6087
6088const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6089 if (const SCEV *S = createAddRecFromPHI(PN))
6090 return S;
6091
6092 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6093 // phi node for X.
6094 if (Value *V = simplifyInstruction(
6095 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6096 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6097 return getSCEV(V);
6098
6099 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6100 return S;
6101
6102 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6103 return S;
6104
6105 // If it's not a loop phi, we can't handle it yet.
6106 return getUnknown(PN);
6107}
6108
6109bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6110 SCEVTypes RootKind) {
6111 struct FindClosure {
6112 const SCEV *OperandToFind;
6113 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6114 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6115
6116 bool Found = false;
6117
6118 bool canRecurseInto(SCEVTypes Kind) const {
6119 // We can only recurse into the SCEV expression of the same effective type
6120 // as the type of our root SCEV expression, and into zero-extensions.
6121 return RootKind == Kind || NonSequentialRootKind == Kind ||
6122 scZeroExtend == Kind;
6123 };
6124
6125 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6126 : OperandToFind(OperandToFind), RootKind(RootKind),
6127 NonSequentialRootKind(
6129 RootKind)) {}
6130
6131 bool follow(const SCEV *S) {
6132 Found = S == OperandToFind;
6133
6134 return !isDone() && canRecurseInto(S->getSCEVType());
6135 }
6136
6137 bool isDone() const { return Found; }
6138 };
6139
6140 FindClosure FC(OperandToFind, RootKind);
6141 visitAll(Root, FC);
6142 return FC.Found;
6143}
6144
6145std::optional<const SCEV *>
6146ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6147 ICmpInst *Cond,
6148 Value *TrueVal,
6149 Value *FalseVal) {
6150 // Try to match some simple smax or umax patterns.
6151 auto *ICI = Cond;
6152
6153 Value *LHS = ICI->getOperand(0);
6154 Value *RHS = ICI->getOperand(1);
6155
6156 switch (ICI->getPredicate()) {
6157 case ICmpInst::ICMP_SLT:
6158 case ICmpInst::ICMP_SLE:
6159 case ICmpInst::ICMP_ULT:
6160 case ICmpInst::ICMP_ULE:
6161 std::swap(LHS, RHS);
6162 [[fallthrough]];
6163 case ICmpInst::ICMP_SGT:
6164 case ICmpInst::ICMP_SGE:
6165 case ICmpInst::ICMP_UGT:
6166 case ICmpInst::ICMP_UGE:
6167 // a > b ? a+x : b+x -> max(a, b)+x
6168 // a > b ? b+x : a+x -> min(a, b)+x
6170 bool Signed = ICI->isSigned();
6171 const SCEV *LA = getSCEV(TrueVal);
6172 const SCEV *RA = getSCEV(FalseVal);
6173 const SCEV *LS = getSCEV(LHS);
6174 const SCEV *RS = getSCEV(RHS);
6175 if (LA->getType()->isPointerTy()) {
6176 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6177 // Need to make sure we can't produce weird expressions involving
6178 // negated pointers.
6179 if (LA == LS && RA == RS)
6180 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6181 if (LA == RS && RA == LS)
6182 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6183 }
6184 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6185 if (Op->getType()->isPointerTy()) {
6188 return Op;
6189 }
6190 if (Signed)
6191 Op = getNoopOrSignExtend(Op, Ty);
6192 else
6193 Op = getNoopOrZeroExtend(Op, Ty);
6194 return Op;
6195 };
6196 LS = CoerceOperand(LS);
6197 RS = CoerceOperand(RS);
6199 break;
6200 const SCEV *LDiff = getMinusSCEV(LA, LS);
6201 const SCEV *RDiff = getMinusSCEV(RA, RS);
6202 if (LDiff == RDiff)
6203 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6204 LDiff);
6205 LDiff = getMinusSCEV(LA, RS);
6206 RDiff = getMinusSCEV(RA, LS);
6207 if (LDiff == RDiff)
6208 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6209 LDiff);
6210 }
6211 break;
6212 case ICmpInst::ICMP_NE:
6213 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6214 std::swap(TrueVal, FalseVal);
6215 [[fallthrough]];
6216 case ICmpInst::ICMP_EQ:
6217 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6220 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6221 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6222 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6223 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6224 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6225 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6226 return getAddExpr(getUMaxExpr(X, C), Y);
6227 }
6228 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6229 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6230 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6231 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6233 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6234 const SCEV *X = getSCEV(LHS);
6235 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6236 X = ZExt->getOperand();
6237 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6238 const SCEV *FalseValExpr = getSCEV(FalseVal);
6239 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6240 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6241 /*Sequential=*/true);
6242 }
6243 }
6244 break;
6245 default:
6246 break;
6247 }
6248
6249 return std::nullopt;
6250}
6251
6252static std::optional<const SCEV *>
6254 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6255 assert(CondExpr->getType()->isIntegerTy(1) &&
6256 TrueExpr->getType() == FalseExpr->getType() &&
6257 TrueExpr->getType()->isIntegerTy(1) &&
6258 "Unexpected operands of a select.");
6259
6260 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6261 // --> C + (umin_seq cond, x - C)
6262 //
6263 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6264 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6265 // --> C + (umin_seq ~cond, x - C)
6266
6267 // FIXME: while we can't legally model the case where both of the hands
6268 // are fully variable, we only require that the *difference* is constant.
6269 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6270 return std::nullopt;
6271
6272 const SCEV *X, *C;
6273 if (isa<SCEVConstant>(TrueExpr)) {
6274 CondExpr = SE->getNotSCEV(CondExpr);
6275 X = FalseExpr;
6276 C = TrueExpr;
6277 } else {
6278 X = TrueExpr;
6279 C = FalseExpr;
6280 }
6281 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6282 /*Sequential=*/true));
6283}
6284
6285static std::optional<const SCEV *>
6287 Value *FalseVal) {
6288 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6289 return std::nullopt;
6290
6291 const auto *SECond = SE->getSCEV(Cond);
6292 const auto *SETrue = SE->getSCEV(TrueVal);
6293 const auto *SEFalse = SE->getSCEV(FalseVal);
6294 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6295}
6296
6297const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6298 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6299 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6300 assert(TrueVal->getType() == FalseVal->getType() &&
6301 V->getType() == TrueVal->getType() &&
6302 "Types of select hands and of the result must match.");
6303
6304 // For now, only deal with i1-typed `select`s.
6305 if (!V->getType()->isIntegerTy(1))
6306 return getUnknown(V);
6307
6308 if (std::optional<const SCEV *> S =
6309 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6310 return *S;
6311
6312 return getUnknown(V);
6313}
6314
6315const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6316 Value *TrueVal,
6317 Value *FalseVal) {
6318 // Handle "constant" branch or select. This can occur for instance when a
6319 // loop pass transforms an inner loop and moves on to process the outer loop.
6320 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6321 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6322
6323 if (auto *I = dyn_cast<Instruction>(V)) {
6324 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6325 if (std::optional<const SCEV *> S =
6326 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6327 TrueVal, FalseVal))
6328 return *S;
6329 }
6330 }
6331
6332 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6333}
6334
6335/// Expand GEP instructions into add and multiply operations. This allows them
6336/// to be analyzed by regular SCEV code.
6337const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6338 assert(GEP->getSourceElementType()->isSized() &&
6339 "GEP source element type must be sized");
6340
6342 for (Value *Index : GEP->indices())
6343 IndexExprs.push_back(getSCEV(Index));
6344 return getGEPExpr(GEP, IndexExprs);
6345}
6346
6347APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6348 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6349 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6350 return TrailingZeros >= BitWidth
6352 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6353 };
6354 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6355 // The result is GCD of all operands results.
6356 APInt Res = getConstantMultiple(N->getOperand(0));
6357 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6359 Res, getConstantMultiple(N->getOperand(I)));
6360 return Res;
6361 };
6362
6363 switch (S->getSCEVType()) {
6364 case scConstant:
6365 return cast<SCEVConstant>(S)->getAPInt();
6366 case scPtrToInt:
6367 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6368 case scUDivExpr:
6369 case scVScale:
6370 return APInt(BitWidth, 1);
6371 case scTruncate: {
6372 // Only multiples that are a power of 2 will hold after truncation.
6373 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6374 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6375 return GetShiftedByZeros(TZ);
6376 }
6377 case scZeroExtend: {
6378 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6379 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6380 }
6381 case scSignExtend: {
6382 // Only multiples that are a power of 2 will hold after sext.
6383 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6384 uint32_t TZ = getMinTrailingZeros(E->getOperand());
6385 return GetShiftedByZeros(TZ);
6386 }
6387 case scMulExpr: {
6388 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6389 if (M->hasNoUnsignedWrap()) {
6390 // The result is the product of all operand results.
6391 APInt Res = getConstantMultiple(M->getOperand(0));
6392 for (const SCEV *Operand : M->operands().drop_front())
6393 Res = Res * getConstantMultiple(Operand);
6394 return Res;
6395 }
6396
6397 // If there are no wrap guarentees, find the trailing zeros, which is the
6398 // sum of trailing zeros for all its operands.
6399 uint32_t TZ = 0;
6400 for (const SCEV *Operand : M->operands())
6401 TZ += getMinTrailingZeros(Operand);
6402 return GetShiftedByZeros(TZ);
6403 }
6404 case scAddExpr:
6405 case scAddRecExpr: {
6406 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6407 if (N->hasNoUnsignedWrap())
6408 return GetGCDMultiple(N);
6409 // Find the trailing bits, which is the minimum of its operands.
6410 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6411 for (const SCEV *Operand : N->operands().drop_front())
6412 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6413 return GetShiftedByZeros(TZ);
6414 }
6415 case scUMaxExpr:
6416 case scSMaxExpr:
6417 case scUMinExpr:
6418 case scSMinExpr:
6420 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6421 case scUnknown: {
6422 // ask ValueTracking for known bits
6423 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6424 unsigned Known =
6425 computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6426 .countMinTrailingZeros();
6427 return GetShiftedByZeros(Known);
6428 }
6429 case scCouldNotCompute:
6430 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6431 }
6432 llvm_unreachable("Unknown SCEV kind!");
6433}
6434
6436 auto I = ConstantMultipleCache.find(S);
6437 if (I != ConstantMultipleCache.end())
6438 return I->second;
6439
6440 APInt Result = getConstantMultipleImpl(S);
6441 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6442 assert(InsertPair.second && "Should insert a new key");
6443 return InsertPair.first->second;
6444}
6445
6447 APInt Multiple = getConstantMultiple(S);
6448 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6449}
6450
6452 return std::min(getConstantMultiple(S).countTrailingZeros(),
6453 (unsigned)getTypeSizeInBits(S->getType()));
6454}
6455
6456/// Helper method to assign a range to V from metadata present in the IR.
6457static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6459 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6460 return getConstantRangeFromMetadata(*MD);
6461 if (const auto *CB = dyn_cast<CallBase>(V))
6462 if (std::optional<ConstantRange> Range = CB->getRange())
6463 return Range;
6464 }
6465 if (auto *A = dyn_cast<Argument>(V))
6466 if (std::optional<ConstantRange> Range = A->getRange())
6467 return Range;
6468
6469 return std::nullopt;
6470}
6471
6473 SCEV::NoWrapFlags Flags) {
6474 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6475 AddRec->setNoWrapFlags(Flags);
6476 UnsignedRanges.erase(AddRec);
6477 SignedRanges.erase(AddRec);
6478 ConstantMultipleCache.erase(AddRec);
6479 }
6480}
6481
6482ConstantRange ScalarEvolution::
6483getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6484 const DataLayout &DL = getDataLayout();
6485
6486 unsigned BitWidth = getTypeSizeInBits(U->getType());
6487 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6488
6489 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6490 // use information about the trip count to improve our available range. Note
6491 // that the trip count independent cases are already handled by known bits.
6492 // WARNING: The definition of recurrence used here is subtly different than
6493 // the one used by AddRec (and thus most of this file). Step is allowed to
6494 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6495 // and other addrecs in the same loop (for non-affine addrecs). The code
6496 // below intentionally handles the case where step is not loop invariant.
6497 auto *P = dyn_cast<PHINode>(U->getValue());
6498 if (!P)
6499 return FullSet;
6500
6501 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6502 // even the values that are not available in these blocks may come from them,
6503 // and this leads to false-positive recurrence test.
6504 for (auto *Pred : predecessors(P->getParent()))
6505 if (!DT.isReachableFromEntry(Pred))
6506 return FullSet;
6507
6508 BinaryOperator *BO;
6509 Value *Start, *Step;
6510 if (!matchSimpleRecurrence(P, BO, Start, Step))
6511 return FullSet;
6512
6513 // If we found a recurrence in reachable code, we must be in a loop. Note
6514 // that BO might be in some subloop of L, and that's completely okay.
6515 auto *L = LI.getLoopFor(P->getParent());
6516 assert(L && L->getHeader() == P->getParent());
6517 if (!L->contains(BO->getParent()))
6518 // NOTE: This bailout should be an assert instead. However, asserting
6519 // the condition here exposes a case where LoopFusion is querying SCEV
6520 // with malformed loop information during the midst of the transform.
6521 // There doesn't appear to be an obvious fix, so for the moment bailout
6522 // until the caller issue can be fixed. PR49566 tracks the bug.
6523 return FullSet;
6524
6525 // TODO: Extend to other opcodes such as mul, and div
6526 switch (BO->getOpcode()) {
6527 default:
6528 return FullSet;
6529 case Instruction::AShr:
6530 case Instruction::LShr:
6531 case Instruction::Shl:
6532 break;
6533 };
6534
6535 if (BO->getOperand(0) != P)
6536 // TODO: Handle the power function forms some day.
6537 return FullSet;
6538
6539 unsigned TC = getSmallConstantMaxTripCount(L);
6540 if (!TC || TC >= BitWidth)
6541 return FullSet;
6542
6543 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6544 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6545 assert(KnownStart.getBitWidth() == BitWidth &&
6546 KnownStep.getBitWidth() == BitWidth);
6547
6548 // Compute total shift amount, being careful of overflow and bitwidths.
6549 auto MaxShiftAmt = KnownStep.getMaxValue();
6550 APInt TCAP(BitWidth, TC-1);
6551 bool Overflow = false;
6552 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6553 if (Overflow)
6554 return FullSet;
6555
6556 switch (BO->getOpcode()) {
6557 default:
6558 llvm_unreachable("filtered out above");
6559 case Instruction::AShr: {
6560 // For each ashr, three cases:
6561 // shift = 0 => unchanged value
6562 // saturation => 0 or -1
6563 // other => a value closer to zero (of the same sign)
6564 // Thus, the end value is closer to zero than the start.
6565 auto KnownEnd = KnownBits::ashr(KnownStart,
6566 KnownBits::makeConstant(TotalShift));
6567 if (KnownStart.isNonNegative())
6568 // Analogous to lshr (simply not yet canonicalized)
6569 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6570 KnownStart.getMaxValue() + 1);
6571 if (KnownStart.isNegative())
6572 // End >=u Start && End <=s Start
6573 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6574 KnownEnd.getMaxValue() + 1);
6575 break;
6576 }
6577 case Instruction::LShr: {
6578 // For each lshr, three cases:
6579 // shift = 0 => unchanged value
6580 // saturation => 0
6581 // other => a smaller positive number
6582 // Thus, the low end of the unsigned range is the last value produced.
6583 auto KnownEnd = KnownBits::lshr(KnownStart,
6584 KnownBits::makeConstant(TotalShift));
6585 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6586 KnownStart.getMaxValue() + 1);
6587 }
6588 case Instruction::Shl: {
6589 // Iff no bits are shifted out, value increases on every shift.
6590 auto KnownEnd = KnownBits::shl(KnownStart,
6591 KnownBits::makeConstant(TotalShift));
6592 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6593 return ConstantRange(KnownStart.getMinValue(),
6594 KnownEnd.getMaxValue() + 1);
6595 break;
6596 }
6597 };
6598 return FullSet;
6599}
6600
6601const ConstantRange &
6602ScalarEvolution::getRangeRefIter(const SCEV *S,
6603 ScalarEvolution::RangeSignHint SignHint) {
6604 DenseMap<const SCEV *, ConstantRange> &Cache =
6605 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6606 : SignedRanges;
6608 SmallPtrSet<const SCEV *, 8> Seen;
6609
6610 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6611 // SCEVUnknown PHI node.
6612 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6613 if (!Seen.insert(Expr).second)
6614 return;
6615 if (Cache.contains(Expr))
6616 return;
6617 switch (Expr->getSCEVType()) {
6618 case scUnknown:
6619 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6620 break;
6621 [[fallthrough]];
6622 case scConstant:
6623 case scVScale:
6624 case scTruncate:
6625 case scZeroExtend:
6626 case scSignExtend:
6627 case scPtrToInt:
6628 case scAddExpr:
6629 case scMulExpr:
6630 case scUDivExpr:
6631 case scAddRecExpr:
6632 case scUMaxExpr:
6633 case scSMaxExpr:
6634 case scUMinExpr:
6635 case scSMinExpr:
6637 WorkList.push_back(Expr);
6638 break;
6639 case scCouldNotCompute:
6640 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6641 }
6642 };
6643 AddToWorklist(S);
6644
6645 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6646 for (unsigned I = 0; I != WorkList.size(); ++I) {
6647 const SCEV *P = WorkList[I];
6648 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6649 // If it is not a `SCEVUnknown`, just recurse into operands.
6650 if (!UnknownS) {
6651 for (const SCEV *Op : P->operands())
6652 AddToWorklist(Op);
6653 continue;
6654 }
6655 // `SCEVUnknown`'s require special treatment.
6656 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6657 if (!PendingPhiRangesIter.insert(P).second)
6658 continue;
6659 for (auto &Op : reverse(P->operands()))
6660 AddToWorklist(getSCEV(Op));
6661 }
6662 }
6663
6664 if (!WorkList.empty()) {
6665 // Use getRangeRef to compute ranges for items in the worklist in reverse
6666 // order. This will force ranges for earlier operands to be computed before
6667 // their users in most cases.
6668 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6669 getRangeRef(P, SignHint);
6670
6671 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6672 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6673 PendingPhiRangesIter.erase(P);
6674 }
6675 }
6676
6677 return getRangeRef(S, SignHint, 0);
6678}
6679
6680/// Determine the range for a particular SCEV. If SignHint is
6681/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6682/// with a "cleaner" unsigned (resp. signed) representation.
6683const ConstantRange &ScalarEvolution::getRangeRef(
6684 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6685 DenseMap<const SCEV *, ConstantRange> &Cache =
6686 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6687 : SignedRanges;
6689 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6691
6692 // See if we've computed this range already.
6694 if (I != Cache.end())
6695 return I->second;
6696
6697 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6698 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6699
6700 // Switch to iteratively computing the range for S, if it is part of a deeply
6701 // nested expression.
6703 return getRangeRefIter(S, SignHint);
6704
6705 unsigned BitWidth = getTypeSizeInBits(S->getType());
6706 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6707 using OBO = OverflowingBinaryOperator;
6708
6709 // If the value has known zeros, the maximum value will have those known zeros
6710 // as well.
6711 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6712 APInt Multiple = getNonZeroConstantMultiple(S);
6713 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6714 if (!Remainder.isZero())
6715 ConservativeResult =
6716 ConstantRange(APInt::getMinValue(BitWidth),
6717 APInt::getMaxValue(BitWidth) - Remainder + 1);
6718 }
6719 else {
6720 uint32_t TZ = getMinTrailingZeros(S);
6721 if (TZ != 0) {
6722 ConservativeResult = ConstantRange(
6724 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6725 }
6726 }
6727
6728 switch (S->getSCEVType()) {
6729 case scConstant:
6730 llvm_unreachable("Already handled above.");
6731 case scVScale:
6732 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6733 case scTruncate: {
6734 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6735 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6736 return setRange(
6737 Trunc, SignHint,
6738 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6739 }
6740 case scZeroExtend: {
6741 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6742 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6743 return setRange(
6744 ZExt, SignHint,
6745 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6746 }
6747 case scSignExtend: {
6748 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6749 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6750 return setRange(
6751 SExt, SignHint,
6752 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6753 }
6754 case scPtrToInt: {
6755 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6756 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6757 return setRange(PtrToInt, SignHint, X);
6758 }
6759 case scAddExpr: {
6760 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6761 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6762 unsigned WrapType = OBO::AnyWrap;
6763 if (Add->hasNoSignedWrap())
6764 WrapType |= OBO::NoSignedWrap;
6765 if (Add->hasNoUnsignedWrap())
6766 WrapType |= OBO::NoUnsignedWrap;
6767 for (const SCEV *Op : drop_begin(Add->operands()))
6768 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6769 RangeType);
6770 return setRange(Add, SignHint,
6771 ConservativeResult.intersectWith(X, RangeType));
6772 }
6773 case scMulExpr: {
6774 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6775 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6776 for (const SCEV *Op : drop_begin(Mul->operands()))
6777 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6778 return setRange(Mul, SignHint,
6779 ConservativeResult.intersectWith(X, RangeType));
6780 }
6781 case scUDivExpr: {
6782 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6783 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6784 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6785 return setRange(UDiv, SignHint,
6786 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6787 }
6788 case scAddRecExpr: {
6789 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6790 // If there's no unsigned wrap, the value will never be less than its
6791 // initial value.
6792 if (AddRec->hasNoUnsignedWrap()) {
6793 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6794 if (!UnsignedMinValue.isZero())
6795 ConservativeResult = ConservativeResult.intersectWith(
6796 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6797 }
6798
6799 // If there's no signed wrap, and all the operands except initial value have
6800 // the same sign or zero, the value won't ever be:
6801 // 1: smaller than initial value if operands are non negative,
6802 // 2: bigger than initial value if operands are non positive.
6803 // For both cases, value can not cross signed min/max boundary.
6804 if (AddRec->hasNoSignedWrap()) {
6805 bool AllNonNeg = true;
6806 bool AllNonPos = true;
6807 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6808 if (!isKnownNonNegative(AddRec->getOperand(i)))
6809 AllNonNeg = false;
6810 if (!isKnownNonPositive(AddRec->getOperand(i)))
6811 AllNonPos = false;
6812 }
6813 if (AllNonNeg)
6814 ConservativeResult = ConservativeResult.intersectWith(
6817 RangeType);
6818 else if (AllNonPos)
6819 ConservativeResult = ConservativeResult.intersectWith(
6821 getSignedRangeMax(AddRec->getStart()) +
6822 1),
6823 RangeType);
6824 }
6825
6826 // TODO: non-affine addrec
6827 if (AddRec->isAffine()) {
6828 const SCEV *MaxBEScev =
6830 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6831 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6832
6833 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6834 // MaxBECount's active bits are all <= AddRec's bit width.
6835 if (MaxBECount.getBitWidth() > BitWidth &&
6836 MaxBECount.getActiveBits() <= BitWidth)
6837 MaxBECount = MaxBECount.trunc(BitWidth);
6838 else if (MaxBECount.getBitWidth() < BitWidth)
6839 MaxBECount = MaxBECount.zext(BitWidth);
6840
6841 if (MaxBECount.getBitWidth() == BitWidth) {
6842 auto RangeFromAffine = getRangeForAffineAR(
6843 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6844 ConservativeResult =
6845 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6846
6847 auto RangeFromFactoring = getRangeViaFactoring(
6848 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6849 ConservativeResult =
6850 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6851 }
6852 }
6853
6854 // Now try symbolic BE count and more powerful methods.
6856 const SCEV *SymbolicMaxBECount =
6858 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6859 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6860 AddRec->hasNoSelfWrap()) {
6861 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6862 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6863 ConservativeResult =
6864 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6865 }
6866 }
6867 }
6868
6869 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6870 }
6871 case scUMaxExpr:
6872 case scSMaxExpr:
6873 case scUMinExpr:
6874 case scSMinExpr:
6875 case scSequentialUMinExpr: {
6877 switch (S->getSCEVType()) {
6878 case scUMaxExpr:
6879 ID = Intrinsic::umax;
6880 break;
6881 case scSMaxExpr:
6882 ID = Intrinsic::smax;
6883 break;
6884 case scUMinExpr:
6886 ID = Intrinsic::umin;
6887 break;
6888 case scSMinExpr:
6889 ID = Intrinsic::smin;
6890 break;
6891 default:
6892 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6893 }
6894
6895 const auto *NAry = cast<SCEVNAryExpr>(S);
6896 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6897 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6898 X = X.intrinsic(
6899 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6900 return setRange(S, SignHint,
6901 ConservativeResult.intersectWith(X, RangeType));
6902 }
6903 case scUnknown: {
6904 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6905 Value *V = U->getValue();
6906
6907 // Check if the IR explicitly contains !range metadata.
6908 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6909 if (MDRange)
6910 ConservativeResult =
6911 ConservativeResult.intersectWith(*MDRange, RangeType);
6912
6913 // Use facts about recurrences in the underlying IR. Note that add
6914 // recurrences are AddRecExprs and thus don't hit this path. This
6915 // primarily handles shift recurrences.
6916 auto CR = getRangeForUnknownRecurrence(U);
6917 ConservativeResult = ConservativeResult.intersectWith(CR);
6918
6919 // See if ValueTracking can give us a useful range.
6920 const DataLayout &DL = getDataLayout();
6921 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6922 if (Known.getBitWidth() != BitWidth)
6923 Known = Known.zextOrTrunc(BitWidth);
6924
6925 // ValueTracking may be able to compute a tighter result for the number of
6926 // sign bits than for the value of those sign bits.
6927 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6928 if (U->getType()->isPointerTy()) {
6929 // If the pointer size is larger than the index size type, this can cause
6930 // NS to be larger than BitWidth. So compensate for this.
6931 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6932 int ptrIdxDiff = ptrSize - BitWidth;
6933 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6934 NS -= ptrIdxDiff;
6935 }
6936
6937 if (NS > 1) {
6938 // If we know any of the sign bits, we know all of the sign bits.
6939 if (!Known.Zero.getHiBits(NS).isZero())
6940 Known.Zero.setHighBits(NS);
6941 if (!Known.One.getHiBits(NS).isZero())
6942 Known.One.setHighBits(NS);
6943 }
6944
6945 if (Known.getMinValue() != Known.getMaxValue() + 1)
6946 ConservativeResult = ConservativeResult.intersectWith(
6947 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6948 RangeType);
6949 if (NS > 1)
6950 ConservativeResult = ConservativeResult.intersectWith(
6951 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6952 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6953 RangeType);
6954
6955 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6956 // Strengthen the range if the underlying IR value is a
6957 // global/alloca/heap allocation using the size of the object.
6958 bool CanBeNull, CanBeFreed;
6959 uint64_t DerefBytes =
6960 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6961 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6962 // The highest address the object can start is DerefBytes bytes before
6963 // the end (unsigned max value). If this value is not a multiple of the
6964 // alignment, the last possible start value is the next lowest multiple
6965 // of the alignment. Note: The computations below cannot overflow,
6966 // because if they would there's no possible start address for the
6967 // object.
6968 APInt MaxVal =
6969 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6970 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6971 uint64_t Rem = MaxVal.urem(Align);
6972 MaxVal -= APInt(BitWidth, Rem);
6973 APInt MinVal = APInt::getZero(BitWidth);
6974 if (llvm::isKnownNonZero(V, DL))
6975 MinVal = Align;
6976 ConservativeResult = ConservativeResult.intersectWith(
6977 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6978 }
6979 }
6980
6981 // A range of Phi is a subset of union of all ranges of its input.
6982 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6983 // Make sure that we do not run over cycled Phis.
6984 if (PendingPhiRanges.insert(Phi).second) {
6985 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6986
6987 for (const auto &Op : Phi->operands()) {
6988 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6989 RangeFromOps = RangeFromOps.unionWith(OpRange);
6990 // No point to continue if we already have a full set.
6991 if (RangeFromOps.isFullSet())
6992 break;
6993 }
6994 ConservativeResult =
6995 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6996 bool Erased = PendingPhiRanges.erase(Phi);
6997 assert(Erased && "Failed to erase Phi properly?");
6998 (void)Erased;
6999 }
7000 }
7001
7002 // vscale can't be equal to zero
7003 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7004 if (II->getIntrinsicID() == Intrinsic::vscale) {
7005 ConstantRange Disallowed = APInt::getZero(BitWidth);
7006 ConservativeResult = ConservativeResult.difference(Disallowed);
7007 }
7008
7009 return setRange(U, SignHint, std::move(ConservativeResult));
7010 }
7011 case scCouldNotCompute:
7012 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7013 }
7014
7015 return setRange(S, SignHint, std::move(ConservativeResult));
7016}
7017
7018// Given a StartRange, Step and MaxBECount for an expression compute a range of
7019// values that the expression can take. Initially, the expression has a value
7020// from StartRange and then is changed by Step up to MaxBECount times. Signed
7021// argument defines if we treat Step as signed or unsigned.
7023 const ConstantRange &StartRange,
7024 const APInt &MaxBECount,
7025 bool Signed) {
7026 unsigned BitWidth = Step.getBitWidth();
7027 assert(BitWidth == StartRange.getBitWidth() &&
7028 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7029 // If either Step or MaxBECount is 0, then the expression won't change, and we
7030 // just need to return the initial range.
7031 if (Step == 0 || MaxBECount == 0)
7032 return StartRange;
7033
7034 // If we don't know anything about the initial value (i.e. StartRange is
7035 // FullRange), then we don't know anything about the final range either.
7036 // Return FullRange.
7037 if (StartRange.isFullSet())
7038 return ConstantRange::getFull(BitWidth);
7039
7040 // If Step is signed and negative, then we use its absolute value, but we also
7041 // note that we're moving in the opposite direction.
7042 bool Descending = Signed && Step.isNegative();
7043
7044 if (Signed)
7045 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7046 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7047 // This equations hold true due to the well-defined wrap-around behavior of
7048 // APInt.
7049 Step = Step.abs();
7050
7051 // Check if Offset is more than full span of BitWidth. If it is, the
7052 // expression is guaranteed to overflow.
7053 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7054 return ConstantRange::getFull(BitWidth);
7055
7056 // Offset is by how much the expression can change. Checks above guarantee no
7057 // overflow here.
7058 APInt Offset = Step * MaxBECount;
7059
7060 // Minimum value of the final range will match the minimal value of StartRange
7061 // if the expression is increasing and will be decreased by Offset otherwise.
7062 // Maximum value of the final range will match the maximal value of StartRange
7063 // if the expression is decreasing and will be increased by Offset otherwise.
7064 APInt StartLower = StartRange.getLower();
7065 APInt StartUpper = StartRange.getUpper() - 1;
7066 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7067 : (StartUpper + std::move(Offset));
7068
7069 // It's possible that the new minimum/maximum value will fall into the initial
7070 // range (due to wrap around). This means that the expression can take any
7071 // value in this bitwidth, and we have to return full range.
7072 if (StartRange.contains(MovedBoundary))
7073 return ConstantRange::getFull(BitWidth);
7074
7075 APInt NewLower =
7076 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7077 APInt NewUpper =
7078 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7079 NewUpper += 1;
7080
7081 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7082 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7083}
7084
7085ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7086 const SCEV *Step,
7087 const APInt &MaxBECount) {
7088 assert(getTypeSizeInBits(Start->getType()) ==
7089 getTypeSizeInBits(Step->getType()) &&
7090 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7091 "mismatched bit widths");
7092
7093 // First, consider step signed.
7094 ConstantRange StartSRange = getSignedRange(Start);
7095 ConstantRange StepSRange = getSignedRange(Step);
7096
7097 // If Step can be both positive and negative, we need to find ranges for the
7098 // maximum absolute step values in both directions and union them.
7099 ConstantRange SR = getRangeForAffineARHelper(
7100 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7102 StartSRange, MaxBECount,
7103 /* Signed = */ true));
7104
7105 // Next, consider step unsigned.
7106 ConstantRange UR = getRangeForAffineARHelper(
7107 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7108 /* Signed = */ false);
7109
7110 // Finally, intersect signed and unsigned ranges.
7112}
7113
7114ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7115 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7116 ScalarEvolution::RangeSignHint SignHint) {
7117 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7118 assert(AddRec->hasNoSelfWrap() &&
7119 "This only works for non-self-wrapping AddRecs!");
7120 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7121 const SCEV *Step = AddRec->getStepRecurrence(*this);
7122 // Only deal with constant step to save compile time.
7123 if (!isa<SCEVConstant>(Step))
7124 return ConstantRange::getFull(BitWidth);
7125 // Let's make sure that we can prove that we do not self-wrap during
7126 // MaxBECount iterations. We need this because MaxBECount is a maximum
7127 // iteration count estimate, and we might infer nw from some exit for which we
7128 // do not know max exit count (or any other side reasoning).
7129 // TODO: Turn into assert at some point.
7130 if (getTypeSizeInBits(MaxBECount->getType()) >
7131 getTypeSizeInBits(AddRec->getType()))
7132 return ConstantRange::getFull(BitWidth);
7133 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7134 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7135 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7136 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7137 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7138 MaxItersWithoutWrap))
7139 return ConstantRange::getFull(BitWidth);
7140
7141 ICmpInst::Predicate LEPred =
7143 ICmpInst::Predicate GEPred =
7145 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7146
7147 // We know that there is no self-wrap. Let's take Start and End values and
7148 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7149 // the iteration. They either lie inside the range [Min(Start, End),
7150 // Max(Start, End)] or outside it:
7151 //
7152 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7153 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7154 //
7155 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7156 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7157 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7158 // Start <= End and step is positive, or Start >= End and step is negative.
7159 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7160 ConstantRange StartRange = getRangeRef(Start, SignHint);
7161 ConstantRange EndRange = getRangeRef(End, SignHint);
7162 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7163 // If they already cover full iteration space, we will know nothing useful
7164 // even if we prove what we want to prove.
7165 if (RangeBetween.isFullSet())
7166 return RangeBetween;
7167 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7168 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7169 : RangeBetween.isWrappedSet();
7170 if (IsWrappedSet)
7171 return ConstantRange::getFull(BitWidth);
7172
7173 if (isKnownPositive(Step) &&
7174 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7175 return RangeBetween;
7176 if (isKnownNegative(Step) &&
7177 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7178 return RangeBetween;
7179 return ConstantRange::getFull(BitWidth);
7180}
7181
7182ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7183 const SCEV *Step,
7184 const APInt &MaxBECount) {
7185 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7186 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7187
7188 unsigned BitWidth = MaxBECount.getBitWidth();
7189 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7190 getTypeSizeInBits(Step->getType()) == BitWidth &&
7191 "mismatched bit widths");
7192
7193 struct SelectPattern {
7194 Value *Condition = nullptr;
7195 APInt TrueValue;
7196 APInt FalseValue;
7197
7198 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7199 const SCEV *S) {
7200 std::optional<unsigned> CastOp;
7201 APInt Offset(BitWidth, 0);
7202
7204 "Should be!");
7205
7206 // Peel off a constant offset. In the future we could consider being
7207 // smarter here and handle {Start+Step,+,Step} too.
7208 const APInt *Off;
7209 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7210 Offset = *Off;
7211
7212 // Peel off a cast operation
7213 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7214 CastOp = SCast->getSCEVType();
7215 S = SCast->getOperand();
7216 }
7217
7218 using namespace llvm::PatternMatch;
7219
7220 auto *SU = dyn_cast<SCEVUnknown>(S);
7221 const APInt *TrueVal, *FalseVal;
7222 if (!SU ||
7223 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7224 m_APInt(FalseVal)))) {
7225 Condition = nullptr;
7226 return;
7227 }
7228
7229 TrueValue = *TrueVal;
7230 FalseValue = *FalseVal;
7231
7232 // Re-apply the cast we peeled off earlier
7233 if (CastOp)
7234 switch (*CastOp) {
7235 default:
7236 llvm_unreachable("Unknown SCEV cast type!");
7237
7238 case scTruncate:
7239 TrueValue = TrueValue.trunc(BitWidth);
7240 FalseValue = FalseValue.trunc(BitWidth);
7241 break;
7242 case scZeroExtend:
7243 TrueValue = TrueValue.zext(BitWidth);
7244 FalseValue = FalseValue.zext(BitWidth);
7245 break;
7246 case scSignExtend:
7247 TrueValue = TrueValue.sext(BitWidth);
7248 FalseValue = FalseValue.sext(BitWidth);
7249 break;
7250 }
7251
7252 // Re-apply the constant offset we peeled off earlier
7253 TrueValue += Offset;
7254 FalseValue += Offset;
7255 }
7256
7257 bool isRecognized() { return Condition != nullptr; }
7258 };
7259
7260 SelectPattern StartPattern(*this, BitWidth, Start);
7261 if (!StartPattern.isRecognized())
7262 return ConstantRange::getFull(BitWidth);
7263
7264 SelectPattern StepPattern(*this, BitWidth, Step);
7265 if (!StepPattern.isRecognized())
7266 return ConstantRange::getFull(BitWidth);
7267
7268 if (StartPattern.Condition != StepPattern.Condition) {
7269 // We don't handle this case today; but we could, by considering four
7270 // possibilities below instead of two. I'm not sure if there are cases where
7271 // that will help over what getRange already does, though.
7272 return ConstantRange::getFull(BitWidth);
7273 }
7274
7275 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7276 // construct arbitrary general SCEV expressions here. This function is called
7277 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7278 // say) can end up caching a suboptimal value.
7279
7280 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7281 // C2352 and C2512 (otherwise it isn't needed).
7282
7283 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7284 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7285 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7286 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7287
7288 ConstantRange TrueRange =
7289 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7290 ConstantRange FalseRange =
7291 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7292
7293 return TrueRange.unionWith(FalseRange);
7294}
7295
7296SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7297 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7298 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7299
7300 // Return early if there are no flags to propagate to the SCEV.
7302 if (BinOp->hasNoUnsignedWrap())
7304 if (BinOp->hasNoSignedWrap())
7306 if (Flags == SCEV::FlagAnyWrap)
7307 return SCEV::FlagAnyWrap;
7308
7309 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7310}
7311
7312const Instruction *
7313ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7314 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7315 return &*AddRec->getLoop()->getHeader()->begin();
7316 if (auto *U = dyn_cast<SCEVUnknown>(S))
7317 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7318 return I;
7319 return nullptr;
7320}
7321
7322const Instruction *
7323ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7324 bool &Precise) {
7325 Precise = true;
7326 // Do a bounded search of the def relation of the requested SCEVs.
7327 SmallPtrSet<const SCEV *, 16> Visited;
7329 auto pushOp = [&](const SCEV *S) {
7330 if (!Visited.insert(S).second)
7331 return;
7332 // Threshold of 30 here is arbitrary.
7333 if (Visited.size() > 30) {
7334 Precise = false;
7335 return;
7336 }
7337 Worklist.push_back(S);
7338 };
7339
7340 for (const auto *S : Ops)
7341 pushOp(S);
7342
7343 const Instruction *Bound = nullptr;
7344 while (!Worklist.empty()) {
7345 auto *S = Worklist.pop_back_val();
7346 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7347 if (!Bound || DT.dominates(Bound, DefI))
7348 Bound = DefI;
7349 } else {
7350 for (const auto *Op : S->operands())
7351 pushOp(Op);
7352 }
7353 }
7354 return Bound ? Bound : &*F.getEntryBlock().begin();
7355}
7356
7357const Instruction *
7358ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7359 bool Discard;
7360 return getDefiningScopeBound(Ops, Discard);
7361}
7362
7363bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7364 const Instruction *B) {
7365 if (A->getParent() == B->getParent() &&
7367 B->getIterator()))
7368 return true;
7369
7370 auto *BLoop = LI.getLoopFor(B->getParent());
7371 if (BLoop && BLoop->getHeader() == B->getParent() &&
7372 BLoop->getLoopPreheader() == A->getParent() &&
7374 A->getParent()->end()) &&
7375 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7376 B->getIterator()))
7377 return true;
7378 return false;
7379}
7380
7381bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7382 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7383 visitAll(Op, PC);
7384 return PC.MaybePoison.empty();
7385}
7386
7387bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7388 return !SCEVExprContains(Op, [this](const SCEV *S) {
7389 const SCEV *Op1;
7390 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7391 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7392 // is a non-zero constant, we have to assume the UDiv may be UB.
7393 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7394 });
7395}
7396
7397bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7398 // Only proceed if we can prove that I does not yield poison.
7400 return false;
7401
7402 // At this point we know that if I is executed, then it does not wrap
7403 // according to at least one of NSW or NUW. If I is not executed, then we do
7404 // not know if the calculation that I represents would wrap. Multiple
7405 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7406 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7407 // derived from other instructions that map to the same SCEV. We cannot make
7408 // that guarantee for cases where I is not executed. So we need to find a
7409 // upper bound on the defining scope for the SCEV, and prove that I is
7410 // executed every time we enter that scope. When the bounding scope is a
7411 // loop (the common case), this is equivalent to proving I executes on every
7412 // iteration of that loop.
7414 for (const Use &Op : I->operands()) {
7415 // I could be an extractvalue from a call to an overflow intrinsic.
7416 // TODO: We can do better here in some cases.
7417 if (isSCEVable(Op->getType()))
7418 SCEVOps.push_back(getSCEV(Op));
7419 }
7420 auto *DefI = getDefiningScopeBound(SCEVOps);
7421 return isGuaranteedToTransferExecutionTo(DefI, I);
7422}
7423
7424bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7425 // If we know that \c I can never be poison period, then that's enough.
7426 if (isSCEVExprNeverPoison(I))
7427 return true;
7428
7429 // If the loop only has one exit, then we know that, if the loop is entered,
7430 // any instruction dominating that exit will be executed. If any such
7431 // instruction would result in UB, the addrec cannot be poison.
7432 //
7433 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7434 // also handles uses outside the loop header (they just need to dominate the
7435 // single exit).
7436
7437 auto *ExitingBB = L->getExitingBlock();
7438 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7439 return false;
7440
7441 SmallPtrSet<const Value *, 16> KnownPoison;
7443
7444 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7445 // things that are known to be poison under that assumption go on the
7446 // Worklist.
7447 KnownPoison.insert(I);
7448 Worklist.push_back(I);
7449
7450 while (!Worklist.empty()) {
7451 const Instruction *Poison = Worklist.pop_back_val();
7452
7453 for (const Use &U : Poison->uses()) {
7454 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7455 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7456 DT.dominates(PoisonUser->getParent(), ExitingBB))
7457 return true;
7458
7459 if (propagatesPoison(U) && L->contains(PoisonUser))
7460 if (KnownPoison.insert(PoisonUser).second)
7461 Worklist.push_back(PoisonUser);
7462 }
7463 }
7464
7465 return false;
7466}
7467
7468ScalarEvolution::LoopProperties
7469ScalarEvolution::getLoopProperties(const Loop *L) {
7470 using LoopProperties = ScalarEvolution::LoopProperties;
7471
7472 auto Itr = LoopPropertiesCache.find(L);
7473 if (Itr == LoopPropertiesCache.end()) {
7474 auto HasSideEffects = [](Instruction *I) {
7475 if (auto *SI = dyn_cast<StoreInst>(I))
7476 return !SI->isSimple();
7477
7478 if (I->mayThrow())
7479 return true;
7480
7481 // Non-volatile memset / memcpy do not count as side-effect for forward
7482 // progress.
7483 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7484 return false;
7485
7486 return I->mayWriteToMemory();
7487 };
7488
7489 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7490 /*HasNoSideEffects*/ true};
7491
7492 for (auto *BB : L->getBlocks())
7493 for (auto &I : *BB) {
7495 LP.HasNoAbnormalExits = false;
7496 if (HasSideEffects(&I))
7497 LP.HasNoSideEffects = false;
7498 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7499 break; // We're already as pessimistic as we can get.
7500 }
7501
7502 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7503 assert(InsertPair.second && "We just checked!");
7504 Itr = InsertPair.first;
7505 }
7506
7507 return Itr->second;
7508}
7509
7511 // A mustprogress loop without side effects must be finite.
7512 // TODO: The check used here is very conservative. It's only *specific*
7513 // side effects which are well defined in infinite loops.
7514 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7515}
7516
7517const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7518 // Worklist item with a Value and a bool indicating whether all operands have
7519 // been visited already.
7522
7523 Stack.emplace_back(V, true);
7524 Stack.emplace_back(V, false);
7525 while (!Stack.empty()) {
7526 auto E = Stack.pop_back_val();
7527 Value *CurV = E.getPointer();
7528
7529 if (getExistingSCEV(CurV))
7530 continue;
7531
7533 const SCEV *CreatedSCEV = nullptr;
7534 // If all operands have been visited already, create the SCEV.
7535 if (E.getInt()) {
7536 CreatedSCEV = createSCEV(CurV);
7537 } else {
7538 // Otherwise get the operands we need to create SCEV's for before creating
7539 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7540 // just use it.
7541 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7542 }
7543
7544 if (CreatedSCEV) {
7545 insertValueToMap(CurV, CreatedSCEV);
7546 } else {
7547 // Queue CurV for SCEV creation, followed by its's operands which need to
7548 // be constructed first.
7549 Stack.emplace_back(CurV, true);
7550 for (Value *Op : Ops)
7551 Stack.emplace_back(Op, false);
7552 }
7553 }
7554
7555 return getExistingSCEV(V);
7556}
7557
7558const SCEV *
7559ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7560 if (!isSCEVable(V->getType()))
7561 return getUnknown(V);
7562
7563 if (Instruction *I = dyn_cast<Instruction>(V)) {
7564 // Don't attempt to analyze instructions in blocks that aren't
7565 // reachable. Such instructions don't matter, and they aren't required
7566 // to obey basic rules for definitions dominating uses which this
7567 // analysis depends on.
7568 if (!DT.isReachableFromEntry(I->getParent()))
7569 return getUnknown(PoisonValue::get(V->getType()));
7570 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7571 return getConstant(CI);
7572 else if (isa<GlobalAlias>(V))
7573 return getUnknown(V);
7574 else if (!isa<ConstantExpr>(V))
7575 return getUnknown(V);
7576
7578 if (auto BO =
7580 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7581 switch (BO->Opcode) {
7582 case Instruction::Add:
7583 case Instruction::Mul: {
7584 // For additions and multiplications, traverse add/mul chains for which we
7585 // can potentially create a single SCEV, to reduce the number of
7586 // get{Add,Mul}Expr calls.
7587 do {
7588 if (BO->Op) {
7589 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7590 Ops.push_back(BO->Op);
7591 break;
7592 }
7593 }
7594 Ops.push_back(BO->RHS);
7595 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7597 if (!NewBO ||
7598 (BO->Opcode == Instruction::Add &&
7599 (NewBO->Opcode != Instruction::Add &&
7600 NewBO->Opcode != Instruction::Sub)) ||
7601 (BO->Opcode == Instruction::Mul &&
7602 NewBO->Opcode != Instruction::Mul)) {
7603 Ops.push_back(BO->LHS);
7604 break;
7605 }
7606 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7607 // requires a SCEV for the LHS.
7608 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7609 auto *I = dyn_cast<Instruction>(BO->Op);
7610 if (I && programUndefinedIfPoison(I)) {
7611 Ops.push_back(BO->LHS);
7612 break;
7613 }
7614 }
7615 BO = NewBO;
7616 } while (true);
7617 return nullptr;
7618 }
7619 case Instruction::Sub:
7620 case Instruction::UDiv:
7621 case Instruction::URem:
7622 break;
7623 case Instruction::AShr:
7624 case Instruction::Shl:
7625 case Instruction::Xor:
7626 if (!IsConstArg)
7627 return nullptr;
7628 break;
7629 case Instruction::And:
7630 case Instruction::Or:
7631 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7632 return nullptr;
7633 break;
7634 case Instruction::LShr:
7635 return getUnknown(V);
7636 default:
7637 llvm_unreachable("Unhandled binop");
7638 break;
7639 }
7640
7641 Ops.push_back(BO->LHS);
7642 Ops.push_back(BO->RHS);
7643 return nullptr;
7644 }
7645
7646 switch (U->getOpcode()) {
7647 case Instruction::Trunc:
7648 case Instruction::ZExt:
7649 case Instruction::SExt:
7650 case Instruction::PtrToInt:
7651 Ops.push_back(U->getOperand(0));
7652 return nullptr;
7653
7654 case Instruction::BitCast:
7655 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7656 Ops.push_back(U->getOperand(0));
7657 return nullptr;
7658 }
7659 return getUnknown(V);
7660
7661 case Instruction::SDiv:
7662 case Instruction::SRem:
7663 Ops.push_back(U->getOperand(0));
7664 Ops.push_back(U->getOperand(1));
7665 return nullptr;
7666
7667 case Instruction::GetElementPtr:
7668 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7669 "GEP source element type must be sized");
7670 llvm::append_range(Ops, U->operands());
7671 return nullptr;
7672
7673 case Instruction::IntToPtr:
7674 return getUnknown(V);
7675
7676 case Instruction::PHI:
7677 // Keep constructing SCEVs' for phis recursively for now.
7678 return nullptr;
7679
7680 case Instruction::Select: {
7681 // Check if U is a select that can be simplified to a SCEVUnknown.
7682 auto CanSimplifyToUnknown = [this, U]() {
7683 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7684 return false;
7685
7686 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7687 if (!ICI)
7688 return false;
7689 Value *LHS = ICI->getOperand(0);
7690 Value *RHS = ICI->getOperand(1);
7691 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7692 ICI->getPredicate() == CmpInst::ICMP_NE) {
7694 return true;
7695 } else if (getTypeSizeInBits(LHS->getType()) >
7696 getTypeSizeInBits(U->getType()))
7697 return true;
7698 return false;
7699 };
7700 if (CanSimplifyToUnknown())
7701 return getUnknown(U);
7702
7703 llvm::append_range(Ops, U->operands());
7704 return nullptr;
7705 break;
7706 }
7707 case Instruction::Call:
7708 case Instruction::Invoke:
7709 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7710 Ops.push_back(RV);
7711 return nullptr;
7712 }
7713
7714 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7715 switch (II->getIntrinsicID()) {
7716 case Intrinsic::abs:
7717 Ops.push_back(II->getArgOperand(0));
7718 return nullptr;
7719 case Intrinsic::umax:
7720 case Intrinsic::umin:
7721 case Intrinsic::smax:
7722 case Intrinsic::smin:
7723 case Intrinsic::usub_sat:
7724 case Intrinsic::uadd_sat:
7725 Ops.push_back(II->getArgOperand(0));
7726 Ops.push_back(II->getArgOperand(1));
7727 return nullptr;
7728 case Intrinsic::start_loop_iterations:
7729 case Intrinsic::annotation:
7730 case Intrinsic::ptr_annotation:
7731 Ops.push_back(II->getArgOperand(0));
7732 return nullptr;
7733 default:
7734 break;
7735 }
7736 }
7737 break;
7738 }
7739
7740 return nullptr;
7741}
7742
7743const SCEV *ScalarEvolution::createSCEV(Value *V) {
7744 if (!isSCEVable(V->getType()))
7745 return getUnknown(V);
7746
7747 if (Instruction *I = dyn_cast<Instruction>(V)) {
7748 // Don't attempt to analyze instructions in blocks that aren't
7749 // reachable. Such instructions don't matter, and they aren't required
7750 // to obey basic rules for definitions dominating uses which this
7751 // analysis depends on.
7752 if (!DT.isReachableFromEntry(I->getParent()))
7753 return getUnknown(PoisonValue::get(V->getType()));
7754 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7755 return getConstant(CI);
7756 else if (isa<GlobalAlias>(V))
7757 return getUnknown(V);
7758 else if (!isa<ConstantExpr>(V))
7759 return getUnknown(V);
7760
7761 const SCEV *LHS;
7762 const SCEV *RHS;
7763
7765 if (auto BO =
7767 switch (BO->Opcode) {
7768 case Instruction::Add: {
7769 // The simple thing to do would be to just call getSCEV on both operands
7770 // and call getAddExpr with the result. However if we're looking at a
7771 // bunch of things all added together, this can be quite inefficient,
7772 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7773 // Instead, gather up all the operands and make a single getAddExpr call.
7774 // LLVM IR canonical form means we need only traverse the left operands.
7776 do {
7777 if (BO->Op) {
7778 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7779 AddOps.push_back(OpSCEV);
7780 break;
7781 }
7782
7783 // If a NUW or NSW flag can be applied to the SCEV for this
7784 // addition, then compute the SCEV for this addition by itself
7785 // with a separate call to getAddExpr. We need to do that
7786 // instead of pushing the operands of the addition onto AddOps,
7787 // since the flags are only known to apply to this particular
7788 // addition - they may not apply to other additions that can be
7789 // formed with operands from AddOps.
7790 const SCEV *RHS = getSCEV(BO->RHS);
7791 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7792 if (Flags != SCEV::FlagAnyWrap) {
7793 const SCEV *LHS = getSCEV(BO->LHS);
7794 if (BO->Opcode == Instruction::Sub)
7795 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7796 else
7797 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7798 break;
7799 }
7800 }
7801
7802 if (BO->Opcode == Instruction::Sub)
7803 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7804 else
7805 AddOps.push_back(getSCEV(BO->RHS));
7806
7807 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7809 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7810 NewBO->Opcode != Instruction::Sub)) {
7811 AddOps.push_back(getSCEV(BO->LHS));
7812 break;
7813 }
7814 BO = NewBO;
7815 } while (true);
7816
7817 return getAddExpr(AddOps);
7818 }
7819
7820 case Instruction::Mul: {
7822 do {
7823 if (BO->Op) {
7824 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7825 MulOps.push_back(OpSCEV);
7826 break;
7827 }
7828
7829 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7830 if (Flags != SCEV::FlagAnyWrap) {
7831 LHS = getSCEV(BO->LHS);
7832 RHS = getSCEV(BO->RHS);
7833 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7834 break;
7835 }
7836 }
7837
7838 MulOps.push_back(getSCEV(BO->RHS));
7839 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7841 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7842 MulOps.push_back(getSCEV(BO->LHS));
7843 break;
7844 }
7845 BO = NewBO;
7846 } while (true);
7847
7848 return getMulExpr(MulOps);
7849 }
7850 case Instruction::UDiv:
7851 LHS = getSCEV(BO->LHS);
7852 RHS = getSCEV(BO->RHS);
7853 return getUDivExpr(LHS, RHS);
7854 case Instruction::URem:
7855 LHS = getSCEV(BO->LHS);
7856 RHS = getSCEV(BO->RHS);
7857 return getURemExpr(LHS, RHS);
7858 case Instruction::Sub: {
7860 if (BO->Op)
7861 Flags = getNoWrapFlagsFromUB(BO->Op);
7862 LHS = getSCEV(BO->LHS);
7863 RHS = getSCEV(BO->RHS);
7864 return getMinusSCEV(LHS, RHS, Flags);
7865 }
7866 case Instruction::And:
7867 // For an expression like x&255 that merely masks off the high bits,
7868 // use zext(trunc(x)) as the SCEV expression.
7869 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7870 if (CI->isZero())
7871 return getSCEV(BO->RHS);
7872 if (CI->isMinusOne())
7873 return getSCEV(BO->LHS);
7874 const APInt &A = CI->getValue();
7875
7876 // Instcombine's ShrinkDemandedConstant may strip bits out of
7877 // constants, obscuring what would otherwise be a low-bits mask.
7878 // Use computeKnownBits to compute what ShrinkDemandedConstant
7879 // knew about to reconstruct a low-bits mask value.
7880 unsigned LZ = A.countl_zero();
7881 unsigned TZ = A.countr_zero();
7882 unsigned BitWidth = A.getBitWidth();
7883 KnownBits Known(BitWidth);
7884 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7885
7886 APInt EffectiveMask =
7887 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7888 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7889 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7890 const SCEV *LHS = getSCEV(BO->LHS);
7891 const SCEV *ShiftedLHS = nullptr;
7892 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7893 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7894 // For an expression like (x * 8) & 8, simplify the multiply.
7895 unsigned MulZeros = OpC->getAPInt().countr_zero();
7896 unsigned GCD = std::min(MulZeros, TZ);
7897 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7899 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7900 append_range(MulOps, LHSMul->operands().drop_front());
7901 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7902 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7903 }
7904 }
7905 if (!ShiftedLHS)
7906 ShiftedLHS = getUDivExpr(LHS, MulCount);
7907 return getMulExpr(
7909 getTruncateExpr(ShiftedLHS,
7910 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7911 BO->LHS->getType()),
7912 MulCount);
7913 }
7914 }
7915 // Binary `and` is a bit-wise `umin`.
7916 if (BO->LHS->getType()->isIntegerTy(1)) {
7917 LHS = getSCEV(BO->LHS);
7918 RHS = getSCEV(BO->RHS);
7919 return getUMinExpr(LHS, RHS);
7920 }
7921 break;
7922
7923 case Instruction::Or:
7924 // Binary `or` is a bit-wise `umax`.
7925 if (BO->LHS->getType()->isIntegerTy(1)) {
7926 LHS = getSCEV(BO->LHS);
7927 RHS = getSCEV(BO->RHS);
7928 return getUMaxExpr(LHS, RHS);
7929 }
7930 break;
7931
7932 case Instruction::Xor:
7933 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7934 // If the RHS of xor is -1, then this is a not operation.
7935 if (CI->isMinusOne())
7936 return getNotSCEV(getSCEV(BO->LHS));
7937
7938 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7939 // This is a variant of the check for xor with -1, and it handles
7940 // the case where instcombine has trimmed non-demanded bits out
7941 // of an xor with -1.
7942 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7943 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7944 if (LBO->getOpcode() == Instruction::And &&
7945 LCI->getValue() == CI->getValue())
7946 if (const SCEVZeroExtendExpr *Z =
7948 Type *UTy = BO->LHS->getType();
7949 const SCEV *Z0 = Z->getOperand();
7950 Type *Z0Ty = Z0->getType();
7951 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7952
7953 // If C is a low-bits mask, the zero extend is serving to
7954 // mask off the high bits. Complement the operand and
7955 // re-apply the zext.
7956 if (CI->getValue().isMask(Z0TySize))
7957 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7958
7959 // If C is a single bit, it may be in the sign-bit position
7960 // before the zero-extend. In this case, represent the xor
7961 // using an add, which is equivalent, and re-apply the zext.
7962 APInt Trunc = CI->getValue().trunc(Z0TySize);
7963 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7964 Trunc.isSignMask())
7965 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7966 UTy);
7967 }
7968 }
7969 break;
7970
7971 case Instruction::Shl:
7972 // Turn shift left of a constant amount into a multiply.
7973 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7974 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7975
7976 // If the shift count is not less than the bitwidth, the result of
7977 // the shift is undefined. Don't try to analyze it, because the
7978 // resolution chosen here may differ from the resolution chosen in
7979 // other parts of the compiler.
7980 if (SA->getValue().uge(BitWidth))
7981 break;
7982
7983 // We can safely preserve the nuw flag in all cases. It's also safe to
7984 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7985 // requires special handling. It can be preserved as long as we're not
7986 // left shifting by bitwidth - 1.
7987 auto Flags = SCEV::FlagAnyWrap;
7988 if (BO->Op) {
7989 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7990 if ((MulFlags & SCEV::FlagNSW) &&
7991 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7993 if (MulFlags & SCEV::FlagNUW)
7995 }
7996
7997 ConstantInt *X = ConstantInt::get(
7998 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7999 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8000 }
8001 break;
8002
8003 case Instruction::AShr:
8004 // AShr X, C, where C is a constant.
8005 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8006 if (!CI)
8007 break;
8008
8009 Type *OuterTy = BO->LHS->getType();
8010 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8011 // If the shift count is not less than the bitwidth, the result of
8012 // the shift is undefined. Don't try to analyze it, because the
8013 // resolution chosen here may differ from the resolution chosen in
8014 // other parts of the compiler.
8015 if (CI->getValue().uge(BitWidth))
8016 break;
8017
8018 if (CI->isZero())
8019 return getSCEV(BO->LHS); // shift by zero --> noop
8020
8021 uint64_t AShrAmt = CI->getZExtValue();
8022 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8023
8024 Operator *L = dyn_cast<Operator>(BO->LHS);
8025 const SCEV *AddTruncateExpr = nullptr;
8026 ConstantInt *ShlAmtCI = nullptr;
8027 const SCEV *AddConstant = nullptr;
8028
8029 if (L && L->getOpcode() == Instruction::Add) {
8030 // X = Shl A, n
8031 // Y = Add X, c
8032 // Z = AShr Y, m
8033 // n, c and m are constants.
8034
8035 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8036 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8037 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8038 if (AddOperandCI) {
8039 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8040 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8041 // since we truncate to TruncTy, the AddConstant should be of the
8042 // same type, so create a new Constant with type same as TruncTy.
8043 // Also, the Add constant should be shifted right by AShr amount.
8044 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8045 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8046 // we model the expression as sext(add(trunc(A), c << n)), since the
8047 // sext(trunc) part is already handled below, we create a
8048 // AddExpr(TruncExp) which will be used later.
8049 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8050 }
8051 }
8052 } else if (L && L->getOpcode() == Instruction::Shl) {
8053 // X = Shl A, n
8054 // Y = AShr X, m
8055 // Both n and m are constant.
8056
8057 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8058 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8059 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8060 }
8061
8062 if (AddTruncateExpr && ShlAmtCI) {
8063 // We can merge the two given cases into a single SCEV statement,
8064 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8065 // a simpler case. The following code handles the two cases:
8066 //
8067 // 1) For a two-shift sext-inreg, i.e. n = m,
8068 // use sext(trunc(x)) as the SCEV expression.
8069 //
8070 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8071 // expression. We already checked that ShlAmt < BitWidth, so
8072 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8073 // ShlAmt - AShrAmt < Amt.
8074 const APInt &ShlAmt = ShlAmtCI->getValue();
8075 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8076 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8077 ShlAmtCI->getZExtValue() - AShrAmt);
8078 const SCEV *CompositeExpr =
8079 getMulExpr(AddTruncateExpr, getConstant(Mul));
8080 if (L->getOpcode() != Instruction::Shl)
8081 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8082
8083 return getSignExtendExpr(CompositeExpr, OuterTy);
8084 }
8085 }
8086 break;
8087 }
8088 }
8089
8090 switch (U->getOpcode()) {
8091 case Instruction::Trunc:
8092 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8093
8094 case Instruction::ZExt:
8095 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8096
8097 case Instruction::SExt:
8098 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8100 // The NSW flag of a subtract does not always survive the conversion to
8101 // A + (-1)*B. By pushing sign extension onto its operands we are much
8102 // more likely to preserve NSW and allow later AddRec optimisations.
8103 //
8104 // NOTE: This is effectively duplicating this logic from getSignExtend:
8105 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8106 // but by that point the NSW information has potentially been lost.
8107 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8108 Type *Ty = U->getType();
8109 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8110 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8111 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8112 }
8113 }
8114 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8115
8116 case Instruction::BitCast:
8117 // BitCasts are no-op casts so we just eliminate the cast.
8118 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8119 return getSCEV(U->getOperand(0));
8120 break;
8121
8122 case Instruction::PtrToInt: {
8123 // Pointer to integer cast is straight-forward, so do model it.
8124 const SCEV *Op = getSCEV(U->getOperand(0));
8125 Type *DstIntTy = U->getType();
8126 // But only if effective SCEV (integer) type is wide enough to represent
8127 // all possible pointer values.
8128 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8129 if (isa<SCEVCouldNotCompute>(IntOp))
8130 return getUnknown(V);
8131 return IntOp;
8132 }
8133 case Instruction::IntToPtr:
8134 // Just don't deal with inttoptr casts.
8135 return getUnknown(V);
8136
8137 case Instruction::SDiv:
8138 // If both operands are non-negative, this is just an udiv.
8139 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8140 isKnownNonNegative(getSCEV(U->getOperand(1))))
8141 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8142 break;
8143
8144 case Instruction::SRem:
8145 // If both operands are non-negative, this is just an urem.
8146 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8147 isKnownNonNegative(getSCEV(U->getOperand(1))))
8148 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8149 break;
8150
8151 case Instruction::GetElementPtr:
8152 return createNodeForGEP(cast<GEPOperator>(U));
8153
8154 case Instruction::PHI:
8155 return createNodeForPHI(cast<PHINode>(U));
8156
8157 case Instruction::Select:
8158 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8159 U->getOperand(2));
8160
8161 case Instruction::Call:
8162 case Instruction::Invoke:
8163 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8164 return getSCEV(RV);
8165
8166 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8167 switch (II->getIntrinsicID()) {
8168 case Intrinsic::abs:
8169 return getAbsExpr(
8170 getSCEV(II->getArgOperand(0)),
8171 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8172 case Intrinsic::umax:
8173 LHS = getSCEV(II->getArgOperand(0));
8174 RHS = getSCEV(II->getArgOperand(1));
8175 return getUMaxExpr(LHS, RHS);
8176 case Intrinsic::umin:
8177 LHS = getSCEV(II->getArgOperand(0));
8178 RHS = getSCEV(II->getArgOperand(1));
8179 return getUMinExpr(LHS, RHS);
8180 case Intrinsic::smax:
8181 LHS = getSCEV(II->getArgOperand(0));
8182 RHS = getSCEV(II->getArgOperand(1));
8183 return getSMaxExpr(LHS, RHS);
8184 case Intrinsic::smin:
8185 LHS = getSCEV(II->getArgOperand(0));
8186 RHS = getSCEV(II->getArgOperand(1));
8187 return getSMinExpr(LHS, RHS);
8188 case Intrinsic::usub_sat: {
8189 const SCEV *X = getSCEV(II->getArgOperand(0));
8190 const SCEV *Y = getSCEV(II->getArgOperand(1));
8191 const SCEV *ClampedY = getUMinExpr(X, Y);
8192 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8193 }
8194 case Intrinsic::uadd_sat: {
8195 const SCEV *X = getSCEV(II->getArgOperand(0));
8196 const SCEV *Y = getSCEV(II->getArgOperand(1));
8197 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8198 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8199 }
8200 case Intrinsic::start_loop_iterations:
8201 case Intrinsic::annotation:
8202 case Intrinsic::ptr_annotation:
8203 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8204 // just eqivalent to the first operand for SCEV purposes.
8205 return getSCEV(II->getArgOperand(0));
8206 case Intrinsic::vscale:
8207 return getVScale(II->getType());
8208 default:
8209 break;
8210 }
8211 }
8212 break;
8213 }
8214
8215 return getUnknown(V);
8216}
8217
8218//===----------------------------------------------------------------------===//
8219// Iteration Count Computation Code
8220//
8221
8223 if (isa<SCEVCouldNotCompute>(ExitCount))
8224 return getCouldNotCompute();
8225
8226 auto *ExitCountType = ExitCount->getType();
8227 assert(ExitCountType->isIntegerTy());
8228 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8229 1 + ExitCountType->getScalarSizeInBits());
8230 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8231}
8232
8234 Type *EvalTy,
8235 const Loop *L) {
8236 if (isa<SCEVCouldNotCompute>(ExitCount))
8237 return getCouldNotCompute();
8238
8239 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8240 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8241
8242 auto CanAddOneWithoutOverflow = [&]() {
8243 ConstantRange ExitCountRange =
8244 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8245 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8246 return true;
8247
8248 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8249 getMinusOne(ExitCount->getType()));
8250 };
8251
8252 // If we need to zero extend the backedge count, check if we can add one to
8253 // it prior to zero extending without overflow. Provided this is safe, it
8254 // allows better simplification of the +1.
8255 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8256 return getZeroExtendExpr(
8257 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8258
8259 // Get the total trip count from the count by adding 1. This may wrap.
8260 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8261}
8262
8263static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8264 if (!ExitCount)
8265 return 0;
8266
8267 ConstantInt *ExitConst = ExitCount->getValue();
8268
8269 // Guard against huge trip counts.
8270 if (ExitConst->getValue().getActiveBits() > 32)
8271 return 0;
8272
8273 // In case of integer overflow, this returns 0, which is correct.
8274 return ((unsigned)ExitConst->getZExtValue()) + 1;
8275}
8276
8278 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8279 return getConstantTripCount(ExitCount);
8280}
8281
8282unsigned
8284 const BasicBlock *ExitingBlock) {
8285 assert(ExitingBlock && "Must pass a non-null exiting block!");
8286 assert(L->isLoopExiting(ExitingBlock) &&
8287 "Exiting block must actually branch out of the loop!");
8288 const SCEVConstant *ExitCount =
8289 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8290 return getConstantTripCount(ExitCount);
8291}
8292
8294 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8295
8296 const auto *MaxExitCount =
8297 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8299 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8300}
8301
8303 SmallVector<BasicBlock *, 8> ExitingBlocks;
8304 L->getExitingBlocks(ExitingBlocks);
8305
8306 std::optional<unsigned> Res;
8307 for (auto *ExitingBB : ExitingBlocks) {
8308 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8309 if (!Res)
8310 Res = Multiple;
8311 Res = std::gcd(*Res, Multiple);
8312 }
8313 return Res.value_or(1);
8314}
8315
8317 const SCEV *ExitCount) {
8318 if (isa<SCEVCouldNotCompute>(ExitCount))
8319 return 1;
8320
8321 // Get the trip count
8322 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8323
8324 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8325 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8326 // the greatest power of 2 divisor less than 2^32.
8327 return Multiple.getActiveBits() > 32
8328 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8329 : (unsigned)Multiple.getZExtValue();
8330}
8331
8332/// Returns the largest constant divisor of the trip count of this loop as a
8333/// normal unsigned value, if possible. This means that the actual trip count is
8334/// always a multiple of the returned value (don't forget the trip count could
8335/// very well be zero as well!).
8336///
8337/// Returns 1 if the trip count is unknown or not guaranteed to be the
8338/// multiple of a constant (which is also the case if the trip count is simply
8339/// constant, use getSmallConstantTripCount for that case), Will also return 1
8340/// if the trip count is very large (>= 2^32).
8341///
8342/// As explained in the comments for getSmallConstantTripCount, this assumes
8343/// that control exits the loop via ExitingBlock.
8344unsigned
8346 const BasicBlock *ExitingBlock) {
8347 assert(ExitingBlock && "Must pass a non-null exiting block!");
8348 assert(L->isLoopExiting(ExitingBlock) &&
8349 "Exiting block must actually branch out of the loop!");
8350 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8351 return getSmallConstantTripMultiple(L, ExitCount);
8352}
8353
8355 const BasicBlock *ExitingBlock,
8356 ExitCountKind Kind) {
8357 switch (Kind) {
8358 case Exact:
8359 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8360 case SymbolicMaximum:
8361 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8362 case ConstantMaximum:
8363 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8364 };
8365 llvm_unreachable("Invalid ExitCountKind!");
8366}
8367
8369 const Loop *L, const BasicBlock *ExitingBlock,
8371 switch (Kind) {
8372 case Exact:
8373 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8374 Predicates);
8375 case SymbolicMaximum:
8376 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8377 Predicates);
8378 case ConstantMaximum:
8379 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8380 Predicates);
8381 };
8382 llvm_unreachable("Invalid ExitCountKind!");
8383}
8384
8387 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8388}
8389
8391 ExitCountKind Kind) {
8392 switch (Kind) {
8393 case Exact:
8394 return getBackedgeTakenInfo(L).getExact(L, this);
8395 case ConstantMaximum:
8396 return getBackedgeTakenInfo(L).getConstantMax(this);
8397 case SymbolicMaximum:
8398 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8399 };
8400 llvm_unreachable("Invalid ExitCountKind!");
8401}
8402
8405 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8406}
8407
8410 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8411}
8412
8414 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8415}
8416
8417/// Push PHI nodes in the header of the given loop onto the given Worklist.
8418static void PushLoopPHIs(const Loop *L,
8421 BasicBlock *Header = L->getHeader();
8422
8423 // Push all Loop-header PHIs onto the Worklist stack.
8424 for (PHINode &PN : Header->phis())
8425 if (Visited.insert(&PN).second)
8426 Worklist.push_back(&PN);
8427}
8428
8429ScalarEvolution::BackedgeTakenInfo &
8430ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8431 auto &BTI = getBackedgeTakenInfo(L);
8432 if (BTI.hasFullInfo())
8433 return BTI;
8434
8435 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8436
8437 if (!Pair.second)
8438 return Pair.first->second;
8439
8440 BackedgeTakenInfo Result =
8441 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8442
8443 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8444}
8445
8446ScalarEvolution::BackedgeTakenInfo &
8447ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8448 // Initially insert an invalid entry for this loop. If the insertion
8449 // succeeds, proceed to actually compute a backedge-taken count and
8450 // update the value. The temporary CouldNotCompute value tells SCEV
8451 // code elsewhere that it shouldn't attempt to request a new
8452 // backedge-taken count, which could result in infinite recursion.
8453 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8454 BackedgeTakenCounts.try_emplace(L);
8455 if (!Pair.second)
8456 return Pair.first->second;
8457
8458 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8459 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8460 // must be cleared in this scope.
8461 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8462
8463 // Now that we know more about the trip count for this loop, forget any
8464 // existing SCEV values for PHI nodes in this loop since they are only
8465 // conservative estimates made without the benefit of trip count
8466 // information. This invalidation is not necessary for correctness, and is
8467 // only done to produce more precise results.
8468 if (Result.hasAnyInfo()) {
8469 // Invalidate any expression using an addrec in this loop.
8471 auto LoopUsersIt = LoopUsers.find(L);
8472 if (LoopUsersIt != LoopUsers.end())
8473 append_range(ToForget, LoopUsersIt->second);
8474 forgetMemoizedResults(ToForget);
8475
8476 // Invalidate constant-evolved loop header phis.
8477 for (PHINode &PN : L->getHeader()->phis())
8478 ConstantEvolutionLoopExitValue.erase(&PN);
8479 }
8480
8481 // Re-lookup the insert position, since the call to
8482 // computeBackedgeTakenCount above could result in a
8483 // recusive call to getBackedgeTakenInfo (on a different
8484 // loop), which would invalidate the iterator computed
8485 // earlier.
8486 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8487}
8488
8490 // This method is intended to forget all info about loops. It should
8491 // invalidate caches as if the following happened:
8492 // - The trip counts of all loops have changed arbitrarily
8493 // - Every llvm::Value has been updated in place to produce a different
8494 // result.
8495 BackedgeTakenCounts.clear();
8496 PredicatedBackedgeTakenCounts.clear();
8497 BECountUsers.clear();
8498 LoopPropertiesCache.clear();
8499 ConstantEvolutionLoopExitValue.clear();
8500 ValueExprMap.clear();
8501 ValuesAtScopes.clear();
8502 ValuesAtScopesUsers.clear();
8503 LoopDispositions.clear();
8504 BlockDispositions.clear();
8505 UnsignedRanges.clear();
8506 SignedRanges.clear();
8507 ExprValueMap.clear();
8508 HasRecMap.clear();
8509 ConstantMultipleCache.clear();
8510 PredicatedSCEVRewrites.clear();
8511 FoldCache.clear();
8512 FoldCacheUser.clear();
8513}
8514void ScalarEvolution::visitAndClearUsers(
8518 while (!Worklist.empty()) {
8519 Instruction *I = Worklist.pop_back_val();
8520 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8521 continue;
8522
8524 ValueExprMap.find_as(static_cast<Value *>(I));
8525 if (It != ValueExprMap.end()) {
8526 eraseValueFromMap(It->first);
8527 ToForget.push_back(It->second);
8528 if (PHINode *PN = dyn_cast<PHINode>(I))
8529 ConstantEvolutionLoopExitValue.erase(PN);
8530 }
8531
8532 PushDefUseChildren(I, Worklist, Visited);
8533 }
8534}
8535
8537 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8541
8542 // Iterate over all the loops and sub-loops to drop SCEV information.
8543 while (!LoopWorklist.empty()) {
8544 auto *CurrL = LoopWorklist.pop_back_val();
8545
8546 // Drop any stored trip count value.
8547 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8548 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8549
8550 // Drop information about predicated SCEV rewrites for this loop.
8551 for (auto I = PredicatedSCEVRewrites.begin();
8552 I != PredicatedSCEVRewrites.end();) {
8553 std::pair<const SCEV *, const Loop *> Entry = I->first;
8554 if (Entry.second == CurrL)
8555 PredicatedSCEVRewrites.erase(I++);
8556 else
8557 ++I;
8558 }
8559
8560 auto LoopUsersItr = LoopUsers.find(CurrL);
8561 if (LoopUsersItr != LoopUsers.end())
8562 llvm::append_range(ToForget, LoopUsersItr->second);
8563
8564 // Drop information about expressions based on loop-header PHIs.
8565 PushLoopPHIs(CurrL, Worklist, Visited);
8566 visitAndClearUsers(Worklist, Visited, ToForget);
8567
8568 LoopPropertiesCache.erase(CurrL);
8569 // Forget all contained loops too, to avoid dangling entries in the
8570 // ValuesAtScopes map.
8571 LoopWorklist.append(CurrL->begin(), CurrL->end());
8572 }
8573 forgetMemoizedResults(ToForget);
8574}
8575
8577 forgetLoop(L->getOutermostLoop());
8578}
8579
8582 if (!I) return;
8583
8584 // Drop information about expressions based on loop-header PHIs.
8588 Worklist.push_back(I);
8589 Visited.insert(I);
8590 visitAndClearUsers(Worklist, Visited, ToForget);
8591
8592 forgetMemoizedResults(ToForget);
8593}
8594
8596 if (!isSCEVable(V->getType()))
8597 return;
8598
8599 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8600 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8601 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8602 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8603 if (const SCEV *S = getExistingSCEV(V)) {
8604 struct InvalidationRootCollector {
8605 Loop *L;
8607
8608 InvalidationRootCollector(Loop *L) : L(L) {}
8609
8610 bool follow(const SCEV *S) {
8611 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8612 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8613 if (L->contains(I))
8614 Roots.push_back(S);
8615 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8616 if (L->contains(AddRec->getLoop()))
8617 Roots.push_back(S);
8618 }
8619 return true;
8620 }
8621 bool isDone() const { return false; }
8622 };
8623
8624 InvalidationRootCollector C(L);
8625 visitAll(S, C);
8626 forgetMemoizedResults(C.Roots);
8627 }
8628
8629 // Also perform the normal invalidation.
8630 forgetValue(V);
8631}
8632
8633void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8634
8636 // Unless a specific value is passed to invalidation, completely clear both
8637 // caches.
8638 if (!V) {
8639 BlockDispositions.clear();
8640 LoopDispositions.clear();
8641 return;
8642 }
8643
8644 if (!isSCEVable(V->getType()))
8645 return;
8646
8647 const SCEV *S = getExistingSCEV(V);
8648 if (!S)
8649 return;
8650
8651 // Invalidate the block and loop dispositions cached for S. Dispositions of
8652 // S's users may change if S's disposition changes (i.e. a user may change to
8653 // loop-invariant, if S changes to loop invariant), so also invalidate
8654 // dispositions of S's users recursively.
8655 SmallVector<const SCEV *, 8> Worklist = {S};
8657 while (!Worklist.empty()) {
8658 const SCEV *Curr = Worklist.pop_back_val();
8659 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8660 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8661 if (!LoopDispoRemoved && !BlockDispoRemoved)
8662 continue;
8663 auto Users = SCEVUsers.find(Curr);
8664 if (Users != SCEVUsers.end())
8665 for (const auto *User : Users->second)
8666 if (Seen.insert(User).second)
8667 Worklist.push_back(User);
8668 }
8669}
8670
8671/// Get the exact loop backedge taken count considering all loop exits. A
8672/// computable result can only be returned for loops with all exiting blocks
8673/// dominating the latch. howFarToZero assumes that the limit of each loop test
8674/// is never skipped. This is a valid assumption as long as the loop exits via
8675/// that test. For precise results, it is the caller's responsibility to specify
8676/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8677const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8678 const Loop *L, ScalarEvolution *SE,
8680 // If any exits were not computable, the loop is not computable.
8681 if (!isComplete() || ExitNotTaken.empty())
8682 return SE->getCouldNotCompute();
8683
8684 const BasicBlock *Latch = L->getLoopLatch();
8685 // All exiting blocks we have collected must dominate the only backedge.
8686 if (!Latch)
8687 return SE->getCouldNotCompute();
8688
8689 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8690 // count is simply a minimum out of all these calculated exit counts.
8692 for (const auto &ENT : ExitNotTaken) {
8693 const SCEV *BECount = ENT.ExactNotTaken;
8694 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8695 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8696 "We should only have known counts for exiting blocks that dominate "
8697 "latch!");
8698
8699 Ops.push_back(BECount);
8700
8701 if (Preds)
8702 append_range(*Preds, ENT.Predicates);
8703
8704 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8705 "Predicate should be always true!");
8706 }
8707
8708 // If an earlier exit exits on the first iteration (exit count zero), then
8709 // a later poison exit count should not propagate into the result. This are
8710 // exactly the semantics provided by umin_seq.
8711 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8712}
8713
8714const ScalarEvolution::ExitNotTakenInfo *
8715ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8716 const BasicBlock *ExitingBlock,
8717 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8718 for (const auto &ENT : ExitNotTaken)
8719 if (ENT.ExitingBlock == ExitingBlock) {
8720 if (ENT.hasAlwaysTruePredicate())
8721 return &ENT;
8722 else if (Predicates) {
8723 append_range(*Predicates, ENT.Predicates);
8724 return &ENT;
8725 }
8726 }
8727
8728 return nullptr;
8729}
8730
8731/// getConstantMax - Get the constant max backedge taken count for the loop.
8732const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8733 ScalarEvolution *SE,
8734 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8735 if (!getConstantMax())
8736 return SE->getCouldNotCompute();
8737
8738 for (const auto &ENT : ExitNotTaken)
8739 if (!ENT.hasAlwaysTruePredicate()) {
8740 if (!Predicates)
8741 return SE->getCouldNotCompute();
8742 append_range(*Predicates, ENT.Predicates);
8743 }
8744
8745 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8746 isa<SCEVConstant>(getConstantMax())) &&
8747 "No point in having a non-constant max backedge taken count!");
8748 return getConstantMax();
8749}
8750
8751const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8752 const Loop *L, ScalarEvolution *SE,
8753 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8754 if (!SymbolicMax) {
8755 // Form an expression for the maximum exit count possible for this loop. We
8756 // merge the max and exact information to approximate a version of
8757 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8758 // constants.
8760
8761 for (const auto &ENT : ExitNotTaken) {
8762 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8763 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8764 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8765 "We should only have known counts for exiting blocks that "
8766 "dominate latch!");
8767 ExitCounts.push_back(ExitCount);
8768 if (Predicates)
8769 append_range(*Predicates, ENT.Predicates);
8770
8771 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8772 "Predicate should be always true!");
8773 }
8774 }
8775 if (ExitCounts.empty())
8776 SymbolicMax = SE->getCouldNotCompute();
8777 else
8778 SymbolicMax =
8779 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8780 }
8781 return SymbolicMax;
8782}
8783
8784bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8785 ScalarEvolution *SE) const {
8786 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8787 return !ENT.hasAlwaysTruePredicate();
8788 };
8789 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8790}
8791
8794
8796 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8797 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8801 // If we prove the max count is zero, so is the symbolic bound. This happens
8802 // in practice due to differences in a) how context sensitive we've chosen
8803 // to be and b) how we reason about bounds implied by UB.
8804 if (ConstantMaxNotTaken->isZero()) {
8805 this->ExactNotTaken = E = ConstantMaxNotTaken;
8806 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8807 }
8808
8811 "Exact is not allowed to be less precise than Constant Max");
8814 "Exact is not allowed to be less precise than Symbolic Max");
8817 "Symbolic Max is not allowed to be less precise than Constant Max");
8820 "No point in having a non-constant max backedge taken count!");
8822 for (const auto PredList : PredLists)
8823 for (const auto *P : PredList) {
8824 if (SeenPreds.contains(P))
8825 continue;
8826 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8827 SeenPreds.insert(P);
8828 Predicates.push_back(P);
8829 }
8830 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8831 "Backedge count should be int");
8833 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8834 "Max backedge count should be int");
8835}
8836
8844
8845/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8846/// computable exit into a persistent ExitNotTakenInfo array.
8847ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8849 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8850 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8851 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8852
8853 ExitNotTaken.reserve(ExitCounts.size());
8854 std::transform(ExitCounts.begin(), ExitCounts.end(),
8855 std::back_inserter(ExitNotTaken),
8856 [&](const EdgeExitInfo &EEI) {
8857 BasicBlock *ExitBB = EEI.first;
8858 const ExitLimit &EL = EEI.second;
8859 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8860 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8861 EL.Predicates);
8862 });
8863 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8864 isa<SCEVConstant>(ConstantMax)) &&
8865 "No point in having a non-constant max backedge taken count!");
8866}
8867
8868/// Compute the number of times the backedge of the specified loop will execute.
8869ScalarEvolution::BackedgeTakenInfo
8870ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8871 bool AllowPredicates) {
8872 SmallVector<BasicBlock *, 8> ExitingBlocks;
8873 L->getExitingBlocks(ExitingBlocks);
8874
8875 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8876
8878 bool CouldComputeBECount = true;
8879 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8880 const SCEV *MustExitMaxBECount = nullptr;
8881 const SCEV *MayExitMaxBECount = nullptr;
8882 bool MustExitMaxOrZero = false;
8883 bool IsOnlyExit = ExitingBlocks.size() == 1;
8884
8885 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8886 // and compute maxBECount.
8887 // Do a union of all the predicates here.
8888 for (BasicBlock *ExitBB : ExitingBlocks) {
8889 // We canonicalize untaken exits to br (constant), ignore them so that
8890 // proving an exit untaken doesn't negatively impact our ability to reason
8891 // about the loop as whole.
8892 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8893 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8894 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8895 if (ExitIfTrue == CI->isZero())
8896 continue;
8897 }
8898
8899 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8900
8901 assert((AllowPredicates || EL.Predicates.empty()) &&
8902 "Predicated exit limit when predicates are not allowed!");
8903
8904 // 1. For each exit that can be computed, add an entry to ExitCounts.
8905 // CouldComputeBECount is true only if all exits can be computed.
8906 if (EL.ExactNotTaken != getCouldNotCompute())
8907 ++NumExitCountsComputed;
8908 else
8909 // We couldn't compute an exact value for this exit, so
8910 // we won't be able to compute an exact value for the loop.
8911 CouldComputeBECount = false;
8912 // Remember exit count if either exact or symbolic is known. Because
8913 // Exact always implies symbolic, only check symbolic.
8914 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8915 ExitCounts.emplace_back(ExitBB, EL);
8916 else {
8917 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8918 "Exact is known but symbolic isn't?");
8919 ++NumExitCountsNotComputed;
8920 }
8921
8922 // 2. Derive the loop's MaxBECount from each exit's max number of
8923 // non-exiting iterations. Partition the loop exits into two kinds:
8924 // LoopMustExits and LoopMayExits.
8925 //
8926 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8927 // is a LoopMayExit. If any computable LoopMustExit is found, then
8928 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8929 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8930 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8931 // any
8932 // computable EL.ConstantMaxNotTaken.
8933 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8934 DT.dominates(ExitBB, Latch)) {
8935 if (!MustExitMaxBECount) {
8936 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8937 MustExitMaxOrZero = EL.MaxOrZero;
8938 } else {
8939 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8940 EL.ConstantMaxNotTaken);
8941 }
8942 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8943 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8944 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8945 else {
8946 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8947 EL.ConstantMaxNotTaken);
8948 }
8949 }
8950 }
8951 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8952 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8953 // The loop backedge will be taken the maximum or zero times if there's
8954 // a single exit that must be taken the maximum or zero times.
8955 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8956
8957 // Remember which SCEVs are used in exit limits for invalidation purposes.
8958 // We only care about non-constant SCEVs here, so we can ignore
8959 // EL.ConstantMaxNotTaken
8960 // and MaxBECount, which must be SCEVConstant.
8961 for (const auto &Pair : ExitCounts) {
8962 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8963 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8964 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8965 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8966 {L, AllowPredicates});
8967 }
8968 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8969 MaxBECount, MaxOrZero);
8970}
8971
8972ScalarEvolution::ExitLimit
8973ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8974 bool IsOnlyExit, bool AllowPredicates) {
8975 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8976 // If our exiting block does not dominate the latch, then its connection with
8977 // loop's exit limit may be far from trivial.
8978 const BasicBlock *Latch = L->getLoopLatch();
8979 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8980 return getCouldNotCompute();
8981
8982 Instruction *Term = ExitingBlock->getTerminator();
8983 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8984 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8985 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8986 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8987 "It should have one successor in loop and one exit block!");
8988 // Proceed to the next level to examine the exit condition expression.
8989 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8990 /*ControlsOnlyExit=*/IsOnlyExit,
8991 AllowPredicates);
8992 }
8993
8994 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8995 // For switch, make sure that there is a single exit from the loop.
8996 BasicBlock *Exit = nullptr;
8997 for (auto *SBB : successors(ExitingBlock))
8998 if (!L->contains(SBB)) {
8999 if (Exit) // Multiple exit successors.
9000 return getCouldNotCompute();
9001 Exit = SBB;
9002 }
9003 assert(Exit && "Exiting block must have at least one exit");
9004 return computeExitLimitFromSingleExitSwitch(
9005 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9006 }
9007
9008 return getCouldNotCompute();
9009}
9010
9012 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9013 bool AllowPredicates) {
9014 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9015 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9016 ControlsOnlyExit, AllowPredicates);
9017}
9018
9019std::optional<ScalarEvolution::ExitLimit>
9020ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9021 bool ExitIfTrue, bool ControlsOnlyExit,
9022 bool AllowPredicates) {
9023 (void)this->L;
9024 (void)this->ExitIfTrue;
9025 (void)this->AllowPredicates;
9026
9027 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9028 this->AllowPredicates == AllowPredicates &&
9029 "Variance in assumed invariant key components!");
9030 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9031 if (Itr == TripCountMap.end())
9032 return std::nullopt;
9033 return Itr->second;
9034}
9035
9036void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9037 bool ExitIfTrue,
9038 bool ControlsOnlyExit,
9039 bool AllowPredicates,
9040 const ExitLimit &EL) {
9041 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9042 this->AllowPredicates == AllowPredicates &&
9043 "Variance in assumed invariant key components!");
9044
9045 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9046 assert(InsertResult.second && "Expected successful insertion!");
9047 (void)InsertResult;
9048 (void)ExitIfTrue;
9049}
9050
9051ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9052 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9053 bool ControlsOnlyExit, bool AllowPredicates) {
9054
9055 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9056 AllowPredicates))
9057 return *MaybeEL;
9058
9059 ExitLimit EL = computeExitLimitFromCondImpl(
9060 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9061 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9062 return EL;
9063}
9064
9065ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9066 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9067 bool ControlsOnlyExit, bool AllowPredicates) {
9068 // Handle BinOp conditions (And, Or).
9069 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9070 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9071 return *LimitFromBinOp;
9072
9073 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9074 // Proceed to the next level to examine the icmp.
9075 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9076 ExitLimit EL =
9077 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9078 if (EL.hasFullInfo() || !AllowPredicates)
9079 return EL;
9080
9081 // Try again, but use SCEV predicates this time.
9082 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9083 ControlsOnlyExit,
9084 /*AllowPredicates=*/true);
9085 }
9086
9087 // Check for a constant condition. These are normally stripped out by
9088 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9089 // preserve the CFG and is temporarily leaving constant conditions
9090 // in place.
9091 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9092 if (ExitIfTrue == !CI->getZExtValue())
9093 // The backedge is always taken.
9094 return getCouldNotCompute();
9095 // The backedge is never taken.
9096 return getZero(CI->getType());
9097 }
9098
9099 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9100 // with a constant step, we can form an equivalent icmp predicate and figure
9101 // out how many iterations will be taken before we exit.
9102 const WithOverflowInst *WO;
9103 const APInt *C;
9104 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9105 match(WO->getRHS(), m_APInt(C))) {
9106 ConstantRange NWR =
9108 WO->getNoWrapKind());
9109 CmpInst::Predicate Pred;
9110 APInt NewRHSC, Offset;
9111 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9112 if (!ExitIfTrue)
9113 Pred = ICmpInst::getInversePredicate(Pred);
9114 auto *LHS = getSCEV(WO->getLHS());
9115 if (Offset != 0)
9117 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9118 ControlsOnlyExit, AllowPredicates);
9119 if (EL.hasAnyInfo())
9120 return EL;
9121 }
9122
9123 // If it's not an integer or pointer comparison then compute it the hard way.
9124 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9125}
9126
9127std::optional<ScalarEvolution::ExitLimit>
9128ScalarEvolution::computeExitLimitFromCondFromBinOp(
9129 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9130 bool ControlsOnlyExit, bool AllowPredicates) {
9131 // Check if the controlling expression for this loop is an And or Or.
9132 Value *Op0, *Op1;
9133 bool IsAnd = false;
9134 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9135 IsAnd = true;
9136 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9137 IsAnd = false;
9138 else
9139 return std::nullopt;
9140
9141 // EitherMayExit is true in these two cases:
9142 // br (and Op0 Op1), loop, exit
9143 // br (or Op0 Op1), exit, loop
9144 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9145 ExitLimit EL0 = computeExitLimitFromCondCached(
9146 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9147 AllowPredicates);
9148 ExitLimit EL1 = computeExitLimitFromCondCached(
9149 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9150 AllowPredicates);
9151
9152 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9153 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9154 if (isa<ConstantInt>(Op1))
9155 return Op1 == NeutralElement ? EL0 : EL1;
9156 if (isa<ConstantInt>(Op0))
9157 return Op0 == NeutralElement ? EL1 : EL0;
9158
9159 const SCEV *BECount = getCouldNotCompute();
9160 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9161 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9162 if (EitherMayExit) {
9163 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9164 // Both conditions must be same for the loop to continue executing.
9165 // Choose the less conservative count.
9166 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9167 EL1.ExactNotTaken != getCouldNotCompute()) {
9168 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9169 UseSequentialUMin);
9170 }
9171 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9172 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9173 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9174 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9175 else
9176 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9177 EL1.ConstantMaxNotTaken);
9178 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9179 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9180 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9181 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9182 else
9183 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9184 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9185 } else {
9186 // Both conditions must be same at the same time for the loop to exit.
9187 // For now, be conservative.
9188 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9189 BECount = EL0.ExactNotTaken;
9190 }
9191
9192 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9193 // to be more aggressive when computing BECount than when computing
9194 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9195 // and
9196 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9197 // EL1.ConstantMaxNotTaken to not.
9198 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9199 !isa<SCEVCouldNotCompute>(BECount))
9200 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9201 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9202 SymbolicMaxBECount =
9203 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9204 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9205 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9206}
9207
9208ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9209 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9210 bool AllowPredicates) {
9211 // If the condition was exit on true, convert the condition to exit on false
9212 CmpPredicate Pred;
9213 if (!ExitIfTrue)
9214 Pred = ExitCond->getCmpPredicate();
9215 else
9216 Pred = ExitCond->getInverseCmpPredicate();
9217 const ICmpInst::Predicate OriginalPred = Pred;
9218
9219 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9220 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9221
9222 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9223 AllowPredicates);
9224 if (EL.hasAnyInfo())
9225 return EL;
9226
9227 auto *ExhaustiveCount =
9228 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9229
9230 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9231 return ExhaustiveCount;
9232
9233 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9234 ExitCond->getOperand(1), L, OriginalPred);
9235}
9236ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9237 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9238 bool ControlsOnlyExit, bool AllowPredicates) {
9239
9240 // Try to evaluate any dependencies out of the loop.
9241 LHS = getSCEVAtScope(LHS, L);
9242 RHS = getSCEVAtScope(RHS, L);
9243
9244 // At this point, we would like to compute how many iterations of the
9245 // loop the predicate will return true for these inputs.
9246 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9247 // If there is a loop-invariant, force it into the RHS.
9248 std::swap(LHS, RHS);
9250 }
9251
9252 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9254 // Simplify the operands before analyzing them.
9255 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9256
9257 // If we have a comparison of a chrec against a constant, try to use value
9258 // ranges to answer this query.
9259 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9260 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9261 if (AddRec->getLoop() == L) {
9262 // Form the constant range.
9263 ConstantRange CompRange =
9264 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9265
9266 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9267 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9268 }
9269
9270 // If this loop must exit based on this condition (or execute undefined
9271 // behaviour), see if we can improve wrap flags. This is essentially
9272 // a must execute style proof.
9273 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9274 // If we can prove the test sequence produced must repeat the same values
9275 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9276 // because if it did, we'd have an infinite (undefined) loop.
9277 // TODO: We can peel off any functions which are invertible *in L*. Loop
9278 // invariant terms are effectively constants for our purposes here.
9279 auto *InnerLHS = LHS;
9280 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9281 InnerLHS = ZExt->getOperand();
9282 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9283 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9284 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9285 /*OrNegative=*/true)) {
9286 auto Flags = AR->getNoWrapFlags();
9287 Flags = setFlags(Flags, SCEV::FlagNW);
9290 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9291 }
9292
9293 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9294 // From no-self-wrap, this follows trivially from the fact that every
9295 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9296 // last value before (un)signed wrap. Since we know that last value
9297 // didn't exit, nor will any smaller one.
9298 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9299 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9300 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9301 AR && AR->getLoop() == L && AR->isAffine() &&
9302 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9303 isKnownPositive(AR->getStepRecurrence(*this))) {
9304 auto Flags = AR->getNoWrapFlags();
9305 Flags = setFlags(Flags, WrapType);
9308 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9309 }
9310 }
9311 }
9312
9313 switch (Pred) {
9314 case ICmpInst::ICMP_NE: { // while (X != Y)
9315 // Convert to: while (X-Y != 0)
9316 if (LHS->getType()->isPointerTy()) {
9319 return LHS;
9320 }
9321 if (RHS->getType()->isPointerTy()) {
9324 return RHS;
9325 }
9326 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9327 AllowPredicates);
9328 if (EL.hasAnyInfo())
9329 return EL;
9330 break;
9331 }
9332 case ICmpInst::ICMP_EQ: { // while (X == Y)
9333 // Convert to: while (X-Y == 0)
9334 if (LHS->getType()->isPointerTy()) {
9337 return LHS;
9338 }
9339 if (RHS->getType()->isPointerTy()) {
9342 return RHS;
9343 }
9344 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9345 if (EL.hasAnyInfo()) return EL;
9346 break;
9347 }
9348 case ICmpInst::ICMP_SLE:
9349 case ICmpInst::ICMP_ULE:
9350 // Since the loop is finite, an invariant RHS cannot include the boundary
9351 // value, otherwise it would loop forever.
9352 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9353 !isLoopInvariant(RHS, L)) {
9354 // Otherwise, perform the addition in a wider type, to avoid overflow.
9355 // If the LHS is an addrec with the appropriate nowrap flag, the
9356 // extension will be sunk into it and the exit count can be analyzed.
9357 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9358 if (!OldType)
9359 break;
9360 // Prefer doubling the bitwidth over adding a single bit to make it more
9361 // likely that we use a legal type.
9362 auto *NewType =
9363 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9364 if (ICmpInst::isSigned(Pred)) {
9365 LHS = getSignExtendExpr(LHS, NewType);
9366 RHS = getSignExtendExpr(RHS, NewType);
9367 } else {
9368 LHS = getZeroExtendExpr(LHS, NewType);
9369 RHS = getZeroExtendExpr(RHS, NewType);
9370 }
9371 }
9373 [[fallthrough]];
9374 case ICmpInst::ICMP_SLT:
9375 case ICmpInst::ICMP_ULT: { // while (X < Y)
9376 bool IsSigned = ICmpInst::isSigned(Pred);
9377 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9378 AllowPredicates);
9379 if (EL.hasAnyInfo())
9380 return EL;
9381 break;
9382 }
9383 case ICmpInst::ICMP_SGE:
9384 case ICmpInst::ICMP_UGE:
9385 // Since the loop is finite, an invariant RHS cannot include the boundary
9386 // value, otherwise it would loop forever.
9387 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9388 !isLoopInvariant(RHS, L))
9389 break;
9391 [[fallthrough]];
9392 case ICmpInst::ICMP_SGT:
9393 case ICmpInst::ICMP_UGT: { // while (X > Y)
9394 bool IsSigned = ICmpInst::isSigned(Pred);
9395 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9396 AllowPredicates);
9397 if (EL.hasAnyInfo())
9398 return EL;
9399 break;
9400 }
9401 default:
9402 break;
9403 }
9404
9405 return getCouldNotCompute();
9406}
9407
9408ScalarEvolution::ExitLimit
9409ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9410 SwitchInst *Switch,
9411 BasicBlock *ExitingBlock,
9412 bool ControlsOnlyExit) {
9413 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9414
9415 // Give up if the exit is the default dest of a switch.
9416 if (Switch->getDefaultDest() == ExitingBlock)
9417 return getCouldNotCompute();
9418
9419 assert(L->contains(Switch->getDefaultDest()) &&
9420 "Default case must not exit the loop!");
9421 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9422 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9423
9424 // while (X != Y) --> while (X-Y != 0)
9425 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9426 if (EL.hasAnyInfo())
9427 return EL;
9428
9429 return getCouldNotCompute();
9430}
9431
9432static ConstantInt *
9434 ScalarEvolution &SE) {
9435 const SCEV *InVal = SE.getConstant(C);
9436 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9438 "Evaluation of SCEV at constant didn't fold correctly?");
9439 return cast<SCEVConstant>(Val)->getValue();
9440}
9441
9442ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9443 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9444 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9445 if (!RHS)
9446 return getCouldNotCompute();
9447
9448 const BasicBlock *Latch = L->getLoopLatch();
9449 if (!Latch)
9450 return getCouldNotCompute();
9451
9452 const BasicBlock *Predecessor = L->getLoopPredecessor();
9453 if (!Predecessor)
9454 return getCouldNotCompute();
9455
9456 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9457 // Return LHS in OutLHS and shift_opt in OutOpCode.
9458 auto MatchPositiveShift =
9459 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9460
9461 using namespace PatternMatch;
9462
9463 ConstantInt *ShiftAmt;
9464 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9465 OutOpCode = Instruction::LShr;
9466 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9467 OutOpCode = Instruction::AShr;
9468 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9469 OutOpCode = Instruction::Shl;
9470 else
9471 return false;
9472
9473 return ShiftAmt->getValue().isStrictlyPositive();
9474 };
9475
9476 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9477 //
9478 // loop:
9479 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9480 // %iv.shifted = lshr i32 %iv, <positive constant>
9481 //
9482 // Return true on a successful match. Return the corresponding PHI node (%iv
9483 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9484 auto MatchShiftRecurrence =
9485 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9486 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9487
9488 {
9490 Value *V;
9491
9492 // If we encounter a shift instruction, "peel off" the shift operation,
9493 // and remember that we did so. Later when we inspect %iv's backedge
9494 // value, we will make sure that the backedge value uses the same
9495 // operation.
9496 //
9497 // Note: the peeled shift operation does not have to be the same
9498 // instruction as the one feeding into the PHI's backedge value. We only
9499 // really care about it being the same *kind* of shift instruction --
9500 // that's all that is required for our later inferences to hold.
9501 if (MatchPositiveShift(LHS, V, OpC)) {
9502 PostShiftOpCode = OpC;
9503 LHS = V;
9504 }
9505 }
9506
9507 PNOut = dyn_cast<PHINode>(LHS);
9508 if (!PNOut || PNOut->getParent() != L->getHeader())
9509 return false;
9510
9511 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9512 Value *OpLHS;
9513
9514 return
9515 // The backedge value for the PHI node must be a shift by a positive
9516 // amount
9517 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9518
9519 // of the PHI node itself
9520 OpLHS == PNOut &&
9521
9522 // and the kind of shift should be match the kind of shift we peeled
9523 // off, if any.
9524 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9525 };
9526
9527 PHINode *PN;
9529 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9530 return getCouldNotCompute();
9531
9532 const DataLayout &DL = getDataLayout();
9533
9534 // The key rationale for this optimization is that for some kinds of shift
9535 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9536 // within a finite number of iterations. If the condition guarding the
9537 // backedge (in the sense that the backedge is taken if the condition is true)
9538 // is false for the value the shift recurrence stabilizes to, then we know
9539 // that the backedge is taken only a finite number of times.
9540
9541 ConstantInt *StableValue = nullptr;
9542 switch (OpCode) {
9543 default:
9544 llvm_unreachable("Impossible case!");
9545
9546 case Instruction::AShr: {
9547 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9548 // bitwidth(K) iterations.
9549 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9550 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9551 Predecessor->getTerminator(), &DT);
9552 auto *Ty = cast<IntegerType>(RHS->getType());
9553 if (Known.isNonNegative())
9554 StableValue = ConstantInt::get(Ty, 0);
9555 else if (Known.isNegative())
9556 StableValue = ConstantInt::get(Ty, -1, true);
9557 else
9558 return getCouldNotCompute();
9559
9560 break;
9561 }
9562 case Instruction::LShr:
9563 case Instruction::Shl:
9564 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9565 // stabilize to 0 in at most bitwidth(K) iterations.
9566 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9567 break;
9568 }
9569
9570 auto *Result =
9571 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9572 assert(Result->getType()->isIntegerTy(1) &&
9573 "Otherwise cannot be an operand to a branch instruction");
9574
9575 if (Result->isZeroValue()) {
9576 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9577 const SCEV *UpperBound =
9579 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9580 }
9581
9582 return getCouldNotCompute();
9583}
9584
9585/// Return true if we can constant fold an instruction of the specified type,
9586/// assuming that all operands were constants.
9587static bool CanConstantFold(const Instruction *I) {
9591 return true;
9592
9593 if (const CallInst *CI = dyn_cast<CallInst>(I))
9594 if (const Function *F = CI->getCalledFunction())
9595 return canConstantFoldCallTo(CI, F);
9596 return false;
9597}
9598
9599/// Determine whether this instruction can constant evolve within this loop
9600/// assuming its operands can all constant evolve.
9601static bool canConstantEvolve(Instruction *I, const Loop *L) {
9602 // An instruction outside of the loop can't be derived from a loop PHI.
9603 if (!L->contains(I)) return false;
9604
9605 if (isa<PHINode>(I)) {
9606 // We don't currently keep track of the control flow needed to evaluate
9607 // PHIs, so we cannot handle PHIs inside of loops.
9608 return L->getHeader() == I->getParent();
9609 }
9610
9611 // If we won't be able to constant fold this expression even if the operands
9612 // are constants, bail early.
9613 return CanConstantFold(I);
9614}
9615
9616/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9617/// recursing through each instruction operand until reaching a loop header phi.
9618static PHINode *
9621 unsigned Depth) {
9623 return nullptr;
9624
9625 // Otherwise, we can evaluate this instruction if all of its operands are
9626 // constant or derived from a PHI node themselves.
9627 PHINode *PHI = nullptr;
9628 for (Value *Op : UseInst->operands()) {
9629 if (isa<Constant>(Op)) continue;
9630
9632 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9633
9634 PHINode *P = dyn_cast<PHINode>(OpInst);
9635 if (!P)
9636 // If this operand is already visited, reuse the prior result.
9637 // We may have P != PHI if this is the deepest point at which the
9638 // inconsistent paths meet.
9639 P = PHIMap.lookup(OpInst);
9640 if (!P) {
9641 // Recurse and memoize the results, whether a phi is found or not.
9642 // This recursive call invalidates pointers into PHIMap.
9643 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9644 PHIMap[OpInst] = P;
9645 }
9646 if (!P)
9647 return nullptr; // Not evolving from PHI
9648 if (PHI && PHI != P)
9649 return nullptr; // Evolving from multiple different PHIs.
9650 PHI = P;
9651 }
9652 // This is a expression evolving from a constant PHI!
9653 return PHI;
9654}
9655
9656/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9657/// in the loop that V is derived from. We allow arbitrary operations along the
9658/// way, but the operands of an operation must either be constants or a value
9659/// derived from a constant PHI. If this expression does not fit with these
9660/// constraints, return null.
9663 if (!I || !canConstantEvolve(I, L)) return nullptr;
9664
9665 if (PHINode *PN = dyn_cast<PHINode>(I))
9666 return PN;
9667
9668 // Record non-constant instructions contained by the loop.
9670 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9671}
9672
9673/// EvaluateExpression - Given an expression that passes the
9674/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9675/// in the loop has the value PHIVal. If we can't fold this expression for some
9676/// reason, return null.
9679 const DataLayout &DL,
9680 const TargetLibraryInfo *TLI) {
9681 // Convenient constant check, but redundant for recursive calls.
9682 if (Constant *C = dyn_cast<Constant>(V)) return C;
9684 if (!I) return nullptr;
9685
9686 if (Constant *C = Vals.lookup(I)) return C;
9687
9688 // An instruction inside the loop depends on a value outside the loop that we
9689 // weren't given a mapping for, or a value such as a call inside the loop.
9690 if (!canConstantEvolve(I, L)) return nullptr;
9691
9692 // An unmapped PHI can be due to a branch or another loop inside this loop,
9693 // or due to this not being the initial iteration through a loop where we
9694 // couldn't compute the evolution of this particular PHI last time.
9695 if (isa<PHINode>(I)) return nullptr;
9696
9697 std::vector<Constant*> Operands(I->getNumOperands());
9698
9699 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9700 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9701 if (!Operand) {
9702 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9703 if (!Operands[i]) return nullptr;
9704 continue;
9705 }
9706 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9707 Vals[Operand] = C;
9708 if (!C) return nullptr;
9709 Operands[i] = C;
9710 }
9711
9712 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9713 /*AllowNonDeterministic=*/false);
9714}
9715
9716
9717// If every incoming value to PN except the one for BB is a specific Constant,
9718// return that, else return nullptr.
9720 Constant *IncomingVal = nullptr;
9721
9722 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9723 if (PN->getIncomingBlock(i) == BB)
9724 continue;
9725
9726 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9727 if (!CurrentVal)
9728 return nullptr;
9729
9730 if (IncomingVal != CurrentVal) {
9731 if (IncomingVal)
9732 return nullptr;
9733 IncomingVal = CurrentVal;
9734 }
9735 }
9736
9737 return IncomingVal;
9738}
9739
9740/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9741/// in the header of its containing loop, we know the loop executes a
9742/// constant number of times, and the PHI node is just a recurrence
9743/// involving constants, fold it.
9744Constant *
9745ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9746 const APInt &BEs,
9747 const Loop *L) {
9748 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9749 if (!Inserted)
9750 return I->second;
9751
9753 return nullptr; // Not going to evaluate it.
9754
9755 Constant *&RetVal = I->second;
9756
9757 DenseMap<Instruction *, Constant *> CurrentIterVals;
9758 BasicBlock *Header = L->getHeader();
9759 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9760
9761 BasicBlock *Latch = L->getLoopLatch();
9762 if (!Latch)
9763 return nullptr;
9764
9765 for (PHINode &PHI : Header->phis()) {
9766 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9767 CurrentIterVals[&PHI] = StartCST;
9768 }
9769 if (!CurrentIterVals.count(PN))
9770 return RetVal = nullptr;
9771
9772 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9773
9774 // Execute the loop symbolically to determine the exit value.
9775 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9776 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9777
9778 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9779 unsigned IterationNum = 0;
9780 const DataLayout &DL = getDataLayout();
9781 for (; ; ++IterationNum) {
9782 if (IterationNum == NumIterations)
9783 return RetVal = CurrentIterVals[PN]; // Got exit value!
9784
9785 // Compute the value of the PHIs for the next iteration.
9786 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9787 DenseMap<Instruction *, Constant *> NextIterVals;
9788 Constant *NextPHI =
9789 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9790 if (!NextPHI)
9791 return nullptr; // Couldn't evaluate!
9792 NextIterVals[PN] = NextPHI;
9793
9794 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9795
9796 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9797 // cease to be able to evaluate one of them or if they stop evolving,
9798 // because that doesn't necessarily prevent us from computing PN.
9800 for (const auto &I : CurrentIterVals) {
9801 PHINode *PHI = dyn_cast<PHINode>(I.first);
9802 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9803 PHIsToCompute.emplace_back(PHI, I.second);
9804 }
9805 // We use two distinct loops because EvaluateExpression may invalidate any
9806 // iterators into CurrentIterVals.
9807 for (const auto &I : PHIsToCompute) {
9808 PHINode *PHI = I.first;
9809 Constant *&NextPHI = NextIterVals[PHI];
9810 if (!NextPHI) { // Not already computed.
9811 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9812 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9813 }
9814 if (NextPHI != I.second)
9815 StoppedEvolving = false;
9816 }
9817
9818 // If all entries in CurrentIterVals == NextIterVals then we can stop
9819 // iterating, the loop can't continue to change.
9820 if (StoppedEvolving)
9821 return RetVal = CurrentIterVals[PN];
9822
9823 CurrentIterVals.swap(NextIterVals);
9824 }
9825}
9826
9827const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9828 Value *Cond,
9829 bool ExitWhen) {
9830 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9831 if (!PN) return getCouldNotCompute();
9832
9833 // If the loop is canonicalized, the PHI will have exactly two entries.
9834 // That's the only form we support here.
9835 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9836
9837 DenseMap<Instruction *, Constant *> CurrentIterVals;
9838 BasicBlock *Header = L->getHeader();
9839 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9840
9841 BasicBlock *Latch = L->getLoopLatch();
9842 assert(Latch && "Should follow from NumIncomingValues == 2!");
9843
9844 for (PHINode &PHI : Header->phis()) {
9845 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9846 CurrentIterVals[&PHI] = StartCST;
9847 }
9848 if (!CurrentIterVals.count(PN))
9849 return getCouldNotCompute();
9850
9851 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9852 // the loop symbolically to determine when the condition gets a value of
9853 // "ExitWhen".
9854 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9855 const DataLayout &DL = getDataLayout();
9856 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9857 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9858 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9859
9860 // Couldn't symbolically evaluate.
9861 if (!CondVal) return getCouldNotCompute();
9862
9863 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9864 ++NumBruteForceTripCountsComputed;
9865 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9866 }
9867
9868 // Update all the PHI nodes for the next iteration.
9869 DenseMap<Instruction *, Constant *> NextIterVals;
9870
9871 // Create a list of which PHIs we need to compute. We want to do this before
9872 // calling EvaluateExpression on them because that may invalidate iterators
9873 // into CurrentIterVals.
9874 SmallVector<PHINode *, 8> PHIsToCompute;
9875 for (const auto &I : CurrentIterVals) {
9876 PHINode *PHI = dyn_cast<PHINode>(I.first);
9877 if (!PHI || PHI->getParent() != Header) continue;
9878 PHIsToCompute.push_back(PHI);
9879 }
9880 for (PHINode *PHI : PHIsToCompute) {
9881 Constant *&NextPHI = NextIterVals[PHI];
9882 if (NextPHI) continue; // Already computed!
9883
9884 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9885 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9886 }
9887 CurrentIterVals.swap(NextIterVals);
9888 }
9889
9890 // Too many iterations were needed to evaluate.
9891 return getCouldNotCompute();
9892}
9893
9894const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9896 ValuesAtScopes[V];
9897 // Check to see if we've folded this expression at this loop before.
9898 for (auto &LS : Values)
9899 if (LS.first == L)
9900 return LS.second ? LS.second : V;
9901
9902 Values.emplace_back(L, nullptr);
9903
9904 // Otherwise compute it.
9905 const SCEV *C = computeSCEVAtScope(V, L);
9906 for (auto &LS : reverse(ValuesAtScopes[V]))
9907 if (LS.first == L) {
9908 LS.second = C;
9909 if (!isa<SCEVConstant>(C))
9910 ValuesAtScopesUsers[C].push_back({L, V});
9911 break;
9912 }
9913 return C;
9914}
9915
9916/// This builds up a Constant using the ConstantExpr interface. That way, we
9917/// will return Constants for objects which aren't represented by a
9918/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9919/// Returns NULL if the SCEV isn't representable as a Constant.
9921 switch (V->getSCEVType()) {
9922 case scCouldNotCompute:
9923 case scAddRecExpr:
9924 case scVScale:
9925 return nullptr;
9926 case scConstant:
9927 return cast<SCEVConstant>(V)->getValue();
9928 case scUnknown:
9929 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9930 case scPtrToInt: {
9932 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9933 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9934
9935 return nullptr;
9936 }
9937 case scTruncate: {
9939 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9940 return ConstantExpr::getTrunc(CastOp, ST->getType());
9941 return nullptr;
9942 }
9943 case scAddExpr: {
9944 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9945 Constant *C = nullptr;
9946 for (const SCEV *Op : SA->operands()) {
9948 if (!OpC)
9949 return nullptr;
9950 if (!C) {
9951 C = OpC;
9952 continue;
9953 }
9954 assert(!C->getType()->isPointerTy() &&
9955 "Can only have one pointer, and it must be last");
9956 if (OpC->getType()->isPointerTy()) {
9957 // The offsets have been converted to bytes. We can add bytes using
9958 // an i8 GEP.
9960 OpC, C);
9961 } else {
9962 C = ConstantExpr::getAdd(C, OpC);
9963 }
9964 }
9965 return C;
9966 }
9967 case scMulExpr:
9968 case scSignExtend:
9969 case scZeroExtend:
9970 case scUDivExpr:
9971 case scSMaxExpr:
9972 case scUMaxExpr:
9973 case scSMinExpr:
9974 case scUMinExpr:
9976 return nullptr;
9977 }
9978 llvm_unreachable("Unknown SCEV kind!");
9979}
9980
9981const SCEV *
9982ScalarEvolution::getWithOperands(const SCEV *S,
9983 SmallVectorImpl<const SCEV *> &NewOps) {
9984 switch (S->getSCEVType()) {
9985 case scTruncate:
9986 case scZeroExtend:
9987 case scSignExtend:
9988 case scPtrToInt:
9989 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9990 case scAddRecExpr: {
9991 auto *AddRec = cast<SCEVAddRecExpr>(S);
9992 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9993 }
9994 case scAddExpr:
9995 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9996 case scMulExpr:
9997 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9998 case scUDivExpr:
9999 return getUDivExpr(NewOps[0], NewOps[1]);
10000 case scUMaxExpr:
10001 case scSMaxExpr:
10002 case scUMinExpr:
10003 case scSMinExpr:
10004 return getMinMaxExpr(S->getSCEVType(), NewOps);
10006 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10007 case scConstant:
10008 case scVScale:
10009 case scUnknown:
10010 return S;
10011 case scCouldNotCompute:
10012 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10013 }
10014 llvm_unreachable("Unknown SCEV kind!");
10015}
10016
10017const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10018 switch (V->getSCEVType()) {
10019 case scConstant:
10020 case scVScale:
10021 return V;
10022 case scAddRecExpr: {
10023 // If this is a loop recurrence for a loop that does not contain L, then we
10024 // are dealing with the final value computed by the loop.
10025 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10026 // First, attempt to evaluate each operand.
10027 // Avoid performing the look-up in the common case where the specified
10028 // expression has no loop-variant portions.
10029 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10030 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10031 if (OpAtScope == AddRec->getOperand(i))
10032 continue;
10033
10034 // Okay, at least one of these operands is loop variant but might be
10035 // foldable. Build a new instance of the folded commutative expression.
10037 NewOps.reserve(AddRec->getNumOperands());
10038 append_range(NewOps, AddRec->operands().take_front(i));
10039 NewOps.push_back(OpAtScope);
10040 for (++i; i != e; ++i)
10041 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10042
10043 const SCEV *FoldedRec = getAddRecExpr(
10044 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10045 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10046 // The addrec may be folded to a nonrecurrence, for example, if the
10047 // induction variable is multiplied by zero after constant folding. Go
10048 // ahead and return the folded value.
10049 if (!AddRec)
10050 return FoldedRec;
10051 break;
10052 }
10053
10054 // If the scope is outside the addrec's loop, evaluate it by using the
10055 // loop exit value of the addrec.
10056 if (!AddRec->getLoop()->contains(L)) {
10057 // To evaluate this recurrence, we need to know how many times the AddRec
10058 // loop iterates. Compute this now.
10059 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10060 if (BackedgeTakenCount == getCouldNotCompute())
10061 return AddRec;
10062
10063 // Then, evaluate the AddRec.
10064 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10065 }
10066
10067 return AddRec;
10068 }
10069 case scTruncate:
10070 case scZeroExtend:
10071 case scSignExtend:
10072 case scPtrToInt:
10073 case scAddExpr:
10074 case scMulExpr:
10075 case scUDivExpr:
10076 case scUMaxExpr:
10077 case scSMaxExpr:
10078 case scUMinExpr:
10079 case scSMinExpr:
10080 case scSequentialUMinExpr: {
10081 ArrayRef<const SCEV *> Ops = V->operands();
10082 // Avoid performing the look-up in the common case where the specified
10083 // expression has no loop-variant portions.
10084 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10085 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10086 if (OpAtScope != Ops[i]) {
10087 // Okay, at least one of these operands is loop variant but might be
10088 // foldable. Build a new instance of the folded commutative expression.
10090 NewOps.reserve(Ops.size());
10091 append_range(NewOps, Ops.take_front(i));
10092 NewOps.push_back(OpAtScope);
10093
10094 for (++i; i != e; ++i) {
10095 OpAtScope = getSCEVAtScope(Ops[i], L);
10096 NewOps.push_back(OpAtScope);
10097 }
10098
10099 return getWithOperands(V, NewOps);
10100 }
10101 }
10102 // If we got here, all operands are loop invariant.
10103 return V;
10104 }
10105 case scUnknown: {
10106 // If this instruction is evolved from a constant-evolving PHI, compute the
10107 // exit value from the loop without using SCEVs.
10108 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10110 if (!I)
10111 return V; // This is some other type of SCEVUnknown, just return it.
10112
10113 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10114 const Loop *CurrLoop = this->LI[I->getParent()];
10115 // Looking for loop exit value.
10116 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10117 PN->getParent() == CurrLoop->getHeader()) {
10118 // Okay, there is no closed form solution for the PHI node. Check
10119 // to see if the loop that contains it has a known backedge-taken
10120 // count. If so, we may be able to force computation of the exit
10121 // value.
10122 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10123 // This trivial case can show up in some degenerate cases where
10124 // the incoming IR has not yet been fully simplified.
10125 if (BackedgeTakenCount->isZero()) {
10126 Value *InitValue = nullptr;
10127 bool MultipleInitValues = false;
10128 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10129 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10130 if (!InitValue)
10131 InitValue = PN->getIncomingValue(i);
10132 else if (InitValue != PN->getIncomingValue(i)) {
10133 MultipleInitValues = true;
10134 break;
10135 }
10136 }
10137 }
10138 if (!MultipleInitValues && InitValue)
10139 return getSCEV(InitValue);
10140 }
10141 // Do we have a loop invariant value flowing around the backedge
10142 // for a loop which must execute the backedge?
10143 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10144 isKnownNonZero(BackedgeTakenCount) &&
10145 PN->getNumIncomingValues() == 2) {
10146
10147 unsigned InLoopPred =
10148 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10149 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10150 if (CurrLoop->isLoopInvariant(BackedgeVal))
10151 return getSCEV(BackedgeVal);
10152 }
10153 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10154 // Okay, we know how many times the containing loop executes. If
10155 // this is a constant evolving PHI node, get the final value at
10156 // the specified iteration number.
10157 Constant *RV =
10158 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10159 if (RV)
10160 return getSCEV(RV);
10161 }
10162 }
10163 }
10164
10165 // Okay, this is an expression that we cannot symbolically evaluate
10166 // into a SCEV. Check to see if it's possible to symbolically evaluate
10167 // the arguments into constants, and if so, try to constant propagate the
10168 // result. This is particularly useful for computing loop exit values.
10169 if (!CanConstantFold(I))
10170 return V; // This is some other type of SCEVUnknown, just return it.
10171
10173 Operands.reserve(I->getNumOperands());
10174 bool MadeImprovement = false;
10175 for (Value *Op : I->operands()) {
10176 if (Constant *C = dyn_cast<Constant>(Op)) {
10177 Operands.push_back(C);
10178 continue;
10179 }
10180
10181 // If any of the operands is non-constant and if they are
10182 // non-integer and non-pointer, don't even try to analyze them
10183 // with scev techniques.
10184 if (!isSCEVable(Op->getType()))
10185 return V;
10186
10187 const SCEV *OrigV = getSCEV(Op);
10188 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10189 MadeImprovement |= OrigV != OpV;
10190
10192 if (!C)
10193 return V;
10194 assert(C->getType() == Op->getType() && "Type mismatch");
10195 Operands.push_back(C);
10196 }
10197
10198 // Check to see if getSCEVAtScope actually made an improvement.
10199 if (!MadeImprovement)
10200 return V; // This is some other type of SCEVUnknown, just return it.
10201
10202 Constant *C = nullptr;
10203 const DataLayout &DL = getDataLayout();
10204 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10205 /*AllowNonDeterministic=*/false);
10206 if (!C)
10207 return V;
10208 return getSCEV(C);
10209 }
10210 case scCouldNotCompute:
10211 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10212 }
10213 llvm_unreachable("Unknown SCEV type!");
10214}
10215
10217 return getSCEVAtScope(getSCEV(V), L);
10218}
10219
10220const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10222 return stripInjectiveFunctions(ZExt->getOperand());
10224 return stripInjectiveFunctions(SExt->getOperand());
10225 return S;
10226}
10227
10228/// Finds the minimum unsigned root of the following equation:
10229///
10230/// A * X = B (mod N)
10231///
10232/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10233/// A and B isn't important.
10234///
10235/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10236static const SCEV *
10239
10240 ScalarEvolution &SE) {
10241 uint32_t BW = A.getBitWidth();
10242 assert(BW == SE.getTypeSizeInBits(B->getType()));
10243 assert(A != 0 && "A must be non-zero.");
10244
10245 // 1. D = gcd(A, N)
10246 //
10247 // The gcd of A and N may have only one prime factor: 2. The number of
10248 // trailing zeros in A is its multiplicity
10249 uint32_t Mult2 = A.countr_zero();
10250 // D = 2^Mult2
10251
10252 // 2. Check if B is divisible by D.
10253 //
10254 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10255 // is not less than multiplicity of this prime factor for D.
10256 if (SE.getMinTrailingZeros(B) < Mult2) {
10257 // Check if we can prove there's no remainder using URem.
10258 const SCEV *URem =
10259 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10260 const SCEV *Zero = SE.getZero(B->getType());
10261 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10262 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10263 if (!Predicates)
10264 return SE.getCouldNotCompute();
10265
10266 // Avoid adding a predicate that is known to be false.
10267 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10268 return SE.getCouldNotCompute();
10269 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10270 }
10271 }
10272
10273 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10274 // modulo (N / D).
10275 //
10276 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10277 // (N / D) in general. The inverse itself always fits into BW bits, though,
10278 // so we immediately truncate it.
10279 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10280 APInt I = AD.multiplicativeInverse().zext(BW);
10281
10282 // 4. Compute the minimum unsigned root of the equation:
10283 // I * (B / D) mod (N / D)
10284 // To simplify the computation, we factor out the divide by D:
10285 // (I * B mod N) / D
10286 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10287 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10288}
10289
10290/// For a given quadratic addrec, generate coefficients of the corresponding
10291/// quadratic equation, multiplied by a common value to ensure that they are
10292/// integers.
10293/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10294/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10295/// were multiplied by, and BitWidth is the bit width of the original addrec
10296/// coefficients.
10297/// This function returns std::nullopt if the addrec coefficients are not
10298/// compile- time constants.
10299static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10301 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10302 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10303 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10304 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10305 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10306 << *AddRec << '\n');
10307
10308 // We currently can only solve this if the coefficients are constants.
10309 if (!LC || !MC || !NC) {
10310 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10311 return std::nullopt;
10312 }
10313
10314 APInt L = LC->getAPInt();
10315 APInt M = MC->getAPInt();
10316 APInt N = NC->getAPInt();
10317 assert(!N.isZero() && "This is not a quadratic addrec");
10318
10319 unsigned BitWidth = LC->getAPInt().getBitWidth();
10320 unsigned NewWidth = BitWidth + 1;
10321 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10322 << BitWidth << '\n');
10323 // The sign-extension (as opposed to a zero-extension) here matches the
10324 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10325 N = N.sext(NewWidth);
10326 M = M.sext(NewWidth);
10327 L = L.sext(NewWidth);
10328
10329 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10330 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10331 // L+M, L+2M+N, L+3M+3N, ...
10332 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10333 //
10334 // The equation Acc = 0 is then
10335 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10336 // In a quadratic form it becomes:
10337 // N n^2 + (2M-N) n + 2L = 0.
10338
10339 APInt A = N;
10340 APInt B = 2 * M - A;
10341 APInt C = 2 * L;
10342 APInt T = APInt(NewWidth, 2);
10343 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10344 << "x + " << C << ", coeff bw: " << NewWidth
10345 << ", multiplied by " << T << '\n');
10346 return std::make_tuple(A, B, C, T, BitWidth);
10347}
10348
10349/// Helper function to compare optional APInts:
10350/// (a) if X and Y both exist, return min(X, Y),
10351/// (b) if neither X nor Y exist, return std::nullopt,
10352/// (c) if exactly one of X and Y exists, return that value.
10353static std::optional<APInt> MinOptional(std::optional<APInt> X,
10354 std::optional<APInt> Y) {
10355 if (X && Y) {
10356 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10357 APInt XW = X->sext(W);
10358 APInt YW = Y->sext(W);
10359 return XW.slt(YW) ? *X : *Y;
10360 }
10361 if (!X && !Y)
10362 return std::nullopt;
10363 return X ? *X : *Y;
10364}
10365
10366/// Helper function to truncate an optional APInt to a given BitWidth.
10367/// When solving addrec-related equations, it is preferable to return a value
10368/// that has the same bit width as the original addrec's coefficients. If the
10369/// solution fits in the original bit width, truncate it (except for i1).
10370/// Returning a value of a different bit width may inhibit some optimizations.
10371///
10372/// In general, a solution to a quadratic equation generated from an addrec
10373/// may require BW+1 bits, where BW is the bit width of the addrec's
10374/// coefficients. The reason is that the coefficients of the quadratic
10375/// equation are BW+1 bits wide (to avoid truncation when converting from
10376/// the addrec to the equation).
10377static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10378 unsigned BitWidth) {
10379 if (!X)
10380 return std::nullopt;
10381 unsigned W = X->getBitWidth();
10383 return X->trunc(BitWidth);
10384 return X;
10385}
10386
10387/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10388/// iterations. The values L, M, N are assumed to be signed, and they
10389/// should all have the same bit widths.
10390/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10391/// where BW is the bit width of the addrec's coefficients.
10392/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10393/// returned as such, otherwise the bit width of the returned value may
10394/// be greater than BW.
10395///
10396/// This function returns std::nullopt if
10397/// (a) the addrec coefficients are not constant, or
10398/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10399/// like x^2 = 5, no integer solutions exist, in other cases an integer
10400/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10401static std::optional<APInt>
10403 APInt A, B, C, M;
10404 unsigned BitWidth;
10405 auto T = GetQuadraticEquation(AddRec);
10406 if (!T)
10407 return std::nullopt;
10408
10409 std::tie(A, B, C, M, BitWidth) = *T;
10410 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10411 std::optional<APInt> X =
10413 if (!X)
10414 return std::nullopt;
10415
10416 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10417 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10418 if (!V->isZero())
10419 return std::nullopt;
10420
10421 return TruncIfPossible(X, BitWidth);
10422}
10423
10424/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10425/// iterations. The values M, N are assumed to be signed, and they
10426/// should all have the same bit widths.
10427/// Find the least n such that c(n) does not belong to the given range,
10428/// while c(n-1) does.
10429///
10430/// This function returns std::nullopt if
10431/// (a) the addrec coefficients are not constant, or
10432/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10433/// bounds of the range.
10434static std::optional<APInt>
10436 const ConstantRange &Range, ScalarEvolution &SE) {
10437 assert(AddRec->getOperand(0)->isZero() &&
10438 "Starting value of addrec should be 0");
10439 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10440 << Range << ", addrec " << *AddRec << '\n');
10441 // This case is handled in getNumIterationsInRange. Here we can assume that
10442 // we start in the range.
10443 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10444 "Addrec's initial value should be in range");
10445
10446 APInt A, B, C, M;
10447 unsigned BitWidth;
10448 auto T = GetQuadraticEquation(AddRec);
10449 if (!T)
10450 return std::nullopt;
10451
10452 // Be careful about the return value: there can be two reasons for not
10453 // returning an actual number. First, if no solutions to the equations
10454 // were found, and second, if the solutions don't leave the given range.
10455 // The first case means that the actual solution is "unknown", the second
10456 // means that it's known, but not valid. If the solution is unknown, we
10457 // cannot make any conclusions.
10458 // Return a pair: the optional solution and a flag indicating if the
10459 // solution was found.
10460 auto SolveForBoundary =
10461 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10462 // Solve for signed overflow and unsigned overflow, pick the lower
10463 // solution.
10464 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10465 << Bound << " (before multiplying by " << M << ")\n");
10466 Bound *= M; // The quadratic equation multiplier.
10467
10468 std::optional<APInt> SO;
10469 if (BitWidth > 1) {
10470 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10471 "signed overflow\n");
10473 }
10474 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10475 "unsigned overflow\n");
10476 std::optional<APInt> UO =
10478
10479 auto LeavesRange = [&] (const APInt &X) {
10480 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10481 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10482 if (Range.contains(V0->getValue()))
10483 return false;
10484 // X should be at least 1, so X-1 is non-negative.
10485 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10486 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10487 if (Range.contains(V1->getValue()))
10488 return true;
10489 return false;
10490 };
10491
10492 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10493 // can be a solution, but the function failed to find it. We cannot treat it
10494 // as "no solution".
10495 if (!SO || !UO)
10496 return {std::nullopt, false};
10497
10498 // Check the smaller value first to see if it leaves the range.
10499 // At this point, both SO and UO must have values.
10500 std::optional<APInt> Min = MinOptional(SO, UO);
10501 if (LeavesRange(*Min))
10502 return { Min, true };
10503 std::optional<APInt> Max = Min == SO ? UO : SO;
10504 if (LeavesRange(*Max))
10505 return { Max, true };
10506
10507 // Solutions were found, but were eliminated, hence the "true".
10508 return {std::nullopt, true};
10509 };
10510
10511 std::tie(A, B, C, M, BitWidth) = *T;
10512 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10513 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10514 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10515 auto SL = SolveForBoundary(Lower);
10516 auto SU = SolveForBoundary(Upper);
10517 // If any of the solutions was unknown, no meaninigful conclusions can
10518 // be made.
10519 if (!SL.second || !SU.second)
10520 return std::nullopt;
10521
10522 // Claim: The correct solution is not some value between Min and Max.
10523 //
10524 // Justification: Assuming that Min and Max are different values, one of
10525 // them is when the first signed overflow happens, the other is when the
10526 // first unsigned overflow happens. Crossing the range boundary is only
10527 // possible via an overflow (treating 0 as a special case of it, modeling
10528 // an overflow as crossing k*2^W for some k).
10529 //
10530 // The interesting case here is when Min was eliminated as an invalid
10531 // solution, but Max was not. The argument is that if there was another
10532 // overflow between Min and Max, it would also have been eliminated if
10533 // it was considered.
10534 //
10535 // For a given boundary, it is possible to have two overflows of the same
10536 // type (signed/unsigned) without having the other type in between: this
10537 // can happen when the vertex of the parabola is between the iterations
10538 // corresponding to the overflows. This is only possible when the two
10539 // overflows cross k*2^W for the same k. In such case, if the second one
10540 // left the range (and was the first one to do so), the first overflow
10541 // would have to enter the range, which would mean that either we had left
10542 // the range before or that we started outside of it. Both of these cases
10543 // are contradictions.
10544 //
10545 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10546 // solution is not some value between the Max for this boundary and the
10547 // Min of the other boundary.
10548 //
10549 // Justification: Assume that we had such Max_A and Min_B corresponding
10550 // to range boundaries A and B and such that Max_A < Min_B. If there was
10551 // a solution between Max_A and Min_B, it would have to be caused by an
10552 // overflow corresponding to either A or B. It cannot correspond to B,
10553 // since Min_B is the first occurrence of such an overflow. If it
10554 // corresponded to A, it would have to be either a signed or an unsigned
10555 // overflow that is larger than both eliminated overflows for A. But
10556 // between the eliminated overflows and this overflow, the values would
10557 // cover the entire value space, thus crossing the other boundary, which
10558 // is a contradiction.
10559
10560 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10561}
10562
10563ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10564 const Loop *L,
10565 bool ControlsOnlyExit,
10566 bool AllowPredicates) {
10567
10568 // This is only used for loops with a "x != y" exit test. The exit condition
10569 // is now expressed as a single expression, V = x-y. So the exit test is
10570 // effectively V != 0. We know and take advantage of the fact that this
10571 // expression only being used in a comparison by zero context.
10572
10574 // If the value is a constant
10575 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10576 // If the value is already zero, the branch will execute zero times.
10577 if (C->getValue()->isZero()) return C;
10578 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10579 }
10580
10581 const SCEVAddRecExpr *AddRec =
10582 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10583
10584 if (!AddRec && AllowPredicates)
10585 // Try to make this an AddRec using runtime tests, in the first X
10586 // iterations of this loop, where X is the SCEV expression found by the
10587 // algorithm below.
10588 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10589
10590 if (!AddRec || AddRec->getLoop() != L)
10591 return getCouldNotCompute();
10592
10593 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10594 // the quadratic equation to solve it.
10595 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10596 // We can only use this value if the chrec ends up with an exact zero
10597 // value at this index. When solving for "X*X != 5", for example, we
10598 // should not accept a root of 2.
10599 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10600 const auto *R = cast<SCEVConstant>(getConstant(*S));
10601 return ExitLimit(R, R, R, false, Predicates);
10602 }
10603 return getCouldNotCompute();
10604 }
10605
10606 // Otherwise we can only handle this if it is affine.
10607 if (!AddRec->isAffine())
10608 return getCouldNotCompute();
10609
10610 // If this is an affine expression, the execution count of this branch is
10611 // the minimum unsigned root of the following equation:
10612 //
10613 // Start + Step*N = 0 (mod 2^BW)
10614 //
10615 // equivalent to:
10616 //
10617 // Step*N = -Start (mod 2^BW)
10618 //
10619 // where BW is the common bit width of Start and Step.
10620
10621 // Get the initial value for the loop.
10622 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10623 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10624
10625 if (!isLoopInvariant(Step, L))
10626 return getCouldNotCompute();
10627
10628 LoopGuards Guards = LoopGuards::collect(L, *this);
10629 // Specialize step for this loop so we get context sensitive facts below.
10630 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10631
10632 // For positive steps (counting up until unsigned overflow):
10633 // N = -Start/Step (as unsigned)
10634 // For negative steps (counting down to zero):
10635 // N = Start/-Step
10636 // First compute the unsigned distance from zero in the direction of Step.
10637 bool CountDown = isKnownNegative(StepWLG);
10638 if (!CountDown && !isKnownNonNegative(StepWLG))
10639 return getCouldNotCompute();
10640
10641 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10642 // Handle unitary steps, which cannot wraparound.
10643 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10644 // N = Distance (as unsigned)
10645
10646 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10647 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10648 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10649
10650 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10651 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10652 // case, and see if we can improve the bound.
10653 //
10654 // Explicitly handling this here is necessary because getUnsignedRange
10655 // isn't context-sensitive; it doesn't know that we only care about the
10656 // range inside the loop.
10657 const SCEV *Zero = getZero(Distance->getType());
10658 const SCEV *One = getOne(Distance->getType());
10659 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10660 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10661 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10662 // as "unsigned_max(Distance + 1) - 1".
10663 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10664 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10665 }
10666 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10667 Predicates);
10668 }
10669
10670 // If the condition controls loop exit (the loop exits only if the expression
10671 // is true) and the addition is no-wrap we can use unsigned divide to
10672 // compute the backedge count. In this case, the step may not divide the
10673 // distance, but we don't care because if the condition is "missed" the loop
10674 // will have undefined behavior due to wrapping.
10675 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10676 loopHasNoAbnormalExits(AddRec->getLoop())) {
10677
10678 // If the stride is zero and the start is non-zero, the loop must be
10679 // infinite. In C++, most loops are finite by assumption, in which case the
10680 // step being zero implies UB must execute if the loop is entered.
10681 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10682 !isKnownNonZero(StepWLG))
10683 return getCouldNotCompute();
10684
10685 const SCEV *Exact =
10686 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10687 const SCEV *ConstantMax = getCouldNotCompute();
10688 if (Exact != getCouldNotCompute()) {
10689 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10690 ConstantMax =
10692 }
10693 const SCEV *SymbolicMax =
10694 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10695 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10696 }
10697
10698 // Solve the general equation.
10699 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10700 if (!StepC || StepC->getValue()->isZero())
10701 return getCouldNotCompute();
10702 const SCEV *E = SolveLinEquationWithOverflow(
10703 StepC->getAPInt(), getNegativeSCEV(Start),
10704 AllowPredicates ? &Predicates : nullptr, *this);
10705
10706 const SCEV *M = E;
10707 if (E != getCouldNotCompute()) {
10708 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10709 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10710 }
10711 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10712 return ExitLimit(E, M, S, false, Predicates);
10713}
10714
10715ScalarEvolution::ExitLimit
10716ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10717 // Loops that look like: while (X == 0) are very strange indeed. We don't
10718 // handle them yet except for the trivial case. This could be expanded in the
10719 // future as needed.
10720
10721 // If the value is a constant, check to see if it is known to be non-zero
10722 // already. If so, the backedge will execute zero times.
10723 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10724 if (!C->getValue()->isZero())
10725 return getZero(C->getType());
10726 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10727 }
10728
10729 // We could implement others, but I really doubt anyone writes loops like
10730 // this, and if they did, they would already be constant folded.
10731 return getCouldNotCompute();
10732}
10733
10734std::pair<const BasicBlock *, const BasicBlock *>
10735ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10736 const {
10737 // If the block has a unique predecessor, then there is no path from the
10738 // predecessor to the block that does not go through the direct edge
10739 // from the predecessor to the block.
10740 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10741 return {Pred, BB};
10742
10743 // A loop's header is defined to be a block that dominates the loop.
10744 // If the header has a unique predecessor outside the loop, it must be
10745 // a block that has exactly one successor that can reach the loop.
10746 if (const Loop *L = LI.getLoopFor(BB))
10747 return {L->getLoopPredecessor(), L->getHeader()};
10748
10749 return {nullptr, BB};
10750}
10751
10752/// SCEV structural equivalence is usually sufficient for testing whether two
10753/// expressions are equal, however for the purposes of looking for a condition
10754/// guarding a loop, it can be useful to be a little more general, since a
10755/// front-end may have replicated the controlling expression.
10756static bool HasSameValue(const SCEV *A, const SCEV *B) {
10757 // Quick check to see if they are the same SCEV.
10758 if (A == B) return true;
10759
10760 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10761 // Not all instructions that are "identical" compute the same value. For
10762 // instance, two distinct alloca instructions allocating the same type are
10763 // identical and do not read memory; but compute distinct values.
10764 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10765 };
10766
10767 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10768 // two different instructions with the same value. Check for this case.
10769 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10770 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10771 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10772 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10773 if (ComputesEqualValues(AI, BI))
10774 return true;
10775
10776 // Otherwise assume they may have a different value.
10777 return false;
10778}
10779
10780static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10782 if (!Add || Add->getNumOperands() != 2)
10783 return false;
10784 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10785 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10786 LHS = Add->getOperand(1);
10787 RHS = ME->getOperand(1);
10788 return true;
10789 }
10790 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10791 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10792 LHS = Add->getOperand(0);
10793 RHS = ME->getOperand(1);
10794 return true;
10795 }
10796 return false;
10797}
10798
10800 const SCEV *&RHS, unsigned Depth) {
10801 bool Changed = false;
10802 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10803 // '0 != 0'.
10804 auto TrivialCase = [&](bool TriviallyTrue) {
10806 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10807 return true;
10808 };
10809 // If we hit the max recursion limit bail out.
10810 if (Depth >= 3)
10811 return false;
10812
10813 // Canonicalize a constant to the right side.
10814 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10815 // Check for both operands constant.
10816 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10817 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10818 return TrivialCase(false);
10819 return TrivialCase(true);
10820 }
10821 // Otherwise swap the operands to put the constant on the right.
10822 std::swap(LHS, RHS);
10824 Changed = true;
10825 }
10826
10827 // If we're comparing an addrec with a value which is loop-invariant in the
10828 // addrec's loop, put the addrec on the left. Also make a dominance check,
10829 // as both operands could be addrecs loop-invariant in each other's loop.
10830 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10831 const Loop *L = AR->getLoop();
10832 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10833 std::swap(LHS, RHS);
10835 Changed = true;
10836 }
10837 }
10838
10839 // If there's a constant operand, canonicalize comparisons with boundary
10840 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10841 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10842 const APInt &RA = RC->getAPInt();
10843
10844 bool SimplifiedByConstantRange = false;
10845
10846 if (!ICmpInst::isEquality(Pred)) {
10848 if (ExactCR.isFullSet())
10849 return TrivialCase(true);
10850 if (ExactCR.isEmptySet())
10851 return TrivialCase(false);
10852
10853 APInt NewRHS;
10854 CmpInst::Predicate NewPred;
10855 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10856 ICmpInst::isEquality(NewPred)) {
10857 // We were able to convert an inequality to an equality.
10858 Pred = NewPred;
10859 RHS = getConstant(NewRHS);
10860 Changed = SimplifiedByConstantRange = true;
10861 }
10862 }
10863
10864 if (!SimplifiedByConstantRange) {
10865 switch (Pred) {
10866 default:
10867 break;
10868 case ICmpInst::ICMP_EQ:
10869 case ICmpInst::ICMP_NE:
10870 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10871 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10872 Changed = true;
10873 break;
10874
10875 // The "Should have been caught earlier!" messages refer to the fact
10876 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10877 // should have fired on the corresponding cases, and canonicalized the
10878 // check to trivial case.
10879
10880 case ICmpInst::ICMP_UGE:
10881 assert(!RA.isMinValue() && "Should have been caught earlier!");
10882 Pred = ICmpInst::ICMP_UGT;
10883 RHS = getConstant(RA - 1);
10884 Changed = true;
10885 break;
10886 case ICmpInst::ICMP_ULE:
10887 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10888 Pred = ICmpInst::ICMP_ULT;
10889 RHS = getConstant(RA + 1);
10890 Changed = true;
10891 break;
10892 case ICmpInst::ICMP_SGE:
10893 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10894 Pred = ICmpInst::ICMP_SGT;
10895 RHS = getConstant(RA - 1);
10896 Changed = true;
10897 break;
10898 case ICmpInst::ICMP_SLE:
10899 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10900 Pred = ICmpInst::ICMP_SLT;
10901 RHS = getConstant(RA + 1);
10902 Changed = true;
10903 break;
10904 }
10905 }
10906 }
10907
10908 // Check for obvious equality.
10909 if (HasSameValue(LHS, RHS)) {
10910 if (ICmpInst::isTrueWhenEqual(Pred))
10911 return TrivialCase(true);
10913 return TrivialCase(false);
10914 }
10915
10916 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10917 // adding or subtracting 1 from one of the operands.
10918 switch (Pred) {
10919 case ICmpInst::ICMP_SLE:
10920 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10921 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10923 Pred = ICmpInst::ICMP_SLT;
10924 Changed = true;
10925 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10926 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10928 Pred = ICmpInst::ICMP_SLT;
10929 Changed = true;
10930 }
10931 break;
10932 case ICmpInst::ICMP_SGE:
10933 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10934 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10936 Pred = ICmpInst::ICMP_SGT;
10937 Changed = true;
10938 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10939 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10941 Pred = ICmpInst::ICMP_SGT;
10942 Changed = true;
10943 }
10944 break;
10945 case ICmpInst::ICMP_ULE:
10946 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10947 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10949 Pred = ICmpInst::ICMP_ULT;
10950 Changed = true;
10951 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10952 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10953 Pred = ICmpInst::ICMP_ULT;
10954 Changed = true;
10955 }
10956 break;
10957 case ICmpInst::ICMP_UGE:
10958 // If RHS is an op we can fold the -1, try that first.
10959 // Otherwise prefer LHS to preserve the nuw flag.
10960 if ((isa<SCEVConstant>(RHS) ||
10962 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10963 !getUnsignedRangeMin(RHS).isMinValue()) {
10964 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10965 Pred = ICmpInst::ICMP_UGT;
10966 Changed = true;
10967 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10968 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10970 Pred = ICmpInst::ICMP_UGT;
10971 Changed = true;
10972 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
10973 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10974 Pred = ICmpInst::ICMP_UGT;
10975 Changed = true;
10976 }
10977 break;
10978 default:
10979 break;
10980 }
10981
10982 // TODO: More simplifications are possible here.
10983
10984 // Recursively simplify until we either hit a recursion limit or nothing
10985 // changes.
10986 if (Changed)
10987 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10988
10989 return Changed;
10990}
10991
10993 return getSignedRangeMax(S).isNegative();
10994}
10995
10999
11001 return !getSignedRangeMin(S).isNegative();
11002}
11003
11007
11009 // Query push down for cases where the unsigned range is
11010 // less than sufficient.
11011 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11012 return isKnownNonZero(SExt->getOperand(0));
11013 return getUnsignedRangeMin(S) != 0;
11014}
11015
11017 bool OrNegative) {
11018 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11019 if (auto *C = dyn_cast<SCEVConstant>(S))
11020 return C->getAPInt().isPowerOf2() ||
11021 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11022
11023 // The vscale_range indicates vscale is a power-of-two.
11024 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11025 };
11026
11027 if (NonRecursive(S))
11028 return true;
11029
11030 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11031 if (!Mul)
11032 return false;
11033 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11034}
11035
11037 const SCEV *S, uint64_t M,
11039 if (M == 0)
11040 return false;
11041 if (M == 1)
11042 return true;
11043
11044 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11045 // starts with a multiple of M and at every iteration step S only adds
11046 // multiples of M.
11047 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11048 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11049 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11050
11051 // For a constant, check that "S % M == 0".
11052 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11053 APInt C = Cst->getAPInt();
11054 return C.urem(M) == 0;
11055 }
11056
11057 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11058
11059 // Basic tests have failed.
11060 // Check "S % M == 0" at compile time and record runtime Assumptions.
11061 auto *STy = dyn_cast<IntegerType>(S->getType());
11062 const SCEV *SmodM =
11063 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11064 const SCEV *Zero = getZero(STy);
11065
11066 // Check whether "S % M == 0" is known at compile time.
11067 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11068 return true;
11069
11070 // Check whether "S % M != 0" is known at compile time.
11071 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11072 return false;
11073
11075
11076 // Detect redundant predicates.
11077 for (auto *A : Assumptions)
11078 if (A->implies(P, *this))
11079 return true;
11080
11081 // Only record non-redundant predicates.
11082 Assumptions.push_back(P);
11083 return true;
11084}
11085
11086std::pair<const SCEV *, const SCEV *>
11088 // Compute SCEV on entry of loop L.
11089 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11090 if (Start == getCouldNotCompute())
11091 return { Start, Start };
11092 // Compute post increment SCEV for loop L.
11093 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11094 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11095 return { Start, PostInc };
11096}
11097
11099 const SCEV *RHS) {
11100 // First collect all loops.
11102 getUsedLoops(LHS, LoopsUsed);
11103 getUsedLoops(RHS, LoopsUsed);
11104
11105 if (LoopsUsed.empty())
11106 return false;
11107
11108 // Domination relationship must be a linear order on collected loops.
11109#ifndef NDEBUG
11110 for (const auto *L1 : LoopsUsed)
11111 for (const auto *L2 : LoopsUsed)
11112 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11113 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11114 "Domination relationship is not a linear order");
11115#endif
11116
11117 const Loop *MDL =
11118 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11119 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11120 });
11121
11122 // Get init and post increment value for LHS.
11123 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11124 // if LHS contains unknown non-invariant SCEV then bail out.
11125 if (SplitLHS.first == getCouldNotCompute())
11126 return false;
11127 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11128 // Get init and post increment value for RHS.
11129 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11130 // if RHS contains unknown non-invariant SCEV then bail out.
11131 if (SplitRHS.first == getCouldNotCompute())
11132 return false;
11133 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11134 // It is possible that init SCEV contains an invariant load but it does
11135 // not dominate MDL and is not available at MDL loop entry, so we should
11136 // check it here.
11137 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11138 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11139 return false;
11140
11141 // It seems backedge guard check is faster than entry one so in some cases
11142 // it can speed up whole estimation by short circuit
11143 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11144 SplitRHS.second) &&
11145 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11146}
11147
11149 const SCEV *RHS) {
11150 // Canonicalize the inputs first.
11151 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11152
11153 if (isKnownViaInduction(Pred, LHS, RHS))
11154 return true;
11155
11156 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11157 return true;
11158
11159 // Otherwise see what can be done with some simple reasoning.
11160 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11161}
11162
11164 const SCEV *LHS,
11165 const SCEV *RHS) {
11166 if (isKnownPredicate(Pred, LHS, RHS))
11167 return true;
11169 return false;
11170 return std::nullopt;
11171}
11172
11174 const SCEV *RHS,
11175 const Instruction *CtxI) {
11176 // TODO: Analyze guards and assumes from Context's block.
11177 return isKnownPredicate(Pred, LHS, RHS) ||
11178 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11179}
11180
11181std::optional<bool>
11183 const SCEV *RHS, const Instruction *CtxI) {
11184 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11185 if (KnownWithoutContext)
11186 return KnownWithoutContext;
11187
11188 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11189 return true;
11191 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11192 return false;
11193 return std::nullopt;
11194}
11195
11197 const SCEVAddRecExpr *LHS,
11198 const SCEV *RHS) {
11199 const Loop *L = LHS->getLoop();
11200 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11201 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11202}
11203
11204std::optional<ScalarEvolution::MonotonicPredicateType>
11206 ICmpInst::Predicate Pred) {
11207 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11208
11209#ifndef NDEBUG
11210 // Verify an invariant: inverting the predicate should turn a monotonically
11211 // increasing change to a monotonically decreasing one, and vice versa.
11212 if (Result) {
11213 auto ResultSwapped =
11214 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11215
11216 assert(*ResultSwapped != *Result &&
11217 "monotonicity should flip as we flip the predicate");
11218 }
11219#endif
11220
11221 return Result;
11222}
11223
11224std::optional<ScalarEvolution::MonotonicPredicateType>
11225ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11226 ICmpInst::Predicate Pred) {
11227 // A zero step value for LHS means the induction variable is essentially a
11228 // loop invariant value. We don't really depend on the predicate actually
11229 // flipping from false to true (for increasing predicates, and the other way
11230 // around for decreasing predicates), all we care about is that *if* the
11231 // predicate changes then it only changes from false to true.
11232 //
11233 // A zero step value in itself is not very useful, but there may be places
11234 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11235 // as general as possible.
11236
11237 // Only handle LE/LT/GE/GT predicates.
11238 if (!ICmpInst::isRelational(Pred))
11239 return std::nullopt;
11240
11241 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11242 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11243 "Should be greater or less!");
11244
11245 // Check that AR does not wrap.
11246 if (ICmpInst::isUnsigned(Pred)) {
11247 if (!LHS->hasNoUnsignedWrap())
11248 return std::nullopt;
11250 }
11251 assert(ICmpInst::isSigned(Pred) &&
11252 "Relational predicate is either signed or unsigned!");
11253 if (!LHS->hasNoSignedWrap())
11254 return std::nullopt;
11255
11256 const SCEV *Step = LHS->getStepRecurrence(*this);
11257
11258 if (isKnownNonNegative(Step))
11260
11261 if (isKnownNonPositive(Step))
11263
11264 return std::nullopt;
11265}
11266
11267std::optional<ScalarEvolution::LoopInvariantPredicate>
11269 const SCEV *RHS, const Loop *L,
11270 const Instruction *CtxI) {
11271 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11272 if (!isLoopInvariant(RHS, L)) {
11273 if (!isLoopInvariant(LHS, L))
11274 return std::nullopt;
11275
11276 std::swap(LHS, RHS);
11278 }
11279
11280 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11281 if (!ArLHS || ArLHS->getLoop() != L)
11282 return std::nullopt;
11283
11284 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11285 if (!MonotonicType)
11286 return std::nullopt;
11287 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11288 // true as the loop iterates, and the backedge is control dependent on
11289 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11290 //
11291 // * if the predicate was false in the first iteration then the predicate
11292 // is never evaluated again, since the loop exits without taking the
11293 // backedge.
11294 // * if the predicate was true in the first iteration then it will
11295 // continue to be true for all future iterations since it is
11296 // monotonically increasing.
11297 //
11298 // For both the above possibilities, we can replace the loop varying
11299 // predicate with its value on the first iteration of the loop (which is
11300 // loop invariant).
11301 //
11302 // A similar reasoning applies for a monotonically decreasing predicate, by
11303 // replacing true with false and false with true in the above two bullets.
11305 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11306
11307 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11309 RHS);
11310
11311 if (!CtxI)
11312 return std::nullopt;
11313 // Try to prove via context.
11314 // TODO: Support other cases.
11315 switch (Pred) {
11316 default:
11317 break;
11318 case ICmpInst::ICMP_ULE:
11319 case ICmpInst::ICMP_ULT: {
11320 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11321 // Given preconditions
11322 // (1) ArLHS does not cross the border of positive and negative parts of
11323 // range because of:
11324 // - Positive step; (TODO: lift this limitation)
11325 // - nuw - does not cross zero boundary;
11326 // - nsw - does not cross SINT_MAX boundary;
11327 // (2) ArLHS <s RHS
11328 // (3) RHS >=s 0
11329 // we can replace the loop variant ArLHS <u RHS condition with loop
11330 // invariant Start(ArLHS) <u RHS.
11331 //
11332 // Because of (1) there are two options:
11333 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11334 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11335 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11336 // Because of (2) ArLHS <u RHS is trivially true.
11337 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11338 // We can strengthen this to Start(ArLHS) <u RHS.
11339 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11340 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11341 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11342 isKnownNonNegative(RHS) &&
11343 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11345 RHS);
11346 }
11347 }
11348
11349 return std::nullopt;
11350}
11351
11352std::optional<ScalarEvolution::LoopInvariantPredicate>
11354 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11355 const Instruction *CtxI, const SCEV *MaxIter) {
11357 Pred, LHS, RHS, L, CtxI, MaxIter))
11358 return LIP;
11359 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11360 // Number of iterations expressed as UMIN isn't always great for expressing
11361 // the value on the last iteration. If the straightforward approach didn't
11362 // work, try the following trick: if the a predicate is invariant for X, it
11363 // is also invariant for umin(X, ...). So try to find something that works
11364 // among subexpressions of MaxIter expressed as umin.
11365 for (auto *Op : UMin->operands())
11367 Pred, LHS, RHS, L, CtxI, Op))
11368 return LIP;
11369 return std::nullopt;
11370}
11371
11372std::optional<ScalarEvolution::LoopInvariantPredicate>
11374 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11375 const Instruction *CtxI, const SCEV *MaxIter) {
11376 // Try to prove the following set of facts:
11377 // - The predicate is monotonic in the iteration space.
11378 // - If the check does not fail on the 1st iteration:
11379 // - No overflow will happen during first MaxIter iterations;
11380 // - It will not fail on the MaxIter'th iteration.
11381 // If the check does fail on the 1st iteration, we leave the loop and no
11382 // other checks matter.
11383
11384 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11385 if (!isLoopInvariant(RHS, L)) {
11386 if (!isLoopInvariant(LHS, L))
11387 return std::nullopt;
11388
11389 std::swap(LHS, RHS);
11391 }
11392
11393 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11394 if (!AR || AR->getLoop() != L)
11395 return std::nullopt;
11396
11397 // The predicate must be relational (i.e. <, <=, >=, >).
11398 if (!ICmpInst::isRelational(Pred))
11399 return std::nullopt;
11400
11401 // TODO: Support steps other than +/- 1.
11402 const SCEV *Step = AR->getStepRecurrence(*this);
11403 auto *One = getOne(Step->getType());
11404 auto *MinusOne = getNegativeSCEV(One);
11405 if (Step != One && Step != MinusOne)
11406 return std::nullopt;
11407
11408 // Type mismatch here means that MaxIter is potentially larger than max
11409 // unsigned value in start type, which mean we cannot prove no wrap for the
11410 // indvar.
11411 if (AR->getType() != MaxIter->getType())
11412 return std::nullopt;
11413
11414 // Value of IV on suggested last iteration.
11415 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11416 // Does it still meet the requirement?
11417 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11418 return std::nullopt;
11419 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11420 // not exceed max unsigned value of this type), this effectively proves
11421 // that there is no wrap during the iteration. To prove that there is no
11422 // signed/unsigned wrap, we need to check that
11423 // Start <= Last for step = 1 or Start >= Last for step = -1.
11424 ICmpInst::Predicate NoOverflowPred =
11426 if (Step == MinusOne)
11427 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11428 const SCEV *Start = AR->getStart();
11429 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11430 return std::nullopt;
11431
11432 // Everything is fine.
11433 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11434}
11435
11436bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11437 const SCEV *LHS,
11438 const SCEV *RHS) {
11439 if (HasSameValue(LHS, RHS))
11440 return ICmpInst::isTrueWhenEqual(Pred);
11441
11442 auto CheckRange = [&](bool IsSigned) {
11443 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11444 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11445 return RangeLHS.icmp(Pred, RangeRHS);
11446 };
11447
11448 // The check at the top of the function catches the case where the values are
11449 // known to be equal.
11450 if (Pred == CmpInst::ICMP_EQ)
11451 return false;
11452
11453 if (Pred == CmpInst::ICMP_NE) {
11454 if (CheckRange(true) || CheckRange(false))
11455 return true;
11456 auto *Diff = getMinusSCEV(LHS, RHS);
11457 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11458 }
11459
11460 return CheckRange(CmpInst::isSigned(Pred));
11461}
11462
11463bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11464 const SCEV *LHS,
11465 const SCEV *RHS) {
11466 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11467 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11468 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11469 // OutC1 and OutC2.
11470 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11471 APInt &OutC1, APInt &OutC2,
11472 SCEV::NoWrapFlags ExpectedFlags) {
11473 const SCEV *XNonConstOp, *XConstOp;
11474 const SCEV *YNonConstOp, *YConstOp;
11475 SCEV::NoWrapFlags XFlagsPresent;
11476 SCEV::NoWrapFlags YFlagsPresent;
11477
11478 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11479 XConstOp = getZero(X->getType());
11480 XNonConstOp = X;
11481 XFlagsPresent = ExpectedFlags;
11482 }
11483 if (!isa<SCEVConstant>(XConstOp))
11484 return false;
11485
11486 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11487 YConstOp = getZero(Y->getType());
11488 YNonConstOp = Y;
11489 YFlagsPresent = ExpectedFlags;
11490 }
11491
11492 if (YNonConstOp != XNonConstOp)
11493 return false;
11494
11495 if (!isa<SCEVConstant>(YConstOp))
11496 return false;
11497
11498 // When matching ADDs with NUW flags (and unsigned predicates), only the
11499 // second ADD (with the larger constant) requires NUW.
11500 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11501 return false;
11502 if (ExpectedFlags != SCEV::FlagNUW &&
11503 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11504 return false;
11505 }
11506
11507 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11508 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11509
11510 return true;
11511 };
11512
11513 APInt C1;
11514 APInt C2;
11515
11516 switch (Pred) {
11517 default:
11518 break;
11519
11520 case ICmpInst::ICMP_SGE:
11521 std::swap(LHS, RHS);
11522 [[fallthrough]];
11523 case ICmpInst::ICMP_SLE:
11524 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11525 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11526 return true;
11527
11528 break;
11529
11530 case ICmpInst::ICMP_SGT:
11531 std::swap(LHS, RHS);
11532 [[fallthrough]];
11533 case ICmpInst::ICMP_SLT:
11534 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11535 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11536 return true;
11537
11538 break;
11539
11540 case ICmpInst::ICMP_UGE:
11541 std::swap(LHS, RHS);
11542 [[fallthrough]];
11543 case ICmpInst::ICMP_ULE:
11544 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11545 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11546 return true;
11547
11548 break;
11549
11550 case ICmpInst::ICMP_UGT:
11551 std::swap(LHS, RHS);
11552 [[fallthrough]];
11553 case ICmpInst::ICMP_ULT:
11554 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11555 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11556 return true;
11557 break;
11558 }
11559
11560 return false;
11561}
11562
11563bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11564 const SCEV *LHS,
11565 const SCEV *RHS) {
11566 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11567 return false;
11568
11569 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11570 // the stack can result in exponential time complexity.
11571 SaveAndRestore Restore(ProvingSplitPredicate, true);
11572
11573 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11574 //
11575 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11576 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11577 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11578 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11579 // use isKnownPredicate later if needed.
11580 return isKnownNonNegative(RHS) &&
11583}
11584
11585bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11586 const SCEV *LHS, const SCEV *RHS) {
11587 // No need to even try if we know the module has no guards.
11588 if (!HasGuards)
11589 return false;
11590
11591 return any_of(*BB, [&](const Instruction &I) {
11592 using namespace llvm::PatternMatch;
11593
11594 Value *Condition;
11596 m_Value(Condition))) &&
11597 isImpliedCond(Pred, LHS, RHS, Condition, false);
11598 });
11599}
11600
11601/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11602/// protected by a conditional between LHS and RHS. This is used to
11603/// to eliminate casts.
11605 CmpPredicate Pred,
11606 const SCEV *LHS,
11607 const SCEV *RHS) {
11608 // Interpret a null as meaning no loop, where there is obviously no guard
11609 // (interprocedural conditions notwithstanding). Do not bother about
11610 // unreachable loops.
11611 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11612 return true;
11613
11614 if (VerifyIR)
11615 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11616 "This cannot be done on broken IR!");
11617
11618
11619 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11620 return true;
11621
11622 BasicBlock *Latch = L->getLoopLatch();
11623 if (!Latch)
11624 return false;
11625
11626 BranchInst *LoopContinuePredicate =
11628 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11629 isImpliedCond(Pred, LHS, RHS,
11630 LoopContinuePredicate->getCondition(),
11631 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11632 return true;
11633
11634 // We don't want more than one activation of the following loops on the stack
11635 // -- that can lead to O(n!) time complexity.
11636 if (WalkingBEDominatingConds)
11637 return false;
11638
11639 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11640
11641 // See if we can exploit a trip count to prove the predicate.
11642 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11643 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11644 if (LatchBECount != getCouldNotCompute()) {
11645 // We know that Latch branches back to the loop header exactly
11646 // LatchBECount times. This means the backdege condition at Latch is
11647 // equivalent to "{0,+,1} u< LatchBECount".
11648 Type *Ty = LatchBECount->getType();
11649 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11650 const SCEV *LoopCounter =
11651 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11652 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11653 LatchBECount))
11654 return true;
11655 }
11656
11657 // Check conditions due to any @llvm.assume intrinsics.
11658 for (auto &AssumeVH : AC.assumptions()) {
11659 if (!AssumeVH)
11660 continue;
11661 auto *CI = cast<CallInst>(AssumeVH);
11662 if (!DT.dominates(CI, Latch->getTerminator()))
11663 continue;
11664
11665 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11666 return true;
11667 }
11668
11669 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11670 return true;
11671
11672 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11673 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11674 assert(DTN && "should reach the loop header before reaching the root!");
11675
11676 BasicBlock *BB = DTN->getBlock();
11677 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11678 return true;
11679
11680 BasicBlock *PBB = BB->getSinglePredecessor();
11681 if (!PBB)
11682 continue;
11683
11684 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11685 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11686 continue;
11687
11688 Value *Condition = ContinuePredicate->getCondition();
11689
11690 // If we have an edge `E` within the loop body that dominates the only
11691 // latch, the condition guarding `E` also guards the backedge. This
11692 // reasoning works only for loops with a single latch.
11693
11694 BasicBlockEdge DominatingEdge(PBB, BB);
11695 if (DominatingEdge.isSingleEdge()) {
11696 // We're constructively (and conservatively) enumerating edges within the
11697 // loop body that dominate the latch. The dominator tree better agree
11698 // with us on this:
11699 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11700
11701 if (isImpliedCond(Pred, LHS, RHS, Condition,
11702 BB != ContinuePredicate->getSuccessor(0)))
11703 return true;
11704 }
11705 }
11706
11707 return false;
11708}
11709
11711 CmpPredicate Pred,
11712 const SCEV *LHS,
11713 const SCEV *RHS) {
11714 // Do not bother proving facts for unreachable code.
11715 if (!DT.isReachableFromEntry(BB))
11716 return true;
11717 if (VerifyIR)
11718 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11719 "This cannot be done on broken IR!");
11720
11721 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11722 // the facts (a >= b && a != b) separately. A typical situation is when the
11723 // non-strict comparison is known from ranges and non-equality is known from
11724 // dominating predicates. If we are proving strict comparison, we always try
11725 // to prove non-equality and non-strict comparison separately.
11726 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11727 const bool ProvingStrictComparison =
11728 Pred != NonStrictPredicate.dropSameSign();
11729 bool ProvedNonStrictComparison = false;
11730 bool ProvedNonEquality = false;
11731
11732 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11733 if (!ProvedNonStrictComparison)
11734 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11735 if (!ProvedNonEquality)
11736 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11737 if (ProvedNonStrictComparison && ProvedNonEquality)
11738 return true;
11739 return false;
11740 };
11741
11742 if (ProvingStrictComparison) {
11743 auto ProofFn = [&](CmpPredicate P) {
11744 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11745 };
11746 if (SplitAndProve(ProofFn))
11747 return true;
11748 }
11749
11750 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11751 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11752 const Instruction *CtxI = &BB->front();
11753 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11754 return true;
11755 if (ProvingStrictComparison) {
11756 auto ProofFn = [&](CmpPredicate P) {
11757 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11758 };
11759 if (SplitAndProve(ProofFn))
11760 return true;
11761 }
11762 return false;
11763 };
11764
11765 // Starting at the block's predecessor, climb up the predecessor chain, as long
11766 // as there are predecessors that can be found that have unique successors
11767 // leading to the original block.
11768 const Loop *ContainingLoop = LI.getLoopFor(BB);
11769 const BasicBlock *PredBB;
11770 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11771 PredBB = ContainingLoop->getLoopPredecessor();
11772 else
11773 PredBB = BB->getSinglePredecessor();
11774 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11775 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11776 const BranchInst *BlockEntryPredicate =
11777 dyn_cast<BranchInst>(Pair.first->getTerminator());
11778 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11779 continue;
11780
11781 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11782 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11783 return true;
11784 }
11785
11786 // Check conditions due to any @llvm.assume intrinsics.
11787 for (auto &AssumeVH : AC.assumptions()) {
11788 if (!AssumeVH)
11789 continue;
11790 auto *CI = cast<CallInst>(AssumeVH);
11791 if (!DT.dominates(CI, BB))
11792 continue;
11793
11794 if (ProveViaCond(CI->getArgOperand(0), false))
11795 return true;
11796 }
11797
11798 // Check conditions due to any @llvm.experimental.guard intrinsics.
11799 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11800 F.getParent(), Intrinsic::experimental_guard);
11801 if (GuardDecl)
11802 for (const auto *GU : GuardDecl->users())
11803 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11804 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11805 if (ProveViaCond(Guard->getArgOperand(0), false))
11806 return true;
11807 return false;
11808}
11809
11811 const SCEV *LHS,
11812 const SCEV *RHS) {
11813 // Interpret a null as meaning no loop, where there is obviously no guard
11814 // (interprocedural conditions notwithstanding).
11815 if (!L)
11816 return false;
11817
11818 // Both LHS and RHS must be available at loop entry.
11820 "LHS is not available at Loop Entry");
11822 "RHS is not available at Loop Entry");
11823
11824 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11825 return true;
11826
11827 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11828}
11829
11830bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11831 const SCEV *RHS,
11832 const Value *FoundCondValue, bool Inverse,
11833 const Instruction *CtxI) {
11834 // False conditions implies anything. Do not bother analyzing it further.
11835 if (FoundCondValue ==
11836 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11837 return true;
11838
11839 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11840 return false;
11841
11842 auto ClearOnExit =
11843 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11844
11845 // Recursively handle And and Or conditions.
11846 const Value *Op0, *Op1;
11847 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11848 if (!Inverse)
11849 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11850 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11851 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11852 if (Inverse)
11853 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11854 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11855 }
11856
11857 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11858 if (!ICI) return false;
11859
11860 // Now that we found a conditional branch that dominates the loop or controls
11861 // the loop latch. Check to see if it is the comparison we are looking for.
11862 CmpPredicate FoundPred;
11863 if (Inverse)
11864 FoundPred = ICI->getInverseCmpPredicate();
11865 else
11866 FoundPred = ICI->getCmpPredicate();
11867
11868 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11869 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11870
11871 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11872}
11873
11874bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11875 const SCEV *RHS, CmpPredicate FoundPred,
11876 const SCEV *FoundLHS, const SCEV *FoundRHS,
11877 const Instruction *CtxI) {
11878 // Balance the types.
11879 if (getTypeSizeInBits(LHS->getType()) <
11880 getTypeSizeInBits(FoundLHS->getType())) {
11881 // For unsigned and equality predicates, try to prove that both found
11882 // operands fit into narrow unsigned range. If so, try to prove facts in
11883 // narrow types.
11884 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11885 !FoundRHS->getType()->isPointerTy()) {
11886 auto *NarrowType = LHS->getType();
11887 auto *WideType = FoundLHS->getType();
11888 auto BitWidth = getTypeSizeInBits(NarrowType);
11889 const SCEV *MaxValue = getZeroExtendExpr(
11891 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11892 MaxValue) &&
11893 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11894 MaxValue)) {
11895 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11896 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11897 // We cannot preserve samesign after truncation.
11898 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11899 TruncFoundLHS, TruncFoundRHS, CtxI))
11900 return true;
11901 }
11902 }
11903
11904 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11905 return false;
11906 if (CmpInst::isSigned(Pred)) {
11907 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11908 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11909 } else {
11910 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11911 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11912 }
11913 } else if (getTypeSizeInBits(LHS->getType()) >
11914 getTypeSizeInBits(FoundLHS->getType())) {
11915 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11916 return false;
11917 if (CmpInst::isSigned(FoundPred)) {
11918 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11919 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11920 } else {
11921 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11922 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11923 }
11924 }
11925 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11926 FoundRHS, CtxI);
11927}
11928
11929bool ScalarEvolution::isImpliedCondBalancedTypes(
11930 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11931 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11933 getTypeSizeInBits(FoundLHS->getType()) &&
11934 "Types should be balanced!");
11935 // Canonicalize the query to match the way instcombine will have
11936 // canonicalized the comparison.
11937 if (SimplifyICmpOperands(Pred, LHS, RHS))
11938 if (LHS == RHS)
11939 return CmpInst::isTrueWhenEqual(Pred);
11940 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11941 if (FoundLHS == FoundRHS)
11942 return CmpInst::isFalseWhenEqual(FoundPred);
11943
11944 // Check to see if we can make the LHS or RHS match.
11945 if (LHS == FoundRHS || RHS == FoundLHS) {
11946 if (isa<SCEVConstant>(RHS)) {
11947 std::swap(FoundLHS, FoundRHS);
11948 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11949 } else {
11950 std::swap(LHS, RHS);
11952 }
11953 }
11954
11955 // Check whether the found predicate is the same as the desired predicate.
11956 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11957 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11958
11959 // Check whether swapping the found predicate makes it the same as the
11960 // desired predicate.
11961 if (auto P = CmpPredicate::getMatching(
11962 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11963 // We can write the implication
11964 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11965 // using one of the following ways:
11966 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11967 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11968 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11969 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11970 // Forms 1. and 2. require swapping the operands of one condition. Don't
11971 // do this if it would break canonical constant/addrec ordering.
11973 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
11974 LHS, FoundLHS, FoundRHS, CtxI);
11975 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11976 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11977
11978 // There's no clear preference between forms 3. and 4., try both. Avoid
11979 // forming getNotSCEV of pointer values as the resulting subtract is
11980 // not legal.
11981 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11982 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
11983 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
11984 FoundRHS, CtxI))
11985 return true;
11986
11987 if (!FoundLHS->getType()->isPointerTy() &&
11988 !FoundRHS->getType()->isPointerTy() &&
11989 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
11990 getNotSCEV(FoundRHS), CtxI))
11991 return true;
11992
11993 return false;
11994 }
11995
11996 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11997 CmpInst::Predicate P2) {
11998 assert(P1 != P2 && "Handled earlier!");
11999 return CmpInst::isRelational(P2) &&
12001 };
12002 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12003 // Unsigned comparison is the same as signed comparison when both the
12004 // operands are non-negative or negative.
12005 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
12006 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
12007 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12008 // Create local copies that we can freely swap and canonicalize our
12009 // conditions to "le/lt".
12010 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12011 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12012 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12013 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12014 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12015 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12016 std::swap(CanonicalLHS, CanonicalRHS);
12017 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12018 }
12019 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12020 "Must be!");
12021 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12022 ICmpInst::isLE(CanonicalFoundPred)) &&
12023 "Must be!");
12024 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12025 // Use implication:
12026 // x <u y && y >=s 0 --> x <s y.
12027 // If we can prove the left part, the right part is also proven.
12028 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12029 CanonicalRHS, CanonicalFoundLHS,
12030 CanonicalFoundRHS);
12031 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12032 // Use implication:
12033 // x <s y && y <s 0 --> x <u y.
12034 // If we can prove the left part, the right part is also proven.
12035 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12036 CanonicalRHS, CanonicalFoundLHS,
12037 CanonicalFoundRHS);
12038 }
12039
12040 // Check if we can make progress by sharpening ranges.
12041 if (FoundPred == ICmpInst::ICMP_NE &&
12042 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12043
12044 const SCEVConstant *C = nullptr;
12045 const SCEV *V = nullptr;
12046
12047 if (isa<SCEVConstant>(FoundLHS)) {
12048 C = cast<SCEVConstant>(FoundLHS);
12049 V = FoundRHS;
12050 } else {
12051 C = cast<SCEVConstant>(FoundRHS);
12052 V = FoundLHS;
12053 }
12054
12055 // The guarding predicate tells us that C != V. If the known range
12056 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12057 // range we consider has to correspond to same signedness as the
12058 // predicate we're interested in folding.
12059
12060 APInt Min = ICmpInst::isSigned(Pred) ?
12062
12063 if (Min == C->getAPInt()) {
12064 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12065 // This is true even if (Min + 1) wraps around -- in case of
12066 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12067
12068 APInt SharperMin = Min + 1;
12069
12070 switch (Pred) {
12071 case ICmpInst::ICMP_SGE:
12072 case ICmpInst::ICMP_UGE:
12073 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12074 // RHS, we're done.
12075 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12076 CtxI))
12077 return true;
12078 [[fallthrough]];
12079
12080 case ICmpInst::ICMP_SGT:
12081 case ICmpInst::ICMP_UGT:
12082 // We know from the range information that (V `Pred` Min ||
12083 // V == Min). We know from the guarding condition that !(V
12084 // == Min). This gives us
12085 //
12086 // V `Pred` Min || V == Min && !(V == Min)
12087 // => V `Pred` Min
12088 //
12089 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12090
12091 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12092 return true;
12093 break;
12094
12095 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12096 case ICmpInst::ICMP_SLE:
12097 case ICmpInst::ICMP_ULE:
12098 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12099 LHS, V, getConstant(SharperMin), CtxI))
12100 return true;
12101 [[fallthrough]];
12102
12103 case ICmpInst::ICMP_SLT:
12104 case ICmpInst::ICMP_ULT:
12105 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12106 LHS, V, getConstant(Min), CtxI))
12107 return true;
12108 break;
12109
12110 default:
12111 // No change
12112 break;
12113 }
12114 }
12115 }
12116
12117 // Check whether the actual condition is beyond sufficient.
12118 if (FoundPred == ICmpInst::ICMP_EQ)
12119 if (ICmpInst::isTrueWhenEqual(Pred))
12120 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12121 return true;
12122 if (Pred == ICmpInst::ICMP_NE)
12123 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12124 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12125 return true;
12126
12127 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12128 return true;
12129
12130 // Otherwise assume the worst.
12131 return false;
12132}
12133
12134bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12135 const SCEV *&L, const SCEV *&R,
12136 SCEV::NoWrapFlags &Flags) {
12137 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12138 if (!AE || AE->getNumOperands() != 2)
12139 return false;
12140
12141 L = AE->getOperand(0);
12142 R = AE->getOperand(1);
12143 Flags = AE->getNoWrapFlags();
12144 return true;
12145}
12146
12147std::optional<APInt>
12149 // We avoid subtracting expressions here because this function is usually
12150 // fairly deep in the call stack (i.e. is called many times).
12151
12152 unsigned BW = getTypeSizeInBits(More->getType());
12153 APInt Diff(BW, 0);
12154 APInt DiffMul(BW, 1);
12155 // Try various simplifications to reduce the difference to a constant. Limit
12156 // the number of allowed simplifications to keep compile-time low.
12157 for (unsigned I = 0; I < 8; ++I) {
12158 if (More == Less)
12159 return Diff;
12160
12161 // Reduce addrecs with identical steps to their start value.
12163 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12164 const auto *MAR = cast<SCEVAddRecExpr>(More);
12165
12166 if (LAR->getLoop() != MAR->getLoop())
12167 return std::nullopt;
12168
12169 // We look at affine expressions only; not for correctness but to keep
12170 // getStepRecurrence cheap.
12171 if (!LAR->isAffine() || !MAR->isAffine())
12172 return std::nullopt;
12173
12174 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12175 return std::nullopt;
12176
12177 Less = LAR->getStart();
12178 More = MAR->getStart();
12179 continue;
12180 }
12181
12182 // Try to match a common constant multiply.
12183 auto MatchConstMul =
12184 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12185 auto *M = dyn_cast<SCEVMulExpr>(S);
12186 if (!M || M->getNumOperands() != 2 ||
12187 !isa<SCEVConstant>(M->getOperand(0)))
12188 return std::nullopt;
12189 return {
12190 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12191 };
12192 if (auto MatchedMore = MatchConstMul(More)) {
12193 if (auto MatchedLess = MatchConstMul(Less)) {
12194 if (MatchedMore->second == MatchedLess->second) {
12195 More = MatchedMore->first;
12196 Less = MatchedLess->first;
12197 DiffMul *= MatchedMore->second;
12198 continue;
12199 }
12200 }
12201 }
12202
12203 // Try to cancel out common factors in two add expressions.
12205 auto Add = [&](const SCEV *S, int Mul) {
12206 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12207 if (Mul == 1) {
12208 Diff += C->getAPInt() * DiffMul;
12209 } else {
12210 assert(Mul == -1);
12211 Diff -= C->getAPInt() * DiffMul;
12212 }
12213 } else
12214 Multiplicity[S] += Mul;
12215 };
12216 auto Decompose = [&](const SCEV *S, int Mul) {
12217 if (isa<SCEVAddExpr>(S)) {
12218 for (const SCEV *Op : S->operands())
12219 Add(Op, Mul);
12220 } else
12221 Add(S, Mul);
12222 };
12223 Decompose(More, 1);
12224 Decompose(Less, -1);
12225
12226 // Check whether all the non-constants cancel out, or reduce to new
12227 // More/Less values.
12228 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12229 for (const auto &[S, Mul] : Multiplicity) {
12230 if (Mul == 0)
12231 continue;
12232 if (Mul == 1) {
12233 if (NewMore)
12234 return std::nullopt;
12235 NewMore = S;
12236 } else if (Mul == -1) {
12237 if (NewLess)
12238 return std::nullopt;
12239 NewLess = S;
12240 } else
12241 return std::nullopt;
12242 }
12243
12244 // Values stayed the same, no point in trying further.
12245 if (NewMore == More || NewLess == Less)
12246 return std::nullopt;
12247
12248 More = NewMore;
12249 Less = NewLess;
12250
12251 // Reduced to constant.
12252 if (!More && !Less)
12253 return Diff;
12254
12255 // Left with variable on only one side, bail out.
12256 if (!More || !Less)
12257 return std::nullopt;
12258 }
12259
12260 // Did not reduce to constant.
12261 return std::nullopt;
12262}
12263
12264bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12265 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12266 const SCEV *FoundRHS, const Instruction *CtxI) {
12267 // Try to recognize the following pattern:
12268 //
12269 // FoundRHS = ...
12270 // ...
12271 // loop:
12272 // FoundLHS = {Start,+,W}
12273 // context_bb: // Basic block from the same loop
12274 // known(Pred, FoundLHS, FoundRHS)
12275 //
12276 // If some predicate is known in the context of a loop, it is also known on
12277 // each iteration of this loop, including the first iteration. Therefore, in
12278 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12279 // prove the original pred using this fact.
12280 if (!CtxI)
12281 return false;
12282 const BasicBlock *ContextBB = CtxI->getParent();
12283 // Make sure AR varies in the context block.
12284 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12285 const Loop *L = AR->getLoop();
12286 // Make sure that context belongs to the loop and executes on 1st iteration
12287 // (if it ever executes at all).
12288 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12289 return false;
12290 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12291 return false;
12292 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12293 }
12294
12295 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12296 const Loop *L = AR->getLoop();
12297 // Make sure that context belongs to the loop and executes on 1st iteration
12298 // (if it ever executes at all).
12299 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12300 return false;
12301 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12302 return false;
12303 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12304 }
12305
12306 return false;
12307}
12308
12309bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12310 const SCEV *LHS,
12311 const SCEV *RHS,
12312 const SCEV *FoundLHS,
12313 const SCEV *FoundRHS) {
12314 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12315 return false;
12316
12317 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12318 if (!AddRecLHS)
12319 return false;
12320
12321 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12322 if (!AddRecFoundLHS)
12323 return false;
12324
12325 // We'd like to let SCEV reason about control dependencies, so we constrain
12326 // both the inequalities to be about add recurrences on the same loop. This
12327 // way we can use isLoopEntryGuardedByCond later.
12328
12329 const Loop *L = AddRecFoundLHS->getLoop();
12330 if (L != AddRecLHS->getLoop())
12331 return false;
12332
12333 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12334 //
12335 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12336 // ... (2)
12337 //
12338 // Informal proof for (2), assuming (1) [*]:
12339 //
12340 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12341 //
12342 // Then
12343 //
12344 // FoundLHS s< FoundRHS s< INT_MIN - C
12345 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12346 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12347 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12348 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12349 // <=> FoundLHS + C s< FoundRHS + C
12350 //
12351 // [*]: (1) can be proved by ruling out overflow.
12352 //
12353 // [**]: This can be proved by analyzing all the four possibilities:
12354 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12355 // (A s>= 0, B s>= 0).
12356 //
12357 // Note:
12358 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12359 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12360 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12361 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12362 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12363 // C)".
12364
12365 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12366 if (!LDiff)
12367 return false;
12368 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12369 if (!RDiff || *LDiff != *RDiff)
12370 return false;
12371
12372 if (LDiff->isMinValue())
12373 return true;
12374
12375 APInt FoundRHSLimit;
12376
12377 if (Pred == CmpInst::ICMP_ULT) {
12378 FoundRHSLimit = -(*RDiff);
12379 } else {
12380 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12381 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12382 }
12383
12384 // Try to prove (1) or (2), as needed.
12385 return isAvailableAtLoopEntry(FoundRHS, L) &&
12386 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12387 getConstant(FoundRHSLimit));
12388}
12389
12390bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12391 const SCEV *RHS, const SCEV *FoundLHS,
12392 const SCEV *FoundRHS, unsigned Depth) {
12393 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12394
12395 auto ClearOnExit = make_scope_exit([&]() {
12396 if (LPhi) {
12397 bool Erased = PendingMerges.erase(LPhi);
12398 assert(Erased && "Failed to erase LPhi!");
12399 (void)Erased;
12400 }
12401 if (RPhi) {
12402 bool Erased = PendingMerges.erase(RPhi);
12403 assert(Erased && "Failed to erase RPhi!");
12404 (void)Erased;
12405 }
12406 });
12407
12408 // Find respective Phis and check that they are not being pending.
12409 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12410 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12411 if (!PendingMerges.insert(Phi).second)
12412 return false;
12413 LPhi = Phi;
12414 }
12415 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12416 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12417 // If we detect a loop of Phi nodes being processed by this method, for
12418 // example:
12419 //
12420 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12421 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12422 //
12423 // we don't want to deal with a case that complex, so return conservative
12424 // answer false.
12425 if (!PendingMerges.insert(Phi).second)
12426 return false;
12427 RPhi = Phi;
12428 }
12429
12430 // If none of LHS, RHS is a Phi, nothing to do here.
12431 if (!LPhi && !RPhi)
12432 return false;
12433
12434 // If there is a SCEVUnknown Phi we are interested in, make it left.
12435 if (!LPhi) {
12436 std::swap(LHS, RHS);
12437 std::swap(FoundLHS, FoundRHS);
12438 std::swap(LPhi, RPhi);
12440 }
12441
12442 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12443 const BasicBlock *LBB = LPhi->getParent();
12444 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12445
12446 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12447 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12448 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12449 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12450 };
12451
12452 if (RPhi && RPhi->getParent() == LBB) {
12453 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12454 // If we compare two Phis from the same block, and for each entry block
12455 // the predicate is true for incoming values from this block, then the
12456 // predicate is also true for the Phis.
12457 for (const BasicBlock *IncBB : predecessors(LBB)) {
12458 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12459 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12460 if (!ProvedEasily(L, R))
12461 return false;
12462 }
12463 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12464 // Case two: RHS is also a Phi from the same basic block, and it is an
12465 // AddRec. It means that there is a loop which has both AddRec and Unknown
12466 // PHIs, for it we can compare incoming values of AddRec from above the loop
12467 // and latch with their respective incoming values of LPhi.
12468 // TODO: Generalize to handle loops with many inputs in a header.
12469 if (LPhi->getNumIncomingValues() != 2) return false;
12470
12471 auto *RLoop = RAR->getLoop();
12472 auto *Predecessor = RLoop->getLoopPredecessor();
12473 assert(Predecessor && "Loop with AddRec with no predecessor?");
12474 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12475 if (!ProvedEasily(L1, RAR->getStart()))
12476 return false;
12477 auto *Latch = RLoop->getLoopLatch();
12478 assert(Latch && "Loop with AddRec with no latch?");
12479 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12480 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12481 return false;
12482 } else {
12483 // In all other cases go over inputs of LHS and compare each of them to RHS,
12484 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12485 // At this point RHS is either a non-Phi, or it is a Phi from some block
12486 // different from LBB.
12487 for (const BasicBlock *IncBB : predecessors(LBB)) {
12488 // Check that RHS is available in this block.
12489 if (!dominates(RHS, IncBB))
12490 return false;
12491 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12492 // Make sure L does not refer to a value from a potentially previous
12493 // iteration of a loop.
12494 if (!properlyDominates(L, LBB))
12495 return false;
12496 // Addrecs are considered to properly dominate their loop, so are missed
12497 // by the previous check. Discard any values that have computable
12498 // evolution in this loop.
12499 if (auto *Loop = LI.getLoopFor(LBB))
12500 if (hasComputableLoopEvolution(L, Loop))
12501 return false;
12502 if (!ProvedEasily(L, RHS))
12503 return false;
12504 }
12505 }
12506 return true;
12507}
12508
12509bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12510 const SCEV *LHS,
12511 const SCEV *RHS,
12512 const SCEV *FoundLHS,
12513 const SCEV *FoundRHS) {
12514 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12515 // sure that we are dealing with same LHS.
12516 if (RHS == FoundRHS) {
12517 std::swap(LHS, RHS);
12518 std::swap(FoundLHS, FoundRHS);
12520 }
12521 if (LHS != FoundLHS)
12522 return false;
12523
12524 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12525 if (!SUFoundRHS)
12526 return false;
12527
12528 Value *Shiftee, *ShiftValue;
12529
12530 using namespace PatternMatch;
12531 if (match(SUFoundRHS->getValue(),
12532 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12533 auto *ShifteeS = getSCEV(Shiftee);
12534 // Prove one of the following:
12535 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12536 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12537 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12538 // ---> LHS <s RHS
12539 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12540 // ---> LHS <=s RHS
12541 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12542 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12543 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12544 if (isKnownNonNegative(ShifteeS))
12545 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12546 }
12547
12548 return false;
12549}
12550
12551bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12552 const SCEV *RHS,
12553 const SCEV *FoundLHS,
12554 const SCEV *FoundRHS,
12555 const Instruction *CtxI) {
12556 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12557 FoundRHS) ||
12558 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12559 FoundRHS) ||
12560 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12561 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12562 CtxI) ||
12563 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12564}
12565
12566/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12567template <typename MinMaxExprType>
12568static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12569 const SCEV *Candidate) {
12570 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12571 if (!MinMaxExpr)
12572 return false;
12573
12574 return is_contained(MinMaxExpr->operands(), Candidate);
12575}
12576
12578 CmpPredicate Pred, const SCEV *LHS,
12579 const SCEV *RHS) {
12580 // If both sides are affine addrecs for the same loop, with equal
12581 // steps, and we know the recurrences don't wrap, then we only
12582 // need to check the predicate on the starting values.
12583
12584 if (!ICmpInst::isRelational(Pred))
12585 return false;
12586
12587 const SCEV *LStart, *RStart, *Step;
12588 const Loop *L;
12589 if (!match(LHS,
12590 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12592 m_SpecificLoop(L))))
12593 return false;
12598 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12599 return false;
12600
12601 return SE.isKnownPredicate(Pred, LStart, RStart);
12602}
12603
12604/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12605/// expression?
12607 const SCEV *LHS, const SCEV *RHS) {
12608 switch (Pred) {
12609 default:
12610 return false;
12611
12612 case ICmpInst::ICMP_SGE:
12613 std::swap(LHS, RHS);
12614 [[fallthrough]];
12615 case ICmpInst::ICMP_SLE:
12616 return
12617 // min(A, ...) <= A
12619 // A <= max(A, ...)
12621
12622 case ICmpInst::ICMP_UGE:
12623 std::swap(LHS, RHS);
12624 [[fallthrough]];
12625 case ICmpInst::ICMP_ULE:
12626 return
12627 // min(A, ...) <= A
12628 // FIXME: what about umin_seq?
12630 // A <= max(A, ...)
12632 }
12633
12634 llvm_unreachable("covered switch fell through?!");
12635}
12636
12637bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12638 const SCEV *RHS,
12639 const SCEV *FoundLHS,
12640 const SCEV *FoundRHS,
12641 unsigned Depth) {
12644 "LHS and RHS have different sizes?");
12645 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12646 getTypeSizeInBits(FoundRHS->getType()) &&
12647 "FoundLHS and FoundRHS have different sizes?");
12648 // We want to avoid hurting the compile time with analysis of too big trees.
12650 return false;
12651
12652 // We only want to work with GT comparison so far.
12653 if (ICmpInst::isLT(Pred)) {
12655 std::swap(LHS, RHS);
12656 std::swap(FoundLHS, FoundRHS);
12657 }
12658
12660
12661 // For unsigned, try to reduce it to corresponding signed comparison.
12662 if (P == ICmpInst::ICMP_UGT)
12663 // We can replace unsigned predicate with its signed counterpart if all
12664 // involved values are non-negative.
12665 // TODO: We could have better support for unsigned.
12666 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12667 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12668 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12669 // use this fact to prove that LHS and RHS are non-negative.
12670 const SCEV *MinusOne = getMinusOne(LHS->getType());
12671 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12672 FoundRHS) &&
12673 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12674 FoundRHS))
12676 }
12677
12678 if (P != ICmpInst::ICMP_SGT)
12679 return false;
12680
12681 auto GetOpFromSExt = [&](const SCEV *S) {
12682 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12683 return Ext->getOperand();
12684 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12685 // the constant in some cases.
12686 return S;
12687 };
12688
12689 // Acquire values from extensions.
12690 auto *OrigLHS = LHS;
12691 auto *OrigFoundLHS = FoundLHS;
12692 LHS = GetOpFromSExt(LHS);
12693 FoundLHS = GetOpFromSExt(FoundLHS);
12694
12695 // Is the SGT predicate can be proved trivially or using the found context.
12696 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12697 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12698 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12699 FoundRHS, Depth + 1);
12700 };
12701
12702 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12703 // We want to avoid creation of any new non-constant SCEV. Since we are
12704 // going to compare the operands to RHS, we should be certain that we don't
12705 // need any size extensions for this. So let's decline all cases when the
12706 // sizes of types of LHS and RHS do not match.
12707 // TODO: Maybe try to get RHS from sext to catch more cases?
12709 return false;
12710
12711 // Should not overflow.
12712 if (!LHSAddExpr->hasNoSignedWrap())
12713 return false;
12714
12715 auto *LL = LHSAddExpr->getOperand(0);
12716 auto *LR = LHSAddExpr->getOperand(1);
12717 auto *MinusOne = getMinusOne(RHS->getType());
12718
12719 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12720 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12721 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12722 };
12723 // Try to prove the following rule:
12724 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12725 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12726 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12727 return true;
12728 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12729 Value *LL, *LR;
12730 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12731
12732 using namespace llvm::PatternMatch;
12733
12734 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12735 // Rules for division.
12736 // We are going to perform some comparisons with Denominator and its
12737 // derivative expressions. In general case, creating a SCEV for it may
12738 // lead to a complex analysis of the entire graph, and in particular it
12739 // can request trip count recalculation for the same loop. This would
12740 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12741 // this, we only want to create SCEVs that are constants in this section.
12742 // So we bail if Denominator is not a constant.
12743 if (!isa<ConstantInt>(LR))
12744 return false;
12745
12746 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12747
12748 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12749 // then a SCEV for the numerator already exists and matches with FoundLHS.
12750 auto *Numerator = getExistingSCEV(LL);
12751 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12752 return false;
12753
12754 // Make sure that the numerator matches with FoundLHS and the denominator
12755 // is positive.
12756 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12757 return false;
12758
12759 auto *DTy = Denominator->getType();
12760 auto *FRHSTy = FoundRHS->getType();
12761 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12762 // One of types is a pointer and another one is not. We cannot extend
12763 // them properly to a wider type, so let us just reject this case.
12764 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12765 // to avoid this check.
12766 return false;
12767
12768 // Given that:
12769 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12770 auto *WTy = getWiderType(DTy, FRHSTy);
12771 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12772 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12773
12774 // Try to prove the following rule:
12775 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12776 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12777 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12778 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12779 if (isKnownNonPositive(RHS) &&
12780 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12781 return true;
12782
12783 // Try to prove the following rule:
12784 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12785 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12786 // If we divide it by Denominator > 2, then:
12787 // 1. If FoundLHS is negative, then the result is 0.
12788 // 2. If FoundLHS is non-negative, then the result is non-negative.
12789 // Anyways, the result is non-negative.
12790 auto *MinusOne = getMinusOne(WTy);
12791 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12792 if (isKnownNegative(RHS) &&
12793 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12794 return true;
12795 }
12796 }
12797
12798 // If our expression contained SCEVUnknown Phis, and we split it down and now
12799 // need to prove something for them, try to prove the predicate for every
12800 // possible incoming values of those Phis.
12801 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12802 return true;
12803
12804 return false;
12805}
12806
12808 const SCEV *RHS) {
12809 // zext x u<= sext x, sext x s<= zext x
12810 const SCEV *Op;
12811 switch (Pred) {
12812 case ICmpInst::ICMP_SGE:
12813 std::swap(LHS, RHS);
12814 [[fallthrough]];
12815 case ICmpInst::ICMP_SLE: {
12816 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12817 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12819 }
12820 case ICmpInst::ICMP_UGE:
12821 std::swap(LHS, RHS);
12822 [[fallthrough]];
12823 case ICmpInst::ICMP_ULE: {
12824 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12825 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12827 }
12828 default:
12829 return false;
12830 };
12831 llvm_unreachable("unhandled case");
12832}
12833
12834bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12835 const SCEV *LHS,
12836 const SCEV *RHS) {
12837 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12838 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12839 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12840 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12841 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12842}
12843
12844bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12845 const SCEV *LHS,
12846 const SCEV *RHS,
12847 const SCEV *FoundLHS,
12848 const SCEV *FoundRHS) {
12849 switch (Pred) {
12850 default:
12851 llvm_unreachable("Unexpected CmpPredicate value!");
12852 case ICmpInst::ICMP_EQ:
12853 case ICmpInst::ICMP_NE:
12854 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12855 return true;
12856 break;
12857 case ICmpInst::ICMP_SLT:
12858 case ICmpInst::ICMP_SLE:
12859 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12860 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12861 return true;
12862 break;
12863 case ICmpInst::ICMP_SGT:
12864 case ICmpInst::ICMP_SGE:
12865 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12866 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12867 return true;
12868 break;
12869 case ICmpInst::ICMP_ULT:
12870 case ICmpInst::ICMP_ULE:
12871 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12872 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12873 return true;
12874 break;
12875 case ICmpInst::ICMP_UGT:
12876 case ICmpInst::ICMP_UGE:
12877 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12878 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12879 return true;
12880 break;
12881 }
12882
12883 // Maybe it can be proved via operations?
12884 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12885 return true;
12886
12887 return false;
12888}
12889
12890bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12891 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12892 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12893 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12894 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12895 // reduce the compile time impact of this optimization.
12896 return false;
12897
12898 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12899 if (!Addend)
12900 return false;
12901
12902 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12903
12904 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12905 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12906 ConstantRange FoundLHSRange =
12907 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12908
12909 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12910 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12911
12912 // We can also compute the range of values for `LHS` that satisfy the
12913 // consequent, "`LHS` `Pred` `RHS`":
12914 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12915 // The antecedent implies the consequent if every value of `LHS` that
12916 // satisfies the antecedent also satisfies the consequent.
12917 return LHSRange.icmp(Pred, ConstRHS);
12918}
12919
12920bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12921 bool IsSigned) {
12922 assert(isKnownPositive(Stride) && "Positive stride expected!");
12923
12924 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12925 const SCEV *One = getOne(Stride->getType());
12926
12927 if (IsSigned) {
12928 APInt MaxRHS = getSignedRangeMax(RHS);
12929 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12930 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12931
12932 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12933 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12934 }
12935
12936 APInt MaxRHS = getUnsignedRangeMax(RHS);
12937 APInt MaxValue = APInt::getMaxValue(BitWidth);
12938 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12939
12940 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12941 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12942}
12943
12944bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12945 bool IsSigned) {
12946
12947 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12948 const SCEV *One = getOne(Stride->getType());
12949
12950 if (IsSigned) {
12951 APInt MinRHS = getSignedRangeMin(RHS);
12952 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12953 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12954
12955 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12956 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12957 }
12958
12959 APInt MinRHS = getUnsignedRangeMin(RHS);
12960 APInt MinValue = APInt::getMinValue(BitWidth);
12961 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12962
12963 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12964 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12965}
12966
12968 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12969 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12970 // expression fixes the case of N=0.
12971 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12972 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12973 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12974}
12975
12976const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12977 const SCEV *Stride,
12978 const SCEV *End,
12979 unsigned BitWidth,
12980 bool IsSigned) {
12981 // The logic in this function assumes we can represent a positive stride.
12982 // If we can't, the backedge-taken count must be zero.
12983 if (IsSigned && BitWidth == 1)
12984 return getZero(Stride->getType());
12985
12986 // This code below only been closely audited for negative strides in the
12987 // unsigned comparison case, it may be correct for signed comparison, but
12988 // that needs to be established.
12989 if (IsSigned && isKnownNegative(Stride))
12990 return getCouldNotCompute();
12991
12992 // Calculate the maximum backedge count based on the range of values
12993 // permitted by Start, End, and Stride.
12994 APInt MinStart =
12995 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12996
12997 APInt MinStride =
12998 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12999
13000 // We assume either the stride is positive, or the backedge-taken count
13001 // is zero. So force StrideForMaxBECount to be at least one.
13002 APInt One(BitWidth, 1);
13003 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13004 : APIntOps::umax(One, MinStride);
13005
13006 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13007 : APInt::getMaxValue(BitWidth);
13008 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13009
13010 // Although End can be a MAX expression we estimate MaxEnd considering only
13011 // the case End = RHS of the loop termination condition. This is safe because
13012 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13013 // taken count.
13014 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13015 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13016
13017 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13018 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13019 : APIntOps::umax(MaxEnd, MinStart);
13020
13021 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13022 getConstant(StrideForMaxBECount) /* Step */);
13023}
13024
13026ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13027 const Loop *L, bool IsSigned,
13028 bool ControlsOnlyExit, bool AllowPredicates) {
13030
13032 bool PredicatedIV = false;
13033 if (!IV) {
13034 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13035 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13036 if (AR && AR->getLoop() == L && AR->isAffine()) {
13037 auto canProveNUW = [&]() {
13038 // We can use the comparison to infer no-wrap flags only if it fully
13039 // controls the loop exit.
13040 if (!ControlsOnlyExit)
13041 return false;
13042
13043 if (!isLoopInvariant(RHS, L))
13044 return false;
13045
13046 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13047 // We need the sequence defined by AR to strictly increase in the
13048 // unsigned integer domain for the logic below to hold.
13049 return false;
13050
13051 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13052 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13053 // If RHS <=u Limit, then there must exist a value V in the sequence
13054 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13055 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13056 // overflow occurs. This limit also implies that a signed comparison
13057 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13058 // the high bits on both sides must be zero.
13059 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13060 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13061 Limit = Limit.zext(OuterBitWidth);
13062 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13063 };
13064 auto Flags = AR->getNoWrapFlags();
13065 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13066 Flags = setFlags(Flags, SCEV::FlagNUW);
13067
13068 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13069 if (AR->hasNoUnsignedWrap()) {
13070 // Emulate what getZeroExtendExpr would have done during construction
13071 // if we'd been able to infer the fact just above at that time.
13072 const SCEV *Step = AR->getStepRecurrence(*this);
13073 Type *Ty = ZExt->getType();
13074 auto *S = getAddRecExpr(
13076 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13078 }
13079 }
13080 }
13081 }
13082
13083
13084 if (!IV && AllowPredicates) {
13085 // Try to make this an AddRec using runtime tests, in the first X
13086 // iterations of this loop, where X is the SCEV expression found by the
13087 // algorithm below.
13088 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13089 PredicatedIV = true;
13090 }
13091
13092 // Avoid weird loops
13093 if (!IV || IV->getLoop() != L || !IV->isAffine())
13094 return getCouldNotCompute();
13095
13096 // A precondition of this method is that the condition being analyzed
13097 // reaches an exiting branch which dominates the latch. Given that, we can
13098 // assume that an increment which violates the nowrap specification and
13099 // produces poison must cause undefined behavior when the resulting poison
13100 // value is branched upon and thus we can conclude that the backedge is
13101 // taken no more often than would be required to produce that poison value.
13102 // Note that a well defined loop can exit on the iteration which violates
13103 // the nowrap specification if there is another exit (either explicit or
13104 // implicit/exceptional) which causes the loop to execute before the
13105 // exiting instruction we're analyzing would trigger UB.
13106 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13107 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13109
13110 const SCEV *Stride = IV->getStepRecurrence(*this);
13111
13112 bool PositiveStride = isKnownPositive(Stride);
13113
13114 // Avoid negative or zero stride values.
13115 if (!PositiveStride) {
13116 // We can compute the correct backedge taken count for loops with unknown
13117 // strides if we can prove that the loop is not an infinite loop with side
13118 // effects. Here's the loop structure we are trying to handle -
13119 //
13120 // i = start
13121 // do {
13122 // A[i] = i;
13123 // i += s;
13124 // } while (i < end);
13125 //
13126 // The backedge taken count for such loops is evaluated as -
13127 // (max(end, start + stride) - start - 1) /u stride
13128 //
13129 // The additional preconditions that we need to check to prove correctness
13130 // of the above formula is as follows -
13131 //
13132 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13133 // NoWrap flag).
13134 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13135 // no side effects within the loop)
13136 // c) loop has a single static exit (with no abnormal exits)
13137 //
13138 // Precondition a) implies that if the stride is negative, this is a single
13139 // trip loop. The backedge taken count formula reduces to zero in this case.
13140 //
13141 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13142 // then a zero stride means the backedge can't be taken without executing
13143 // undefined behavior.
13144 //
13145 // The positive stride case is the same as isKnownPositive(Stride) returning
13146 // true (original behavior of the function).
13147 //
13148 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13150 return getCouldNotCompute();
13151
13152 if (!isKnownNonZero(Stride)) {
13153 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13154 // if it might eventually be greater than start and if so, on which
13155 // iteration. We can't even produce a useful upper bound.
13156 if (!isLoopInvariant(RHS, L))
13157 return getCouldNotCompute();
13158
13159 // We allow a potentially zero stride, but we need to divide by stride
13160 // below. Since the loop can't be infinite and this check must control
13161 // the sole exit, we can infer the exit must be taken on the first
13162 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13163 // we know the numerator in the divides below must be zero, so we can
13164 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13165 // and produce the right result.
13166 // FIXME: Handle the case where Stride is poison?
13167 auto wouldZeroStrideBeUB = [&]() {
13168 // Proof by contradiction. Suppose the stride were zero. If we can
13169 // prove that the backedge *is* taken on the first iteration, then since
13170 // we know this condition controls the sole exit, we must have an
13171 // infinite loop. We can't have a (well defined) infinite loop per
13172 // check just above.
13173 // Note: The (Start - Stride) term is used to get the start' term from
13174 // (start' + stride,+,stride). Remember that we only care about the
13175 // result of this expression when stride == 0 at runtime.
13176 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13177 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13178 };
13179 if (!wouldZeroStrideBeUB()) {
13180 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13181 }
13182 }
13183 } else if (!NoWrap) {
13184 // Avoid proven overflow cases: this will ensure that the backedge taken
13185 // count will not generate any unsigned overflow.
13186 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13187 return getCouldNotCompute();
13188 }
13189
13190 // On all paths just preceeding, we established the following invariant:
13191 // IV can be assumed not to overflow up to and including the exiting
13192 // iteration. We proved this in one of two ways:
13193 // 1) We can show overflow doesn't occur before the exiting iteration
13194 // 1a) canIVOverflowOnLT, and b) step of one
13195 // 2) We can show that if overflow occurs, the loop must execute UB
13196 // before any possible exit.
13197 // Note that we have not yet proved RHS invariant (in general).
13198
13199 const SCEV *Start = IV->getStart();
13200
13201 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13202 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13203 // Use integer-typed versions for actual computation; we can't subtract
13204 // pointers in general.
13205 const SCEV *OrigStart = Start;
13206 const SCEV *OrigRHS = RHS;
13207 if (Start->getType()->isPointerTy()) {
13209 if (isa<SCEVCouldNotCompute>(Start))
13210 return Start;
13211 }
13212 if (RHS->getType()->isPointerTy()) {
13215 return RHS;
13216 }
13217
13218 const SCEV *End = nullptr, *BECount = nullptr,
13219 *BECountIfBackedgeTaken = nullptr;
13220 if (!isLoopInvariant(RHS, L)) {
13221 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13222 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13223 RHSAddRec->getNoWrapFlags()) {
13224 // The structure of loop we are trying to calculate backedge count of:
13225 //
13226 // left = left_start
13227 // right = right_start
13228 //
13229 // while(left < right){
13230 // ... do something here ...
13231 // left += s1; // stride of left is s1 (s1 > 0)
13232 // right += s2; // stride of right is s2 (s2 < 0)
13233 // }
13234 //
13235
13236 const SCEV *RHSStart = RHSAddRec->getStart();
13237 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13238
13239 // If Stride - RHSStride is positive and does not overflow, we can write
13240 // backedge count as ->
13241 // ceil((End - Start) /u (Stride - RHSStride))
13242 // Where, End = max(RHSStart, Start)
13243
13244 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13245 if (isKnownNegative(RHSStride) &&
13246 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13247 RHSStride)) {
13248
13249 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13250 if (isKnownPositive(Denominator)) {
13251 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13252 : getUMaxExpr(RHSStart, Start);
13253
13254 // We can do this because End >= Start, as End = max(RHSStart, Start)
13255 const SCEV *Delta = getMinusSCEV(End, Start);
13256
13257 BECount = getUDivCeilSCEV(Delta, Denominator);
13258 BECountIfBackedgeTaken =
13259 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13260 }
13261 }
13262 }
13263 if (BECount == nullptr) {
13264 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13265 // given the start, stride and max value for the end bound of the
13266 // loop (RHS), and the fact that IV does not overflow (which is
13267 // checked above).
13268 const SCEV *MaxBECount = computeMaxBECountForLT(
13269 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13270 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13271 MaxBECount, false /*MaxOrZero*/, Predicates);
13272 }
13273 } else {
13274 // We use the expression (max(End,Start)-Start)/Stride to describe the
13275 // backedge count, as if the backedge is taken at least once
13276 // max(End,Start) is End and so the result is as above, and if not
13277 // max(End,Start) is Start so we get a backedge count of zero.
13278 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13279 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13280 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13281 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13282 // Can we prove (max(RHS,Start) > Start - Stride?
13283 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13284 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13285 // In this case, we can use a refined formula for computing backedge
13286 // taken count. The general formula remains:
13287 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13288 // We want to use the alternate formula:
13289 // "((End - 1) - (Start - Stride)) /u Stride"
13290 // Let's do a quick case analysis to show these are equivalent under
13291 // our precondition that max(RHS,Start) > Start - Stride.
13292 // * For RHS <= Start, the backedge-taken count must be zero.
13293 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13294 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13295 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13296 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13297 // reducing this to the stride of 1 case.
13298 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13299 // Stride".
13300 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13301 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13302 // "((RHS - (Start - Stride) - 1) /u Stride".
13303 // Our preconditions trivially imply no overflow in that form.
13304 const SCEV *MinusOne = getMinusOne(Stride->getType());
13305 const SCEV *Numerator =
13306 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13307 BECount = getUDivExpr(Numerator, Stride);
13308 }
13309
13310 if (!BECount) {
13311 auto canProveRHSGreaterThanEqualStart = [&]() {
13312 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13313 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13314 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13315
13316 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13317 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13318 return true;
13319
13320 // (RHS > Start - 1) implies RHS >= Start.
13321 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13322 // "Start - 1" doesn't overflow.
13323 // * For signed comparison, if Start - 1 does overflow, it's equal
13324 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13325 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13326 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13327 //
13328 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13329 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13330 auto *StartMinusOne =
13331 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13332 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13333 };
13334
13335 // If we know that RHS >= Start in the context of loop, then we know
13336 // that max(RHS, Start) = RHS at this point.
13337 if (canProveRHSGreaterThanEqualStart()) {
13338 End = RHS;
13339 } else {
13340 // If RHS < Start, the backedge will be taken zero times. So in
13341 // general, we can write the backedge-taken count as:
13342 //
13343 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13344 //
13345 // We convert it to the following to make it more convenient for SCEV:
13346 //
13347 // ceil(max(RHS, Start) - Start) / Stride
13348 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13349
13350 // See what would happen if we assume the backedge is taken. This is
13351 // used to compute MaxBECount.
13352 BECountIfBackedgeTaken =
13353 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13354 }
13355
13356 // At this point, we know:
13357 //
13358 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13359 // 2. The index variable doesn't overflow.
13360 //
13361 // Therefore, we know N exists such that
13362 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13363 // doesn't overflow.
13364 //
13365 // Using this information, try to prove whether the addition in
13366 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13367 const SCEV *One = getOne(Stride->getType());
13368 bool MayAddOverflow = [&] {
13369 if (isKnownToBeAPowerOfTwo(Stride)) {
13370 // Suppose Stride is a power of two, and Start/End are unsigned
13371 // integers. Let UMAX be the largest representable unsigned
13372 // integer.
13373 //
13374 // By the preconditions of this function, we know
13375 // "(Start + Stride * N) >= End", and this doesn't overflow.
13376 // As a formula:
13377 //
13378 // End <= (Start + Stride * N) <= UMAX
13379 //
13380 // Subtracting Start from all the terms:
13381 //
13382 // End - Start <= Stride * N <= UMAX - Start
13383 //
13384 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13385 //
13386 // End - Start <= Stride * N <= UMAX
13387 //
13388 // Stride * N is a multiple of Stride. Therefore,
13389 //
13390 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13391 //
13392 // Since Stride is a power of two, UMAX + 1 is divisible by
13393 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13394 // write:
13395 //
13396 // End - Start <= Stride * N <= UMAX - Stride - 1
13397 //
13398 // Dropping the middle term:
13399 //
13400 // End - Start <= UMAX - Stride - 1
13401 //
13402 // Adding Stride - 1 to both sides:
13403 //
13404 // (End - Start) + (Stride - 1) <= UMAX
13405 //
13406 // In other words, the addition doesn't have unsigned overflow.
13407 //
13408 // A similar proof works if we treat Start/End as signed values.
13409 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13410 // to use signed max instead of unsigned max. Note that we're
13411 // trying to prove a lack of unsigned overflow in either case.
13412 return false;
13413 }
13414 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13415 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13416 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13417 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13418 // 1 <s End.
13419 //
13420 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13421 // End.
13422 return false;
13423 }
13424 return true;
13425 }();
13426
13427 const SCEV *Delta = getMinusSCEV(End, Start);
13428 if (!MayAddOverflow) {
13429 // floor((D + (S - 1)) / S)
13430 // We prefer this formulation if it's legal because it's fewer
13431 // operations.
13432 BECount =
13433 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13434 } else {
13435 BECount = getUDivCeilSCEV(Delta, Stride);
13436 }
13437 }
13438 }
13439
13440 const SCEV *ConstantMaxBECount;
13441 bool MaxOrZero = false;
13442 if (isa<SCEVConstant>(BECount)) {
13443 ConstantMaxBECount = BECount;
13444 } else if (BECountIfBackedgeTaken &&
13445 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13446 // If we know exactly how many times the backedge will be taken if it's
13447 // taken at least once, then the backedge count will either be that or
13448 // zero.
13449 ConstantMaxBECount = BECountIfBackedgeTaken;
13450 MaxOrZero = true;
13451 } else {
13452 ConstantMaxBECount = computeMaxBECountForLT(
13453 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13454 }
13455
13456 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13457 !isa<SCEVCouldNotCompute>(BECount))
13458 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13459
13460 const SCEV *SymbolicMaxBECount =
13461 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13462 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13463 Predicates);
13464}
13465
13466ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13467 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13468 bool ControlsOnlyExit, bool AllowPredicates) {
13470 // We handle only IV > Invariant
13471 if (!isLoopInvariant(RHS, L))
13472 return getCouldNotCompute();
13473
13474 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13475 if (!IV && AllowPredicates)
13476 // Try to make this an AddRec using runtime tests, in the first X
13477 // iterations of this loop, where X is the SCEV expression found by the
13478 // algorithm below.
13479 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13480
13481 // Avoid weird loops
13482 if (!IV || IV->getLoop() != L || !IV->isAffine())
13483 return getCouldNotCompute();
13484
13485 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13486 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13488
13489 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13490
13491 // Avoid negative or zero stride values
13492 if (!isKnownPositive(Stride))
13493 return getCouldNotCompute();
13494
13495 // Avoid proven overflow cases: this will ensure that the backedge taken count
13496 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13497 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13498 // behaviors like the case of C language.
13499 if (!Stride->isOne() && !NoWrap)
13500 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13501 return getCouldNotCompute();
13502
13503 const SCEV *Start = IV->getStart();
13504 const SCEV *End = RHS;
13505 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13506 // If we know that Start >= RHS in the context of loop, then we know that
13507 // min(RHS, Start) = RHS at this point.
13509 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13510 End = RHS;
13511 else
13512 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13513 }
13514
13515 if (Start->getType()->isPointerTy()) {
13517 if (isa<SCEVCouldNotCompute>(Start))
13518 return Start;
13519 }
13520 if (End->getType()->isPointerTy()) {
13521 End = getLosslessPtrToIntExpr(End);
13522 if (isa<SCEVCouldNotCompute>(End))
13523 return End;
13524 }
13525
13526 // Compute ((Start - End) + (Stride - 1)) / Stride.
13527 // FIXME: This can overflow. Holding off on fixing this for now;
13528 // howManyGreaterThans will hopefully be gone soon.
13529 const SCEV *One = getOne(Stride->getType());
13530 const SCEV *BECount = getUDivExpr(
13531 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13532
13533 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13535
13536 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13537 : getUnsignedRangeMin(Stride);
13538
13539 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13540 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13541 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13542
13543 // Although End can be a MIN expression we estimate MinEnd considering only
13544 // the case End = RHS. This is safe because in the other case (Start - End)
13545 // is zero, leading to a zero maximum backedge taken count.
13546 APInt MinEnd =
13547 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13548 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13549
13550 const SCEV *ConstantMaxBECount =
13551 isa<SCEVConstant>(BECount)
13552 ? BECount
13553 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13554 getConstant(MinStride));
13555
13556 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13557 ConstantMaxBECount = BECount;
13558 const SCEV *SymbolicMaxBECount =
13559 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13560
13561 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13562 Predicates);
13563}
13564
13566 ScalarEvolution &SE) const {
13567 if (Range.isFullSet()) // Infinite loop.
13568 return SE.getCouldNotCompute();
13569
13570 // If the start is a non-zero constant, shift the range to simplify things.
13571 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13572 if (!SC->getValue()->isZero()) {
13574 Operands[0] = SE.getZero(SC->getType());
13575 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13577 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13578 return ShiftedAddRec->getNumIterationsInRange(
13579 Range.subtract(SC->getAPInt()), SE);
13580 // This is strange and shouldn't happen.
13581 return SE.getCouldNotCompute();
13582 }
13583
13584 // The only time we can solve this is when we have all constant indices.
13585 // Otherwise, we cannot determine the overflow conditions.
13586 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13587 return SE.getCouldNotCompute();
13588
13589 // Okay at this point we know that all elements of the chrec are constants and
13590 // that the start element is zero.
13591
13592 // First check to see if the range contains zero. If not, the first
13593 // iteration exits.
13594 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13595 if (!Range.contains(APInt(BitWidth, 0)))
13596 return SE.getZero(getType());
13597
13598 if (isAffine()) {
13599 // If this is an affine expression then we have this situation:
13600 // Solve {0,+,A} in Range === Ax in Range
13601
13602 // We know that zero is in the range. If A is positive then we know that
13603 // the upper value of the range must be the first possible exit value.
13604 // If A is negative then the lower of the range is the last possible loop
13605 // value. Also note that we already checked for a full range.
13606 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13607 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13608
13609 // The exit value should be (End+A)/A.
13610 APInt ExitVal = (End + A).udiv(A);
13611 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13612
13613 // Evaluate at the exit value. If we really did fall out of the valid
13614 // range, then we computed our trip count, otherwise wrap around or other
13615 // things must have happened.
13616 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13617 if (Range.contains(Val->getValue()))
13618 return SE.getCouldNotCompute(); // Something strange happened
13619
13620 // Ensure that the previous value is in the range.
13621 assert(Range.contains(
13623 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13624 "Linear scev computation is off in a bad way!");
13625 return SE.getConstant(ExitValue);
13626 }
13627
13628 if (isQuadratic()) {
13629 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13630 return SE.getConstant(*S);
13631 }
13632
13633 return SE.getCouldNotCompute();
13634}
13635
13636const SCEVAddRecExpr *
13638 assert(getNumOperands() > 1 && "AddRec with zero step?");
13639 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13640 // but in this case we cannot guarantee that the value returned will be an
13641 // AddRec because SCEV does not have a fixed point where it stops
13642 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13643 // may happen if we reach arithmetic depth limit while simplifying. So we
13644 // construct the returned value explicitly.
13646 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13647 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13648 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13649 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13650 // We know that the last operand is not a constant zero (otherwise it would
13651 // have been popped out earlier). This guarantees us that if the result has
13652 // the same last operand, then it will also not be popped out, meaning that
13653 // the returned value will be an AddRec.
13654 const SCEV *Last = getOperand(getNumOperands() - 1);
13655 assert(!Last->isZero() && "Recurrency with zero step?");
13656 Ops.push_back(Last);
13659}
13660
13661// Return true when S contains at least an undef value.
13663 return SCEVExprContains(S, [](const SCEV *S) {
13664 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13665 return isa<UndefValue>(SU->getValue());
13666 return false;
13667 });
13668}
13669
13670// Return true when S contains a value that is a nullptr.
13672 return SCEVExprContains(S, [](const SCEV *S) {
13673 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13674 return SU->getValue() == nullptr;
13675 return false;
13676 });
13677}
13678
13679/// Return the size of an element read or written by Inst.
13681 Type *Ty;
13682 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13683 Ty = Store->getValueOperand()->getType();
13684 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13685 Ty = Load->getType();
13686 else
13687 return nullptr;
13688
13690 return getSizeOfExpr(ETy, Ty);
13691}
13692
13693//===----------------------------------------------------------------------===//
13694// SCEVCallbackVH Class Implementation
13695//===----------------------------------------------------------------------===//
13696
13698 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13699 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13700 SE->ConstantEvolutionLoopExitValue.erase(PN);
13701 SE->eraseValueFromMap(getValPtr());
13702 // this now dangles!
13703}
13704
13705void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13706 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13707
13708 // Forget all the expressions associated with users of the old value,
13709 // so that future queries will recompute the expressions using the new
13710 // value.
13711 SE->forgetValue(getValPtr());
13712 // this now dangles!
13713}
13714
13715ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13716 : CallbackVH(V), SE(se) {}
13717
13718//===----------------------------------------------------------------------===//
13719// ScalarEvolution Class Implementation
13720//===----------------------------------------------------------------------===//
13721
13724 LoopInfo &LI)
13725 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13726 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13727 LoopDispositions(64), BlockDispositions(64) {
13728 // To use guards for proving predicates, we need to scan every instruction in
13729 // relevant basic blocks, and not just terminators. Doing this is a waste of
13730 // time if the IR does not actually contain any calls to
13731 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13732 //
13733 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13734 // to _add_ guards to the module when there weren't any before, and wants
13735 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13736 // efficient in lieu of being smart in that rather obscure case.
13737
13738 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13739 F.getParent(), Intrinsic::experimental_guard);
13740 HasGuards = GuardDecl && !GuardDecl->use_empty();
13741}
13742
13744 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13745 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13746 ValueExprMap(std::move(Arg.ValueExprMap)),
13747 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13748 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13749 PendingMerges(std::move(Arg.PendingMerges)),
13750 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13751 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13752 PredicatedBackedgeTakenCounts(
13753 std::move(Arg.PredicatedBackedgeTakenCounts)),
13754 BECountUsers(std::move(Arg.BECountUsers)),
13755 ConstantEvolutionLoopExitValue(
13756 std::move(Arg.ConstantEvolutionLoopExitValue)),
13757 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13758 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13759 LoopDispositions(std::move(Arg.LoopDispositions)),
13760 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13761 BlockDispositions(std::move(Arg.BlockDispositions)),
13762 SCEVUsers(std::move(Arg.SCEVUsers)),
13763 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13764 SignedRanges(std::move(Arg.SignedRanges)),
13765 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13766 UniquePreds(std::move(Arg.UniquePreds)),
13767 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13768 LoopUsers(std::move(Arg.LoopUsers)),
13769 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13770 FirstUnknown(Arg.FirstUnknown) {
13771 Arg.FirstUnknown = nullptr;
13772}
13773
13775 // Iterate through all the SCEVUnknown instances and call their
13776 // destructors, so that they release their references to their values.
13777 for (SCEVUnknown *U = FirstUnknown; U;) {
13778 SCEVUnknown *Tmp = U;
13779 U = U->Next;
13780 Tmp->~SCEVUnknown();
13781 }
13782 FirstUnknown = nullptr;
13783
13784 ExprValueMap.clear();
13785 ValueExprMap.clear();
13786 HasRecMap.clear();
13787 BackedgeTakenCounts.clear();
13788 PredicatedBackedgeTakenCounts.clear();
13789
13790 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13791 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13792 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13793 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13794 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13795}
13796
13800
13801/// When printing a top-level SCEV for trip counts, it's helpful to include
13802/// a type for constants which are otherwise hard to disambiguate.
13803static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13804 if (isa<SCEVConstant>(S))
13805 OS << *S->getType() << " ";
13806 OS << *S;
13807}
13808
13810 const Loop *L) {
13811 // Print all inner loops first
13812 for (Loop *I : *L)
13813 PrintLoopInfo(OS, SE, I);
13814
13815 OS << "Loop ";
13816 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13817 OS << ": ";
13818
13819 SmallVector<BasicBlock *, 8> ExitingBlocks;
13820 L->getExitingBlocks(ExitingBlocks);
13821 if (ExitingBlocks.size() != 1)
13822 OS << "<multiple exits> ";
13823
13824 auto *BTC = SE->getBackedgeTakenCount(L);
13825 if (!isa<SCEVCouldNotCompute>(BTC)) {
13826 OS << "backedge-taken count is ";
13827 PrintSCEVWithTypeHint(OS, BTC);
13828 } else
13829 OS << "Unpredictable backedge-taken count.";
13830 OS << "\n";
13831
13832 if (ExitingBlocks.size() > 1)
13833 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13834 OS << " exit count for " << ExitingBlock->getName() << ": ";
13835 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13836 PrintSCEVWithTypeHint(OS, EC);
13837 if (isa<SCEVCouldNotCompute>(EC)) {
13838 // Retry with predicates.
13840 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13841 if (!isa<SCEVCouldNotCompute>(EC)) {
13842 OS << "\n predicated exit count for " << ExitingBlock->getName()
13843 << ": ";
13844 PrintSCEVWithTypeHint(OS, EC);
13845 OS << "\n Predicates:\n";
13846 for (const auto *P : Predicates)
13847 P->print(OS, 4);
13848 }
13849 }
13850 OS << "\n";
13851 }
13852
13853 OS << "Loop ";
13854 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13855 OS << ": ";
13856
13857 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13858 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13859 OS << "constant max backedge-taken count is ";
13860 PrintSCEVWithTypeHint(OS, ConstantBTC);
13862 OS << ", actual taken count either this or zero.";
13863 } else {
13864 OS << "Unpredictable constant max backedge-taken count. ";
13865 }
13866
13867 OS << "\n"
13868 "Loop ";
13869 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13870 OS << ": ";
13871
13872 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13873 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13874 OS << "symbolic max backedge-taken count is ";
13875 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13877 OS << ", actual taken count either this or zero.";
13878 } else {
13879 OS << "Unpredictable symbolic max backedge-taken count. ";
13880 }
13881 OS << "\n";
13882
13883 if (ExitingBlocks.size() > 1)
13884 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13885 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13886 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13888 PrintSCEVWithTypeHint(OS, ExitBTC);
13889 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13890 // Retry with predicates.
13892 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13894 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13895 OS << "\n predicated symbolic max exit count for "
13896 << ExitingBlock->getName() << ": ";
13897 PrintSCEVWithTypeHint(OS, ExitBTC);
13898 OS << "\n Predicates:\n";
13899 for (const auto *P : Predicates)
13900 P->print(OS, 4);
13901 }
13902 }
13903 OS << "\n";
13904 }
13905
13907 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13908 if (PBT != BTC) {
13909 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13910 OS << "Loop ";
13911 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13912 OS << ": ";
13913 if (!isa<SCEVCouldNotCompute>(PBT)) {
13914 OS << "Predicated backedge-taken count is ";
13915 PrintSCEVWithTypeHint(OS, PBT);
13916 } else
13917 OS << "Unpredictable predicated backedge-taken count.";
13918 OS << "\n";
13919 OS << " Predicates:\n";
13920 for (const auto *P : Preds)
13921 P->print(OS, 4);
13922 }
13923 Preds.clear();
13924
13925 auto *PredConstantMax =
13927 if (PredConstantMax != ConstantBTC) {
13928 assert(!Preds.empty() &&
13929 "different predicated constant max BTC but no predicates");
13930 OS << "Loop ";
13931 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13932 OS << ": ";
13933 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13934 OS << "Predicated constant max backedge-taken count is ";
13935 PrintSCEVWithTypeHint(OS, PredConstantMax);
13936 } else
13937 OS << "Unpredictable predicated constant max backedge-taken count.";
13938 OS << "\n";
13939 OS << " Predicates:\n";
13940 for (const auto *P : Preds)
13941 P->print(OS, 4);
13942 }
13943 Preds.clear();
13944
13945 auto *PredSymbolicMax =
13947 if (SymbolicBTC != PredSymbolicMax) {
13948 assert(!Preds.empty() &&
13949 "Different predicated symbolic max BTC, but no predicates");
13950 OS << "Loop ";
13951 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13952 OS << ": ";
13953 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13954 OS << "Predicated symbolic max backedge-taken count is ";
13955 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13956 } else
13957 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13958 OS << "\n";
13959 OS << " Predicates:\n";
13960 for (const auto *P : Preds)
13961 P->print(OS, 4);
13962 }
13963
13965 OS << "Loop ";
13966 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13967 OS << ": ";
13968 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13969 }
13970}
13971
13972namespace llvm {
13974 switch (LD) {
13976 OS << "Variant";
13977 break;
13979 OS << "Invariant";
13980 break;
13982 OS << "Computable";
13983 break;
13984 }
13985 return OS;
13986}
13987
13989 switch (BD) {
13991 OS << "DoesNotDominate";
13992 break;
13994 OS << "Dominates";
13995 break;
13997 OS << "ProperlyDominates";
13998 break;
13999 }
14000 return OS;
14001}
14002} // namespace llvm
14003
14005 // ScalarEvolution's implementation of the print method is to print
14006 // out SCEV values of all instructions that are interesting. Doing
14007 // this potentially causes it to create new SCEV objects though,
14008 // which technically conflicts with the const qualifier. This isn't
14009 // observable from outside the class though, so casting away the
14010 // const isn't dangerous.
14011 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14012
14013 if (ClassifyExpressions) {
14014 OS << "Classifying expressions for: ";
14015 F.printAsOperand(OS, /*PrintType=*/false);
14016 OS << "\n";
14017 for (Instruction &I : instructions(F))
14018 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14019 OS << I << '\n';
14020 OS << " --> ";
14021 const SCEV *SV = SE.getSCEV(&I);
14022 SV->print(OS);
14023 if (!isa<SCEVCouldNotCompute>(SV)) {
14024 OS << " U: ";
14025 SE.getUnsignedRange(SV).print(OS);
14026 OS << " S: ";
14027 SE.getSignedRange(SV).print(OS);
14028 }
14029
14030 const Loop *L = LI.getLoopFor(I.getParent());
14031
14032 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14033 if (AtUse != SV) {
14034 OS << " --> ";
14035 AtUse->print(OS);
14036 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14037 OS << " U: ";
14038 SE.getUnsignedRange(AtUse).print(OS);
14039 OS << " S: ";
14040 SE.getSignedRange(AtUse).print(OS);
14041 }
14042 }
14043
14044 if (L) {
14045 OS << "\t\t" "Exits: ";
14046 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14047 if (!SE.isLoopInvariant(ExitValue, L)) {
14048 OS << "<<Unknown>>";
14049 } else {
14050 OS << *ExitValue;
14051 }
14052
14053 bool First = true;
14054 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14055 if (First) {
14056 OS << "\t\t" "LoopDispositions: { ";
14057 First = false;
14058 } else {
14059 OS << ", ";
14060 }
14061
14062 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14063 OS << ": " << SE.getLoopDisposition(SV, Iter);
14064 }
14065
14066 for (const auto *InnerL : depth_first(L)) {
14067 if (InnerL == L)
14068 continue;
14069 if (First) {
14070 OS << "\t\t" "LoopDispositions: { ";
14071 First = false;
14072 } else {
14073 OS << ", ";
14074 }
14075
14076 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14077 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14078 }
14079
14080 OS << " }";
14081 }
14082
14083 OS << "\n";
14084 }
14085 }
14086
14087 OS << "Determining loop execution counts for: ";
14088 F.printAsOperand(OS, /*PrintType=*/false);
14089 OS << "\n";
14090 for (Loop *I : LI)
14091 PrintLoopInfo(OS, &SE, I);
14092}
14093
14096 auto &Values = LoopDispositions[S];
14097 for (auto &V : Values) {
14098 if (V.getPointer() == L)
14099 return V.getInt();
14100 }
14101 Values.emplace_back(L, LoopVariant);
14102 LoopDisposition D = computeLoopDisposition(S, L);
14103 auto &Values2 = LoopDispositions[S];
14104 for (auto &V : llvm::reverse(Values2)) {
14105 if (V.getPointer() == L) {
14106 V.setInt(D);
14107 break;
14108 }
14109 }
14110 return D;
14111}
14112
14114ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14115 switch (S->getSCEVType()) {
14116 case scConstant:
14117 case scVScale:
14118 return LoopInvariant;
14119 case scAddRecExpr: {
14120 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14121
14122 // If L is the addrec's loop, it's computable.
14123 if (AR->getLoop() == L)
14124 return LoopComputable;
14125
14126 // Add recurrences are never invariant in the function-body (null loop).
14127 if (!L)
14128 return LoopVariant;
14129
14130 // Everything that is not defined at loop entry is variant.
14131 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14132 return LoopVariant;
14133 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14134 " dominate the contained loop's header?");
14135
14136 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14137 if (AR->getLoop()->contains(L))
14138 return LoopInvariant;
14139
14140 // This recurrence is variant w.r.t. L if any of its operands
14141 // are variant.
14142 for (const auto *Op : AR->operands())
14143 if (!isLoopInvariant(Op, L))
14144 return LoopVariant;
14145
14146 // Otherwise it's loop-invariant.
14147 return LoopInvariant;
14148 }
14149 case scTruncate:
14150 case scZeroExtend:
14151 case scSignExtend:
14152 case scPtrToInt:
14153 case scAddExpr:
14154 case scMulExpr:
14155 case scUDivExpr:
14156 case scUMaxExpr:
14157 case scSMaxExpr:
14158 case scUMinExpr:
14159 case scSMinExpr:
14160 case scSequentialUMinExpr: {
14161 bool HasVarying = false;
14162 for (const auto *Op : S->operands()) {
14164 if (D == LoopVariant)
14165 return LoopVariant;
14166 if (D == LoopComputable)
14167 HasVarying = true;
14168 }
14169 return HasVarying ? LoopComputable : LoopInvariant;
14170 }
14171 case scUnknown:
14172 // All non-instruction values are loop invariant. All instructions are loop
14173 // invariant if they are not contained in the specified loop.
14174 // Instructions are never considered invariant in the function body
14175 // (null loop) because they are defined within the "loop".
14176 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14177 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14178 return LoopInvariant;
14179 case scCouldNotCompute:
14180 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14181 }
14182 llvm_unreachable("Unknown SCEV kind!");
14183}
14184
14186 return getLoopDisposition(S, L) == LoopInvariant;
14187}
14188
14190 return getLoopDisposition(S, L) == LoopComputable;
14191}
14192
14195 auto &Values = BlockDispositions[S];
14196 for (auto &V : Values) {
14197 if (V.getPointer() == BB)
14198 return V.getInt();
14199 }
14200 Values.emplace_back(BB, DoesNotDominateBlock);
14201 BlockDisposition D = computeBlockDisposition(S, BB);
14202 auto &Values2 = BlockDispositions[S];
14203 for (auto &V : llvm::reverse(Values2)) {
14204 if (V.getPointer() == BB) {
14205 V.setInt(D);
14206 break;
14207 }
14208 }
14209 return D;
14210}
14211
14213ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14214 switch (S->getSCEVType()) {
14215 case scConstant:
14216 case scVScale:
14218 case scAddRecExpr: {
14219 // This uses a "dominates" query instead of "properly dominates" query
14220 // to test for proper dominance too, because the instruction which
14221 // produces the addrec's value is a PHI, and a PHI effectively properly
14222 // dominates its entire containing block.
14223 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14224 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14225 return DoesNotDominateBlock;
14226
14227 // Fall through into SCEVNAryExpr handling.
14228 [[fallthrough]];
14229 }
14230 case scTruncate:
14231 case scZeroExtend:
14232 case scSignExtend:
14233 case scPtrToInt:
14234 case scAddExpr:
14235 case scMulExpr:
14236 case scUDivExpr:
14237 case scUMaxExpr:
14238 case scSMaxExpr:
14239 case scUMinExpr:
14240 case scSMinExpr:
14241 case scSequentialUMinExpr: {
14242 bool Proper = true;
14243 for (const SCEV *NAryOp : S->operands()) {
14245 if (D == DoesNotDominateBlock)
14246 return DoesNotDominateBlock;
14247 if (D == DominatesBlock)
14248 Proper = false;
14249 }
14250 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14251 }
14252 case scUnknown:
14253 if (Instruction *I =
14254 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14255 if (I->getParent() == BB)
14256 return DominatesBlock;
14257 if (DT.properlyDominates(I->getParent(), BB))
14259 return DoesNotDominateBlock;
14260 }
14262 case scCouldNotCompute:
14263 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14264 }
14265 llvm_unreachable("Unknown SCEV kind!");
14266}
14267
14268bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14269 return getBlockDisposition(S, BB) >= DominatesBlock;
14270}
14271
14274}
14275
14276bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14277 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14278}
14279
14280void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14281 bool Predicated) {
14282 auto &BECounts =
14283 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14284 auto It = BECounts.find(L);
14285 if (It != BECounts.end()) {
14286 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14287 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14288 if (!isa<SCEVConstant>(S)) {
14289 auto UserIt = BECountUsers.find(S);
14290 assert(UserIt != BECountUsers.end());
14291 UserIt->second.erase({L, Predicated});
14292 }
14293 }
14294 }
14295 BECounts.erase(It);
14296 }
14297}
14298
14299void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14300 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14301 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14302
14303 while (!Worklist.empty()) {
14304 const SCEV *Curr = Worklist.pop_back_val();
14305 auto Users = SCEVUsers.find(Curr);
14306 if (Users != SCEVUsers.end())
14307 for (const auto *User : Users->second)
14308 if (ToForget.insert(User).second)
14309 Worklist.push_back(User);
14310 }
14311
14312 for (const auto *S : ToForget)
14313 forgetMemoizedResultsImpl(S);
14314
14315 for (auto I = PredicatedSCEVRewrites.begin();
14316 I != PredicatedSCEVRewrites.end();) {
14317 std::pair<const SCEV *, const Loop *> Entry = I->first;
14318 if (ToForget.count(Entry.first))
14319 PredicatedSCEVRewrites.erase(I++);
14320 else
14321 ++I;
14322 }
14323}
14324
14325void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14326 LoopDispositions.erase(S);
14327 BlockDispositions.erase(S);
14328 UnsignedRanges.erase(S);
14329 SignedRanges.erase(S);
14330 HasRecMap.erase(S);
14331 ConstantMultipleCache.erase(S);
14332
14333 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14334 UnsignedWrapViaInductionTried.erase(AR);
14335 SignedWrapViaInductionTried.erase(AR);
14336 }
14337
14338 auto ExprIt = ExprValueMap.find(S);
14339 if (ExprIt != ExprValueMap.end()) {
14340 for (Value *V : ExprIt->second) {
14341 auto ValueIt = ValueExprMap.find_as(V);
14342 if (ValueIt != ValueExprMap.end())
14343 ValueExprMap.erase(ValueIt);
14344 }
14345 ExprValueMap.erase(ExprIt);
14346 }
14347
14348 auto ScopeIt = ValuesAtScopes.find(S);
14349 if (ScopeIt != ValuesAtScopes.end()) {
14350 for (const auto &Pair : ScopeIt->second)
14351 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14352 llvm::erase(ValuesAtScopesUsers[Pair.second],
14353 std::make_pair(Pair.first, S));
14354 ValuesAtScopes.erase(ScopeIt);
14355 }
14356
14357 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14358 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14359 for (const auto &Pair : ScopeUserIt->second)
14360 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14361 ValuesAtScopesUsers.erase(ScopeUserIt);
14362 }
14363
14364 auto BEUsersIt = BECountUsers.find(S);
14365 if (BEUsersIt != BECountUsers.end()) {
14366 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14367 auto Copy = BEUsersIt->second;
14368 for (const auto &Pair : Copy)
14369 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14370 BECountUsers.erase(BEUsersIt);
14371 }
14372
14373 auto FoldUser = FoldCacheUser.find(S);
14374 if (FoldUser != FoldCacheUser.end())
14375 for (auto &KV : FoldUser->second)
14376 FoldCache.erase(KV);
14377 FoldCacheUser.erase(S);
14378}
14379
14380void
14381ScalarEvolution::getUsedLoops(const SCEV *S,
14382 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14383 struct FindUsedLoops {
14384 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14385 : LoopsUsed(LoopsUsed) {}
14386 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14387 bool follow(const SCEV *S) {
14388 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14389 LoopsUsed.insert(AR->getLoop());
14390 return true;
14391 }
14392
14393 bool isDone() const { return false; }
14394 };
14395
14396 FindUsedLoops F(LoopsUsed);
14397 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14398}
14399
14400void ScalarEvolution::getReachableBlocks(
14403 Worklist.push_back(&F.getEntryBlock());
14404 while (!Worklist.empty()) {
14405 BasicBlock *BB = Worklist.pop_back_val();
14406 if (!Reachable.insert(BB).second)
14407 continue;
14408
14409 Value *Cond;
14410 BasicBlock *TrueBB, *FalseBB;
14411 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14412 m_BasicBlock(FalseBB)))) {
14413 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14414 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14415 continue;
14416 }
14417
14418 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14419 const SCEV *L = getSCEV(Cmp->getOperand(0));
14420 const SCEV *R = getSCEV(Cmp->getOperand(1));
14421 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14422 Worklist.push_back(TrueBB);
14423 continue;
14424 }
14425 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14426 R)) {
14427 Worklist.push_back(FalseBB);
14428 continue;
14429 }
14430 }
14431 }
14432
14433 append_range(Worklist, successors(BB));
14434 }
14435}
14436
14438 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14439 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14440
14441 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14442
14443 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14444 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14445 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14446
14447 const SCEV *visitConstant(const SCEVConstant *Constant) {
14448 return SE.getConstant(Constant->getAPInt());
14449 }
14450
14451 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14452 return SE.getUnknown(Expr->getValue());
14453 }
14454
14455 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14456 return SE.getCouldNotCompute();
14457 }
14458 };
14459
14460 SCEVMapper SCM(SE2);
14461 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14462 SE2.getReachableBlocks(ReachableBlocks, F);
14463
14464 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14465 if (containsUndefs(Old) || containsUndefs(New)) {
14466 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14467 // not propagate undef aggressively). This means we can (and do) fail
14468 // verification in cases where a transform makes a value go from "undef"
14469 // to "undef+1" (say). The transform is fine, since in both cases the
14470 // result is "undef", but SCEV thinks the value increased by 1.
14471 return nullptr;
14472 }
14473
14474 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14475 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14476 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14477 return nullptr;
14478
14479 return Delta;
14480 };
14481
14482 while (!LoopStack.empty()) {
14483 auto *L = LoopStack.pop_back_val();
14484 llvm::append_range(LoopStack, *L);
14485
14486 // Only verify BECounts in reachable loops. For an unreachable loop,
14487 // any BECount is legal.
14488 if (!ReachableBlocks.contains(L->getHeader()))
14489 continue;
14490
14491 // Only verify cached BECounts. Computing new BECounts may change the
14492 // results of subsequent SCEV uses.
14493 auto It = BackedgeTakenCounts.find(L);
14494 if (It == BackedgeTakenCounts.end())
14495 continue;
14496
14497 auto *CurBECount =
14498 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14499 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14500
14501 if (CurBECount == SE2.getCouldNotCompute() ||
14502 NewBECount == SE2.getCouldNotCompute()) {
14503 // NB! This situation is legal, but is very suspicious -- whatever pass
14504 // change the loop to make a trip count go from could not compute to
14505 // computable or vice-versa *should have* invalidated SCEV. However, we
14506 // choose not to assert here (for now) since we don't want false
14507 // positives.
14508 continue;
14509 }
14510
14511 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14512 SE.getTypeSizeInBits(NewBECount->getType()))
14513 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14514 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14515 SE.getTypeSizeInBits(NewBECount->getType()))
14516 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14517
14518 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14519 if (Delta && !Delta->isZero()) {
14520 dbgs() << "Trip Count for " << *L << " Changed!\n";
14521 dbgs() << "Old: " << *CurBECount << "\n";
14522 dbgs() << "New: " << *NewBECount << "\n";
14523 dbgs() << "Delta: " << *Delta << "\n";
14524 std::abort();
14525 }
14526 }
14527
14528 // Collect all valid loops currently in LoopInfo.
14529 SmallPtrSet<Loop *, 32> ValidLoops;
14530 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14531 while (!Worklist.empty()) {
14532 Loop *L = Worklist.pop_back_val();
14533 if (ValidLoops.insert(L).second)
14534 Worklist.append(L->begin(), L->end());
14535 }
14536 for (const auto &KV : ValueExprMap) {
14537#ifndef NDEBUG
14538 // Check for SCEV expressions referencing invalid/deleted loops.
14539 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14540 assert(ValidLoops.contains(AR->getLoop()) &&
14541 "AddRec references invalid loop");
14542 }
14543#endif
14544
14545 // Check that the value is also part of the reverse map.
14546 auto It = ExprValueMap.find(KV.second);
14547 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14548 dbgs() << "Value " << *KV.first
14549 << " is in ValueExprMap but not in ExprValueMap\n";
14550 std::abort();
14551 }
14552
14553 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14554 if (!ReachableBlocks.contains(I->getParent()))
14555 continue;
14556 const SCEV *OldSCEV = SCM.visit(KV.second);
14557 const SCEV *NewSCEV = SE2.getSCEV(I);
14558 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14559 if (Delta && !Delta->isZero()) {
14560 dbgs() << "SCEV for value " << *I << " changed!\n"
14561 << "Old: " << *OldSCEV << "\n"
14562 << "New: " << *NewSCEV << "\n"
14563 << "Delta: " << *Delta << "\n";
14564 std::abort();
14565 }
14566 }
14567 }
14568
14569 for (const auto &KV : ExprValueMap) {
14570 for (Value *V : KV.second) {
14571 const SCEV *S = ValueExprMap.lookup(V);
14572 if (!S) {
14573 dbgs() << "Value " << *V
14574 << " is in ExprValueMap but not in ValueExprMap\n";
14575 std::abort();
14576 }
14577 if (S != KV.first) {
14578 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14579 << *KV.first << "\n";
14580 std::abort();
14581 }
14582 }
14583 }
14584
14585 // Verify integrity of SCEV users.
14586 for (const auto &S : UniqueSCEVs) {
14587 for (const auto *Op : S.operands()) {
14588 // We do not store dependencies of constants.
14589 if (isa<SCEVConstant>(Op))
14590 continue;
14591 auto It = SCEVUsers.find(Op);
14592 if (It != SCEVUsers.end() && It->second.count(&S))
14593 continue;
14594 dbgs() << "Use of operand " << *Op << " by user " << S
14595 << " is not being tracked!\n";
14596 std::abort();
14597 }
14598 }
14599
14600 // Verify integrity of ValuesAtScopes users.
14601 for (const auto &ValueAndVec : ValuesAtScopes) {
14602 const SCEV *Value = ValueAndVec.first;
14603 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14604 const Loop *L = LoopAndValueAtScope.first;
14605 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14606 if (!isa<SCEVConstant>(ValueAtScope)) {
14607 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14608 if (It != ValuesAtScopesUsers.end() &&
14609 is_contained(It->second, std::make_pair(L, Value)))
14610 continue;
14611 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14612 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14613 std::abort();
14614 }
14615 }
14616 }
14617
14618 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14619 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14620 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14621 const Loop *L = LoopAndValue.first;
14622 const SCEV *Value = LoopAndValue.second;
14624 auto It = ValuesAtScopes.find(Value);
14625 if (It != ValuesAtScopes.end() &&
14626 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14627 continue;
14628 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14629 << *ValueAtScope << " missing in ValuesAtScopes\n";
14630 std::abort();
14631 }
14632 }
14633
14634 // Verify integrity of BECountUsers.
14635 auto VerifyBECountUsers = [&](bool Predicated) {
14636 auto &BECounts =
14637 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14638 for (const auto &LoopAndBEInfo : BECounts) {
14639 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14640 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14641 if (!isa<SCEVConstant>(S)) {
14642 auto UserIt = BECountUsers.find(S);
14643 if (UserIt != BECountUsers.end() &&
14644 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14645 continue;
14646 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14647 << " missing from BECountUsers\n";
14648 std::abort();
14649 }
14650 }
14651 }
14652 }
14653 };
14654 VerifyBECountUsers(/* Predicated */ false);
14655 VerifyBECountUsers(/* Predicated */ true);
14656
14657 // Verify intergity of loop disposition cache.
14658 for (auto &[S, Values] : LoopDispositions) {
14659 for (auto [Loop, CachedDisposition] : Values) {
14660 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14661 if (CachedDisposition != RecomputedDisposition) {
14662 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14663 << " is incorrect: cached " << CachedDisposition << ", actual "
14664 << RecomputedDisposition << "\n";
14665 std::abort();
14666 }
14667 }
14668 }
14669
14670 // Verify integrity of the block disposition cache.
14671 for (auto &[S, Values] : BlockDispositions) {
14672 for (auto [BB, CachedDisposition] : Values) {
14673 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14674 if (CachedDisposition != RecomputedDisposition) {
14675 dbgs() << "Cached disposition of " << *S << " for block %"
14676 << BB->getName() << " is incorrect: cached " << CachedDisposition
14677 << ", actual " << RecomputedDisposition << "\n";
14678 std::abort();
14679 }
14680 }
14681 }
14682
14683 // Verify FoldCache/FoldCacheUser caches.
14684 for (auto [FoldID, Expr] : FoldCache) {
14685 auto I = FoldCacheUser.find(Expr);
14686 if (I == FoldCacheUser.end()) {
14687 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14688 << "!\n";
14689 std::abort();
14690 }
14691 if (!is_contained(I->second, FoldID)) {
14692 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14693 std::abort();
14694 }
14695 }
14696 for (auto [Expr, IDs] : FoldCacheUser) {
14697 for (auto &FoldID : IDs) {
14698 const SCEV *S = FoldCache.lookup(FoldID);
14699 if (!S) {
14700 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14701 << "!\n";
14702 std::abort();
14703 }
14704 if (S != Expr) {
14705 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14706 << " != " << *Expr << "!\n";
14707 std::abort();
14708 }
14709 }
14710 }
14711
14712 // Verify that ConstantMultipleCache computations are correct. We check that
14713 // cached multiples and recomputed multiples are multiples of each other to
14714 // verify correctness. It is possible that a recomputed multiple is different
14715 // from the cached multiple due to strengthened no wrap flags or changes in
14716 // KnownBits computations.
14717 for (auto [S, Multiple] : ConstantMultipleCache) {
14718 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14719 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14720 Multiple.urem(RecomputedMultiple) != 0 &&
14721 RecomputedMultiple.urem(Multiple) != 0)) {
14722 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14723 << *S << " : Computed " << RecomputedMultiple
14724 << " but cache contains " << Multiple << "!\n";
14725 std::abort();
14726 }
14727 }
14728}
14729
14731 Function &F, const PreservedAnalyses &PA,
14732 FunctionAnalysisManager::Invalidator &Inv) {
14733 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14734 // of its dependencies is invalidated.
14735 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14736 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14737 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14738 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14739 Inv.invalidate<LoopAnalysis>(F, PA);
14740}
14741
14742AnalysisKey ScalarEvolutionAnalysis::Key;
14743
14746 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14747 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14748 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14749 auto &LI = AM.getResult<LoopAnalysis>(F);
14750 return ScalarEvolution(F, TLI, AC, DT, LI);
14751}
14752
14758
14761 // For compatibility with opt's -analyze feature under legacy pass manager
14762 // which was not ported to NPM. This keeps tests using
14763 // update_analyze_test_checks.py working.
14764 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14765 << F.getName() << "':\n";
14767 return PreservedAnalyses::all();
14768}
14769
14771 "Scalar Evolution Analysis", false, true)
14777 "Scalar Evolution Analysis", false, true)
14778
14780
14782
14784 SE.reset(new ScalarEvolution(
14786 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14788 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14789 return false;
14790}
14791
14793
14795 SE->print(OS);
14796}
14797
14799 if (!VerifySCEV)
14800 return;
14801
14802 SE->verify();
14803}
14804
14812
14814 const SCEV *RHS) {
14815 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14816}
14817
14818const SCEVPredicate *
14820 const SCEV *LHS, const SCEV *RHS) {
14822 assert(LHS->getType() == RHS->getType() &&
14823 "Type mismatch between LHS and RHS");
14824 // Unique this node based on the arguments
14825 ID.AddInteger(SCEVPredicate::P_Compare);
14826 ID.AddInteger(Pred);
14827 ID.AddPointer(LHS);
14828 ID.AddPointer(RHS);
14829 void *IP = nullptr;
14830 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14831 return S;
14832 SCEVComparePredicate *Eq = new (SCEVAllocator)
14833 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14834 UniquePreds.InsertNode(Eq, IP);
14835 return Eq;
14836}
14837
14839 const SCEVAddRecExpr *AR,
14842 // Unique this node based on the arguments
14843 ID.AddInteger(SCEVPredicate::P_Wrap);
14844 ID.AddPointer(AR);
14845 ID.AddInteger(AddedFlags);
14846 void *IP = nullptr;
14847 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14848 return S;
14849 auto *OF = new (SCEVAllocator)
14850 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14851 UniquePreds.InsertNode(OF, IP);
14852 return OF;
14853}
14854
14855namespace {
14856
14857class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14858public:
14859
14860 /// Rewrites \p S in the context of a loop L and the SCEV predication
14861 /// infrastructure.
14862 ///
14863 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14864 /// equivalences present in \p Pred.
14865 ///
14866 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14867 /// \p NewPreds such that the result will be an AddRecExpr.
14868 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14870 const SCEVPredicate *Pred) {
14871 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14872 return Rewriter.visit(S);
14873 }
14874
14875 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14876 if (Pred) {
14877 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14878 for (const auto *Pred : U->getPredicates())
14879 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14880 if (IPred->getLHS() == Expr &&
14881 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14882 return IPred->getRHS();
14883 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14884 if (IPred->getLHS() == Expr &&
14885 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14886 return IPred->getRHS();
14887 }
14888 }
14889 return convertToAddRecWithPreds(Expr);
14890 }
14891
14892 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14893 const SCEV *Operand = visit(Expr->getOperand());
14894 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14895 if (AR && AR->getLoop() == L && AR->isAffine()) {
14896 // This couldn't be folded because the operand didn't have the nuw
14897 // flag. Add the nusw flag as an assumption that we could make.
14898 const SCEV *Step = AR->getStepRecurrence(SE);
14899 Type *Ty = Expr->getType();
14900 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14901 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14902 SE.getSignExtendExpr(Step, Ty), L,
14903 AR->getNoWrapFlags());
14904 }
14905 return SE.getZeroExtendExpr(Operand, Expr->getType());
14906 }
14907
14908 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14909 const SCEV *Operand = visit(Expr->getOperand());
14910 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14911 if (AR && AR->getLoop() == L && AR->isAffine()) {
14912 // This couldn't be folded because the operand didn't have the nsw
14913 // flag. Add the nssw flag as an assumption that we could make.
14914 const SCEV *Step = AR->getStepRecurrence(SE);
14915 Type *Ty = Expr->getType();
14916 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14917 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14918 SE.getSignExtendExpr(Step, Ty), L,
14919 AR->getNoWrapFlags());
14920 }
14921 return SE.getSignExtendExpr(Operand, Expr->getType());
14922 }
14923
14924private:
14925 explicit SCEVPredicateRewriter(
14926 const Loop *L, ScalarEvolution &SE,
14927 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14928 const SCEVPredicate *Pred)
14929 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14930
14931 bool addOverflowAssumption(const SCEVPredicate *P) {
14932 if (!NewPreds) {
14933 // Check if we've already made this assumption.
14934 return Pred && Pred->implies(P, SE);
14935 }
14936 NewPreds->push_back(P);
14937 return true;
14938 }
14939
14940 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14942 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14943 return addOverflowAssumption(A);
14944 }
14945
14946 // If \p Expr represents a PHINode, we try to see if it can be represented
14947 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14948 // to add this predicate as a runtime overflow check, we return the AddRec.
14949 // If \p Expr does not meet these conditions (is not a PHI node, or we
14950 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14951 // return \p Expr.
14952 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14953 if (!isa<PHINode>(Expr->getValue()))
14954 return Expr;
14955 std::optional<
14956 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14957 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14958 if (!PredicatedRewrite)
14959 return Expr;
14960 for (const auto *P : PredicatedRewrite->second){
14961 // Wrap predicates from outer loops are not supported.
14962 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14963 if (L != WP->getExpr()->getLoop())
14964 return Expr;
14965 }
14966 if (!addOverflowAssumption(P))
14967 return Expr;
14968 }
14969 return PredicatedRewrite->first;
14970 }
14971
14972 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
14973 const SCEVPredicate *Pred;
14974 const Loop *L;
14975};
14976
14977} // end anonymous namespace
14978
14979const SCEV *
14981 const SCEVPredicate &Preds) {
14982 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14983}
14984
14986 const SCEV *S, const Loop *L,
14989 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14990 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14991
14992 if (!AddRec)
14993 return nullptr;
14994
14995 // Check if any of the transformed predicates is known to be false. In that
14996 // case, it doesn't make sense to convert to a predicated AddRec, as the
14997 // versioned loop will never execute.
14998 for (const SCEVPredicate *Pred : TransformPreds) {
14999 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15000 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15001 continue;
15002
15003 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15004 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15005 if (isa<SCEVCouldNotCompute>(ExitCount))
15006 continue;
15007
15008 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15009 if (!Step->isOne())
15010 continue;
15011
15012 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15013 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15014 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15015 return nullptr;
15016 }
15017
15018 // Since the transformation was successful, we can now transfer the SCEV
15019 // predicates.
15020 Preds.append(TransformPreds.begin(), TransformPreds.end());
15021
15022 return AddRec;
15023}
15024
15025/// SCEV predicates
15029
15031 const ICmpInst::Predicate Pred,
15032 const SCEV *LHS, const SCEV *RHS)
15033 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15034 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15035 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15036}
15037
15039 ScalarEvolution &SE) const {
15040 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15041
15042 if (!Op)
15043 return false;
15044
15045 if (Pred != ICmpInst::ICMP_EQ)
15046 return false;
15047
15048 return Op->LHS == LHS && Op->RHS == RHS;
15049}
15050
15051bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15052
15054 if (Pred == ICmpInst::ICMP_EQ)
15055 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15056 else
15057 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15058 << *RHS << "\n";
15059
15060}
15061
15063 const SCEVAddRecExpr *AR,
15064 IncrementWrapFlags Flags)
15065 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15066
15067const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15068
15070 ScalarEvolution &SE) const {
15071 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15072 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15073 return false;
15074
15075 if (Op->AR == AR)
15076 return true;
15077
15078 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15080 return false;
15081
15082 const SCEV *Start = AR->getStart();
15083 const SCEV *OpStart = Op->AR->getStart();
15084 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15085 return false;
15086
15087 // Reject pointers to different address spaces.
15088 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15089 return false;
15090
15091 const SCEV *Step = AR->getStepRecurrence(SE);
15092 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15093 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15094 return false;
15095
15096 // If both steps are positive, this implies N, if N's start and step are
15097 // ULE/SLE (for NSUW/NSSW) than this'.
15098 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15099 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15100 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15101
15102 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15103 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15104 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15105 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15106 : SE.getNoopOrSignExtend(Start, WiderTy);
15108 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15109 SE.isKnownPredicate(Pred, OpStart, Start);
15110}
15111
15113 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15114 IncrementWrapFlags IFlags = Flags;
15115
15116 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15117 IFlags = clearFlags(IFlags, IncrementNSSW);
15118
15119 return IFlags == IncrementAnyWrap;
15120}
15121
15122void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15123 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15125 OS << "<nusw>";
15127 OS << "<nssw>";
15128 OS << "\n";
15129}
15130
15133 ScalarEvolution &SE) {
15134 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15135 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15136
15137 // We can safely transfer the NSW flag as NSSW.
15138 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15139 ImpliedFlags = IncrementNSSW;
15140
15141 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15142 // If the increment is positive, the SCEV NUW flag will also imply the
15143 // WrapPredicate NUSW flag.
15144 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15145 if (Step->getValue()->getValue().isNonNegative())
15146 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15147 }
15148
15149 return ImpliedFlags;
15150}
15151
15152/// Union predicates don't get cached so create a dummy set ID for it.
15154 ScalarEvolution &SE)
15155 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15156 for (const auto *P : Preds)
15157 add(P, SE);
15158}
15159
15161 return all_of(Preds,
15162 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15163}
15164
15166 ScalarEvolution &SE) const {
15167 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15168 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15169 return this->implies(I, SE);
15170 });
15171
15172 return any_of(Preds,
15173 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15174}
15175
15177 for (const auto *Pred : Preds)
15178 Pred->print(OS, Depth);
15179}
15180
15181void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15182 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15183 for (const auto *Pred : Set->Preds)
15184 add(Pred, SE);
15185 return;
15186 }
15187
15188 // Implication checks are quadratic in the number of predicates. Stop doing
15189 // them if there are many predicates, as they should be too expensive to use
15190 // anyway at that point.
15191 bool CheckImplies = Preds.size() < 16;
15192
15193 // Only add predicate if it is not already implied by this union predicate.
15194 if (CheckImplies && implies(N, SE))
15195 return;
15196
15197 // Build a new vector containing the current predicates, except the ones that
15198 // are implied by the new predicate N.
15200 for (auto *P : Preds) {
15201 if (CheckImplies && N->implies(P, SE))
15202 continue;
15203 PrunedPreds.push_back(P);
15204 }
15205 Preds = std::move(PrunedPreds);
15206 Preds.push_back(N);
15207}
15208
15210 Loop &L)
15211 : SE(SE), L(L) {
15213 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15214}
15215
15218 for (const auto *Op : Ops)
15219 // We do not expect that forgetting cached data for SCEVConstants will ever
15220 // open any prospects for sharpening or introduce any correctness issues,
15221 // so we don't bother storing their dependencies.
15222 if (!isa<SCEVConstant>(Op))
15223 SCEVUsers[Op].insert(User);
15224}
15225
15227 const SCEV *Expr = SE.getSCEV(V);
15228 RewriteEntry &Entry = RewriteMap[Expr];
15229
15230 // If we already have an entry and the version matches, return it.
15231 if (Entry.second && Generation == Entry.first)
15232 return Entry.second;
15233
15234 // We found an entry but it's stale. Rewrite the stale entry
15235 // according to the current predicate.
15236 if (Entry.second)
15237 Expr = Entry.second;
15238
15239 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15240 Entry = {Generation, NewSCEV};
15241
15242 return NewSCEV;
15243}
15244
15246 if (!BackedgeCount) {
15248 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15249 for (const auto *P : Preds)
15250 addPredicate(*P);
15251 }
15252 return BackedgeCount;
15253}
15254
15256 if (!SymbolicMaxBackedgeCount) {
15258 SymbolicMaxBackedgeCount =
15259 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15260 for (const auto *P : Preds)
15261 addPredicate(*P);
15262 }
15263 return SymbolicMaxBackedgeCount;
15264}
15265
15267 if (!SmallConstantMaxTripCount) {
15269 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15270 for (const auto *P : Preds)
15271 addPredicate(*P);
15272 }
15273 return *SmallConstantMaxTripCount;
15274}
15275
15277 if (Preds->implies(&Pred, SE))
15278 return;
15279
15280 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15281 NewPreds.push_back(&Pred);
15282 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15283 updateGeneration();
15284}
15285
15287 return *Preds;
15288}
15289
15290void PredicatedScalarEvolution::updateGeneration() {
15291 // If the generation number wrapped recompute everything.
15292 if (++Generation == 0) {
15293 for (auto &II : RewriteMap) {
15294 const SCEV *Rewritten = II.second.second;
15295 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15296 }
15297 }
15298}
15299
15302 const SCEV *Expr = getSCEV(V);
15303 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15304
15305 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15306
15307 // Clear the statically implied flags.
15308 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15309 addPredicate(*SE.getWrapPredicate(AR, Flags));
15310
15311 auto II = FlagsMap.insert({V, Flags});
15312 if (!II.second)
15313 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15314}
15315
15318 const SCEV *Expr = getSCEV(V);
15319 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15320
15322 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15323
15324 auto II = FlagsMap.find(V);
15325
15326 if (II != FlagsMap.end())
15327 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15328
15330}
15331
15333 const SCEV *Expr = this->getSCEV(V);
15335 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15336
15337 if (!New)
15338 return nullptr;
15339
15340 for (const auto *P : NewPreds)
15341 addPredicate(*P);
15342
15343 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15344 return New;
15345}
15346
15349 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15350 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15351 SE)),
15352 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15353 for (auto I : Init.FlagsMap)
15354 FlagsMap.insert(I);
15355}
15356
15358 // For each block.
15359 for (auto *BB : L.getBlocks())
15360 for (auto &I : *BB) {
15361 if (!SE.isSCEVable(I.getType()))
15362 continue;
15363
15364 auto *Expr = SE.getSCEV(&I);
15365 auto II = RewriteMap.find(Expr);
15366
15367 if (II == RewriteMap.end())
15368 continue;
15369
15370 // Don't print things that are not interesting.
15371 if (II->second.second == Expr)
15372 continue;
15373
15374 OS.indent(Depth) << "[PSE]" << I << ":\n";
15375 OS.indent(Depth + 2) << *Expr << "\n";
15376 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15377 }
15378}
15379
15380// Match the mathematical pattern A - (A / B) * B, where A and B can be
15381// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15382// for URem with constant power-of-2 second operands.
15383// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15384// 4, A / B becomes X / 8).
15385bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15386 const SCEV *&RHS) {
15387 if (Expr->getType()->isPointerTy())
15388 return false;
15389
15390 // Try to match 'zext (trunc A to iB) to iY', which is used
15391 // for URem with constant power-of-2 second operands. Make sure the size of
15392 // the operand A matches the size of the whole expressions.
15393 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15394 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15395 LHS = Trunc->getOperand();
15396 // Bail out if the type of the LHS is larger than the type of the
15397 // expression for now.
15398 if (getTypeSizeInBits(LHS->getType()) >
15399 getTypeSizeInBits(Expr->getType()))
15400 return false;
15401 if (LHS->getType() != Expr->getType())
15402 LHS = getZeroExtendExpr(LHS, Expr->getType());
15403 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15404 << getTypeSizeInBits(Trunc->getType()));
15405 return true;
15406 }
15407 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15408 if (Add == nullptr || Add->getNumOperands() != 2)
15409 return false;
15410
15411 const SCEV *A = Add->getOperand(1);
15412 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15413
15414 if (Mul == nullptr)
15415 return false;
15416
15417 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15418 // (SomeExpr + (-(SomeExpr / B) * B)).
15419 if (Expr == getURemExpr(A, B)) {
15420 LHS = A;
15421 RHS = B;
15422 return true;
15423 }
15424 return false;
15425 };
15426
15427 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15428 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15429 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15430 MatchURemWithDivisor(Mul->getOperand(2));
15431
15432 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15433 if (Mul->getNumOperands() == 2)
15434 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15435 MatchURemWithDivisor(Mul->getOperand(0)) ||
15436 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15437 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15438 return false;
15439}
15440
15443 BasicBlock *Header = L->getHeader();
15444 BasicBlock *Pred = L->getLoopPredecessor();
15445 LoopGuards Guards(SE);
15446 if (!Pred)
15447 return Guards;
15449 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15450 return Guards;
15451}
15452
15453void ScalarEvolution::LoopGuards::collectFromPHI(
15455 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15457 unsigned Depth) {
15458 if (!SE.isSCEVable(Phi.getType()))
15459 return;
15460
15461 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15462 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15463 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15464 if (!VisitedBlocks.insert(InBlock).second)
15465 return {nullptr, scCouldNotCompute};
15466
15467 // Avoid analyzing unreachable blocks so that we don't get trapped
15468 // traversing cycles with ill-formed dominance or infinite cycles
15469 if (!SE.DT.isReachableFromEntry(InBlock))
15470 return {nullptr, scCouldNotCompute};
15471
15472 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15473 if (Inserted)
15474 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15475 Depth + 1);
15476 auto &RewriteMap = G->second.RewriteMap;
15477 if (RewriteMap.empty())
15478 return {nullptr, scCouldNotCompute};
15479 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15480 if (S == RewriteMap.end())
15481 return {nullptr, scCouldNotCompute};
15482 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15483 if (!SM)
15484 return {nullptr, scCouldNotCompute};
15485 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15486 return {C0, SM->getSCEVType()};
15487 return {nullptr, scCouldNotCompute};
15488 };
15489 auto MergeMinMaxConst = [](MinMaxPattern P1,
15490 MinMaxPattern P2) -> MinMaxPattern {
15491 auto [C1, T1] = P1;
15492 auto [C2, T2] = P2;
15493 if (!C1 || !C2 || T1 != T2)
15494 return {nullptr, scCouldNotCompute};
15495 switch (T1) {
15496 case scUMaxExpr:
15497 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15498 case scSMaxExpr:
15499 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15500 case scUMinExpr:
15501 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15502 case scSMinExpr:
15503 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15504 default:
15505 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15506 }
15507 };
15508 auto P = GetMinMaxConst(0);
15509 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15510 if (!P.first)
15511 break;
15512 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15513 }
15514 if (P.first) {
15515 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15517 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15518 Guards.RewriteMap.insert({LHS, RHS});
15519 }
15520}
15521
15522void ScalarEvolution::LoopGuards::collectFromBlock(
15523 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15524 const BasicBlock *Block, const BasicBlock *Pred,
15525 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15526
15528
15529 SmallVector<const SCEV *> ExprsToRewrite;
15530 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15531 const SCEV *RHS,
15532 DenseMap<const SCEV *, const SCEV *>
15533 &RewriteMap) {
15534 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15535 // replacement SCEV which isn't directly implied by the structure of that
15536 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15537 // legal. See the scoping rules for flags in the header to understand why.
15538
15539 // If LHS is a constant, apply information to the other expression.
15540 if (isa<SCEVConstant>(LHS)) {
15541 std::swap(LHS, RHS);
15543 }
15544
15545 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15546 // create this form when combining two checks of the form (X u< C2 + C1) and
15547 // (X >=u C1).
15548 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15549 &ExprsToRewrite]() {
15550 const SCEVConstant *C1;
15551 const SCEVUnknown *LHSUnknown;
15552 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15553 if (!match(LHS,
15554 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15555 !C2)
15556 return false;
15557
15558 auto ExactRegion =
15559 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15560 .sub(C1->getAPInt());
15561
15562 // Bail out, unless we have a non-wrapping, monotonic range.
15563 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15564 return false;
15565 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15566 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15567 I->second = SE.getUMaxExpr(
15568 SE.getConstant(ExactRegion.getUnsignedMin()),
15569 SE.getUMinExpr(RewrittenLHS,
15570 SE.getConstant(ExactRegion.getUnsignedMax())));
15571 ExprsToRewrite.push_back(LHSUnknown);
15572 return true;
15573 };
15574 if (MatchRangeCheckIdiom())
15575 return;
15576
15577 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15578 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15579 // the non-constant operand and in \p LHS the constant operand.
15580 auto IsMinMaxSCEVWithNonNegativeConstant =
15581 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15582 const SCEV *&RHS) {
15583 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15584 if (MinMax->getNumOperands() != 2)
15585 return false;
15586 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15587 if (C->getAPInt().isNegative())
15588 return false;
15589 SCTy = MinMax->getSCEVType();
15590 LHS = MinMax->getOperand(0);
15591 RHS = MinMax->getOperand(1);
15592 return true;
15593 }
15594 }
15595 return false;
15596 };
15597
15598 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15599 // constant, and returns their APInt in ExprVal and in DivisorVal.
15600 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15601 APInt &ExprVal, APInt &DivisorVal) {
15602 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15603 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15604 if (!ConstExpr || !ConstDivisor)
15605 return false;
15606 ExprVal = ConstExpr->getAPInt();
15607 DivisorVal = ConstDivisor->getAPInt();
15608 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15609 };
15610
15611 // Return a new SCEV that modifies \p Expr to the closest number divides by
15612 // \p Divisor and greater or equal than Expr.
15613 // For now, only handle constant Expr and Divisor.
15614 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15615 const SCEV *Divisor) {
15616 APInt ExprVal;
15617 APInt DivisorVal;
15618 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15619 return Expr;
15620 APInt Rem = ExprVal.urem(DivisorVal);
15621 if (!Rem.isZero())
15622 // return the SCEV: Expr + Divisor - Expr % Divisor
15623 return SE.getConstant(ExprVal + DivisorVal - Rem);
15624 return Expr;
15625 };
15626
15627 // Return a new SCEV that modifies \p Expr to the closest number divides by
15628 // \p Divisor and less or equal than Expr.
15629 // For now, only handle constant Expr and Divisor.
15630 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15631 const SCEV *Divisor) {
15632 APInt ExprVal;
15633 APInt DivisorVal;
15634 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15635 return Expr;
15636 APInt Rem = ExprVal.urem(DivisorVal);
15637 // return the SCEV: Expr - Expr % Divisor
15638 return SE.getConstant(ExprVal - Rem);
15639 };
15640
15641 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15642 // recursively. This is done by aligning up/down the constant value to the
15643 // Divisor.
15644 std::function<const SCEV *(const SCEV *, const SCEV *)>
15645 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15646 const SCEV *Divisor) {
15647 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15648 SCEVTypes SCTy;
15649 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15650 MinMaxRHS))
15651 return MinMaxExpr;
15652 auto IsMin =
15653 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15654 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15655 "Expected non-negative operand!");
15656 auto *DivisibleExpr =
15657 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15658 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15660 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15661 return SE.getMinMaxExpr(SCTy, Ops);
15662 };
15663
15664 // If we have LHS == 0, check if LHS is computing a property of some unknown
15665 // SCEV %v which we can rewrite %v to express explicitly.
15666 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15667 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15668 // explicitly express that.
15669 const SCEV *URemLHS = nullptr;
15670 const SCEV *URemRHS = nullptr;
15671 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15672 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15673 auto I = RewriteMap.find(LHSUnknown);
15674 const SCEV *RewrittenLHS =
15675 I != RewriteMap.end() ? I->second : LHSUnknown;
15676 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15677 const auto *Multiple =
15678 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15679 RewriteMap[LHSUnknown] = Multiple;
15680 ExprsToRewrite.push_back(LHSUnknown);
15681 return;
15682 }
15683 }
15684 }
15685
15686 // Do not apply information for constants or if RHS contains an AddRec.
15688 return;
15689
15690 // If RHS is SCEVUnknown, make sure the information is applied to it.
15692 std::swap(LHS, RHS);
15694 }
15695
15696 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15697 // and \p FromRewritten are the same (i.e. there has been no rewrite
15698 // registered for \p From), then puts this value in the list of rewritten
15699 // expressions.
15700 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15701 const SCEV *To) {
15702 if (From == FromRewritten)
15703 ExprsToRewrite.push_back(From);
15704 RewriteMap[From] = To;
15705 };
15706
15707 // Checks whether \p S has already been rewritten. In that case returns the
15708 // existing rewrite because we want to chain further rewrites onto the
15709 // already rewritten value. Otherwise returns \p S.
15710 auto GetMaybeRewritten = [&](const SCEV *S) {
15711 return RewriteMap.lookup_or(S, S);
15712 };
15713
15714 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15715 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15716 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15717 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15718 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15719 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15720 // DividesBy.
15721 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15722 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15723 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15724 if (Mul->getNumOperands() != 2)
15725 return false;
15726 auto *MulLHS = Mul->getOperand(0);
15727 auto *MulRHS = Mul->getOperand(1);
15728 if (isa<SCEVConstant>(MulLHS))
15729 std::swap(MulLHS, MulRHS);
15730 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15731 if (Div->getOperand(1) == MulRHS) {
15732 DividesBy = MulRHS;
15733 return true;
15734 }
15735 }
15736 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15737 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15738 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15739 return false;
15740 };
15741
15742 // Return true if Expr known to divide by \p DividesBy.
15743 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15744 [&](const SCEV *Expr, const SCEV *DividesBy) {
15745 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15746 return true;
15747 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15748 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15749 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15750 return false;
15751 };
15752
15753 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15754 const SCEV *DividesBy = nullptr;
15755 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15756 // Check that the whole expression is divided by DividesBy
15757 DividesBy =
15758 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15759
15760 // Collect rewrites for LHS and its transitive operands based on the
15761 // condition.
15762 // For min/max expressions, also apply the guard to its operands:
15763 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15764 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15765 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15766 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15767
15768 // We cannot express strict predicates in SCEV, so instead we replace them
15769 // with non-strict ones against plus or minus one of RHS depending on the
15770 // predicate.
15771 const SCEV *One = SE.getOne(RHS->getType());
15772 switch (Predicate) {
15773 case CmpInst::ICMP_ULT:
15774 if (RHS->getType()->isPointerTy())
15775 return;
15776 RHS = SE.getUMaxExpr(RHS, One);
15777 [[fallthrough]];
15778 case CmpInst::ICMP_SLT: {
15779 RHS = SE.getMinusSCEV(RHS, One);
15780 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15781 break;
15782 }
15783 case CmpInst::ICMP_UGT:
15784 case CmpInst::ICMP_SGT:
15785 RHS = SE.getAddExpr(RHS, One);
15786 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15787 break;
15788 case CmpInst::ICMP_ULE:
15789 case CmpInst::ICMP_SLE:
15790 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15791 break;
15792 case CmpInst::ICMP_UGE:
15793 case CmpInst::ICMP_SGE:
15794 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15795 break;
15796 default:
15797 break;
15798 }
15799
15801 SmallPtrSet<const SCEV *, 16> Visited;
15802
15803 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15804 append_range(Worklist, S->operands());
15805 };
15806
15807 while (!Worklist.empty()) {
15808 const SCEV *From = Worklist.pop_back_val();
15809 if (isa<SCEVConstant>(From))
15810 continue;
15811 if (!Visited.insert(From).second)
15812 continue;
15813 const SCEV *FromRewritten = GetMaybeRewritten(From);
15814 const SCEV *To = nullptr;
15815
15816 switch (Predicate) {
15817 case CmpInst::ICMP_ULT:
15818 case CmpInst::ICMP_ULE:
15819 To = SE.getUMinExpr(FromRewritten, RHS);
15820 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15821 EnqueueOperands(UMax);
15822 break;
15823 case CmpInst::ICMP_SLT:
15824 case CmpInst::ICMP_SLE:
15825 To = SE.getSMinExpr(FromRewritten, RHS);
15826 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15827 EnqueueOperands(SMax);
15828 break;
15829 case CmpInst::ICMP_UGT:
15830 case CmpInst::ICMP_UGE:
15831 To = SE.getUMaxExpr(FromRewritten, RHS);
15832 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15833 EnqueueOperands(UMin);
15834 break;
15835 case CmpInst::ICMP_SGT:
15836 case CmpInst::ICMP_SGE:
15837 To = SE.getSMaxExpr(FromRewritten, RHS);
15838 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15839 EnqueueOperands(SMin);
15840 break;
15841 case CmpInst::ICMP_EQ:
15843 To = RHS;
15844 break;
15845 case CmpInst::ICMP_NE:
15846 if (match(RHS, m_scev_Zero())) {
15847 const SCEV *OneAlignedUp =
15848 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15849 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15850 }
15851 break;
15852 default:
15853 break;
15854 }
15855
15856 if (To)
15857 AddRewrite(From, FromRewritten, To);
15858 }
15859 };
15860
15862 // First, collect information from assumptions dominating the loop.
15863 for (auto &AssumeVH : SE.AC.assumptions()) {
15864 if (!AssumeVH)
15865 continue;
15866 auto *AssumeI = cast<CallInst>(AssumeVH);
15867 if (!SE.DT.dominates(AssumeI, Block))
15868 continue;
15869 Terms.emplace_back(AssumeI->getOperand(0), true);
15870 }
15871
15872 // Second, collect information from llvm.experimental.guards dominating the loop.
15873 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15874 SE.F.getParent(), Intrinsic::experimental_guard);
15875 if (GuardDecl)
15876 for (const auto *GU : GuardDecl->users())
15877 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15878 if (Guard->getFunction() == Block->getParent() &&
15879 SE.DT.dominates(Guard, Block))
15880 Terms.emplace_back(Guard->getArgOperand(0), true);
15881
15882 // Third, collect conditions from dominating branches. Starting at the loop
15883 // predecessor, climb up the predecessor chain, as long as there are
15884 // predecessors that can be found that have unique successors leading to the
15885 // original header.
15886 // TODO: share this logic with isLoopEntryGuardedByCond.
15887 unsigned NumCollectedConditions = 0;
15888 VisitedBlocks.insert(Block);
15889 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15890 for (; Pair.first;
15891 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15892 VisitedBlocks.insert(Pair.second);
15893 const BranchInst *LoopEntryPredicate =
15894 dyn_cast<BranchInst>(Pair.first->getTerminator());
15895 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15896 continue;
15897
15898 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15899 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15900 NumCollectedConditions++;
15901
15902 // If we are recursively collecting guards stop after 2
15903 // conditions to limit compile-time impact for now.
15904 if (Depth > 0 && NumCollectedConditions == 2)
15905 break;
15906 }
15907 // Finally, if we stopped climbing the predecessor chain because
15908 // there wasn't a unique one to continue, try to collect conditions
15909 // for PHINodes by recursively following all of their incoming
15910 // blocks and try to merge the found conditions to build a new one
15911 // for the Phi.
15912 if (Pair.second->hasNPredecessorsOrMore(2) &&
15914 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15915 for (auto &Phi : Pair.second->phis())
15916 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15917 }
15918
15919 // Now apply the information from the collected conditions to
15920 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15921 // earliest conditions is processed first. This ensures the SCEVs with the
15922 // shortest dependency chains are constructed first.
15923 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15924 SmallVector<Value *, 8> Worklist;
15925 SmallPtrSet<Value *, 8> Visited;
15926 Worklist.push_back(Term);
15927 while (!Worklist.empty()) {
15928 Value *Cond = Worklist.pop_back_val();
15929 if (!Visited.insert(Cond).second)
15930 continue;
15931
15932 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15933 auto Predicate =
15934 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15935 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15936 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15937 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15938 continue;
15939 }
15940
15941 Value *L, *R;
15942 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15943 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15944 Worklist.push_back(L);
15945 Worklist.push_back(R);
15946 }
15947 }
15948 }
15949
15950 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15951 // the replacement expressions are contained in the ranges of the replaced
15952 // expressions.
15953 Guards.PreserveNUW = true;
15954 Guards.PreserveNSW = true;
15955 for (const SCEV *Expr : ExprsToRewrite) {
15956 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15957 Guards.PreserveNUW &=
15958 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15959 Guards.PreserveNSW &=
15960 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15961 }
15962
15963 // Now that all rewrite information is collect, rewrite the collected
15964 // expressions with the information in the map. This applies information to
15965 // sub-expressions.
15966 if (ExprsToRewrite.size() > 1) {
15967 for (const SCEV *Expr : ExprsToRewrite) {
15968 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15969 Guards.RewriteMap.erase(Expr);
15970 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15971 }
15972 }
15973}
15974
15976 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15977 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15978 /// replacement is loop invariant in the loop of the AddRec.
15979 class SCEVLoopGuardRewriter
15980 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15982
15984
15985 public:
15986 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15987 const ScalarEvolution::LoopGuards &Guards)
15988 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15989 if (Guards.PreserveNUW)
15990 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15991 if (Guards.PreserveNSW)
15992 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15993 }
15994
15995 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15996
15997 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15998 return Map.lookup_or(Expr, Expr);
15999 }
16000
16001 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16002 if (const SCEV *S = Map.lookup(Expr))
16003 return S;
16004
16005 // If we didn't find the extact ZExt expr in the map, check if there's
16006 // an entry for a smaller ZExt we can use instead.
16007 Type *Ty = Expr->getType();
16008 const SCEV *Op = Expr->getOperand(0);
16009 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16010 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16011 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16012 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16013 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16014 if (const SCEV *S = Map.lookup(NarrowExt))
16015 return SE.getZeroExtendExpr(S, Ty);
16016 Bitwidth = Bitwidth / 2;
16017 }
16018
16020 Expr);
16021 }
16022
16023 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16024 if (const SCEV *S = Map.lookup(Expr))
16025 return S;
16027 Expr);
16028 }
16029
16030 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16031 if (const SCEV *S = Map.lookup(Expr))
16032 return S;
16034 }
16035
16036 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16037 if (const SCEV *S = Map.lookup(Expr))
16038 return S;
16040 }
16041
16042 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16043 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16044 // (Const + A + B). There may be guard info for A + B, and if so, apply
16045 // it.
16046 // TODO: Could more generally apply guards to Add sub-expressions.
16047 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16048 Expr->getNumOperands() == 3) {
16049 if (const SCEV *S = Map.lookup(
16050 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
16051 return SE.getAddExpr(Expr->getOperand(0), S);
16052 }
16054 bool Changed = false;
16055 for (const auto *Op : Expr->operands()) {
16056 Operands.push_back(
16058 Changed |= Op != Operands.back();
16059 }
16060 // We are only replacing operands with equivalent values, so transfer the
16061 // flags from the original expression.
16062 return !Changed ? Expr
16063 : SE.getAddExpr(Operands,
16065 Expr->getNoWrapFlags(), FlagMask));
16066 }
16067
16068 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16070 bool Changed = false;
16071 for (const auto *Op : Expr->operands()) {
16072 Operands.push_back(
16074 Changed |= Op != Operands.back();
16075 }
16076 // We are only replacing operands with equivalent values, so transfer the
16077 // flags from the original expression.
16078 return !Changed ? Expr
16079 : SE.getMulExpr(Operands,
16081 Expr->getNoWrapFlags(), FlagMask));
16082 }
16083 };
16084
16085 if (RewriteMap.empty())
16086 return Expr;
16087
16088 SCEVLoopGuardRewriter Rewriter(SE, *this);
16089 return Rewriter.visit(Expr);
16090}
16091
16092const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16093 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16094}
16095
16097 const LoopGuards &Guards) {
16098 return Guards.rewrite(Expr);
16099}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
mir Rename Register Operands
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:206
APInt abs() const
Get the absolute value.
Definition APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1201
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:371
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:219
unsigned countTrailingZeros() const
Definition APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:356
unsigned logBase2() const
Definition APInt.h:1761
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:440
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1221
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
iterator end() const
Definition ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator begin() const
Definition ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:950
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:678
@ ICMP_SLT
signed less than
Definition InstrTypes.h:707
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:708
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:701
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:705
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:703
@ ICMP_NE
not equal
Definition InstrTypes.h:700
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:706
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:704
bool isSigned() const
Definition InstrTypes.h:932
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:829
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:944
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:791
bool isUnsigned() const
Definition InstrTypes.h:938
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:928
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:669
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:187
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:229
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:173
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:161
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:156
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:214
void swap(DenseMap &RHS)
Definition DenseMap.h:746
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1077
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:621
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:652
TypeSize getSizeInBits() const
Definition DataLayout.h:632
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:295
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:294
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:169
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
class_match< const SCEV > m_SCEV()
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
Definition MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:318
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2038
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1705
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:738
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2116
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:252
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2033
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:682
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:186
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:95
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2108
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1712
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:288
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:71
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1934
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2010
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1847
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1941
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:257
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1877
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:853
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:294
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:101
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:189
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:138
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:122
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:98
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.