LLVM 21.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
85#include "llvm/Config/llvm-config.h"
86#include "llvm/IR/Argument.h"
87#include "llvm/IR/BasicBlock.h"
88#include "llvm/IR/CFG.h"
89#include "llvm/IR/Constant.h"
91#include "llvm/IR/Constants.h"
92#include "llvm/IR/DataLayout.h"
94#include "llvm/IR/Dominators.h"
95#include "llvm/IR/Function.h"
96#include "llvm/IR/GlobalAlias.h"
97#include "llvm/IR/GlobalValue.h"
99#include "llvm/IR/InstrTypes.h"
100#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Instructions.h"
103#include "llvm/IR/Intrinsics.h"
104#include "llvm/IR/LLVMContext.h"
105#include "llvm/IR/Operator.h"
106#include "llvm/IR/PatternMatch.h"
107#include "llvm/IR/Type.h"
108#include "llvm/IR/Use.h"
109#include "llvm/IR/User.h"
110#include "llvm/IR/Value.h"
111#include "llvm/IR/Verifier.h"
113#include "llvm/Pass.h"
114#include "llvm/Support/Casting.h"
117#include "llvm/Support/Debug.h"
122#include <algorithm>
123#include <cassert>
124#include <climits>
125#include <cstdint>
126#include <cstdlib>
127#include <map>
128#include <memory>
129#include <numeric>
130#include <optional>
131#include <tuple>
132#include <utility>
133#include <vector>
134
135using namespace llvm;
136using namespace PatternMatch;
137using namespace SCEVPatternMatch;
138
139#define DEBUG_TYPE "scalar-evolution"
140
141STATISTIC(NumExitCountsComputed,
142 "Number of loop exits with predictable exit counts");
143STATISTIC(NumExitCountsNotComputed,
144 "Number of loop exits without predictable exit counts");
145STATISTIC(NumBruteForceTripCountsComputed,
146 "Number of loops with trip counts computed by force");
147
148#ifdef EXPENSIVE_CHECKS
149bool llvm::VerifySCEV = true;
150#else
151bool llvm::VerifySCEV = false;
152#endif
153
155 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
156 cl::desc("Maximum number of iterations SCEV will "
157 "symbolically execute a constant "
158 "derived loop"),
159 cl::init(100));
160
162 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
163 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
165 "verify-scev-strict", cl::Hidden,
166 cl::desc("Enable stricter verification with -verify-scev is passed"));
167
169 "scev-verify-ir", cl::Hidden,
170 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
171 cl::init(false));
172
174 "scev-mulops-inline-threshold", cl::Hidden,
175 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
176 cl::init(32));
177
179 "scev-addops-inline-threshold", cl::Hidden,
180 cl::desc("Threshold for inlining addition operands into a SCEV"),
181 cl::init(500));
182
184 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
185 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
186 cl::init(32));
187
189 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
190 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
191 cl::init(2));
192
194 "scalar-evolution-max-value-compare-depth", cl::Hidden,
195 cl::desc("Maximum depth of recursive value complexity comparisons"),
196 cl::init(2));
197
199 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
200 cl::desc("Maximum depth of recursive arithmetics"),
201 cl::init(32));
202
204 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
205 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
206
208 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
209 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
210 cl::init(8));
211
213 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
214 cl::desc("Max coefficients in AddRec during evolving"),
215 cl::init(8));
216
218 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
219 cl::desc("Size of the expression which is considered huge"),
220 cl::init(4096));
221
223 "scev-range-iter-threshold", cl::Hidden,
224 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
225 cl::init(32));
226
228 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
229 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
230
231static cl::opt<bool>
232ClassifyExpressions("scalar-evolution-classify-expressions",
233 cl::Hidden, cl::init(true),
234 cl::desc("When printing analysis, include information on every instruction"));
235
237 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
238 cl::init(false),
239 cl::desc("Use more powerful methods of sharpening expression ranges. May "
240 "be costly in terms of compile time"));
241
243 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
244 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
245 "Phi strongly connected components"),
246 cl::init(8));
247
248static cl::opt<bool>
249 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
250 cl::desc("Handle <= and >= in finite loops"),
251 cl::init(true));
252
254 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
255 cl::desc("Infer nuw/nsw flags using context where suitable"),
256 cl::init(true));
257
258//===----------------------------------------------------------------------===//
259// SCEV class definitions
260//===----------------------------------------------------------------------===//
261
262//===----------------------------------------------------------------------===//
263// Implementation of the SCEV class.
264//
265
266#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
268 print(dbgs());
269 dbgs() << '\n';
270}
271#endif
272
274 switch (getSCEVType()) {
275 case scConstant:
276 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
277 return;
278 case scVScale:
279 OS << "vscale";
280 return;
281 case scPtrToInt: {
282 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
283 const SCEV *Op = PtrToInt->getOperand();
284 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
285 << *PtrToInt->getType() << ")";
286 return;
287 }
288 case scTruncate: {
289 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
290 const SCEV *Op = Trunc->getOperand();
291 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
292 << *Trunc->getType() << ")";
293 return;
294 }
295 case scZeroExtend: {
296 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
297 const SCEV *Op = ZExt->getOperand();
298 OS << "(zext " << *Op->getType() << " " << *Op << " to "
299 << *ZExt->getType() << ")";
300 return;
301 }
302 case scSignExtend: {
303 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
304 const SCEV *Op = SExt->getOperand();
305 OS << "(sext " << *Op->getType() << " " << *Op << " to "
306 << *SExt->getType() << ")";
307 return;
308 }
309 case scAddRecExpr: {
310 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
311 OS << "{" << *AR->getOperand(0);
312 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
313 OS << ",+," << *AR->getOperand(i);
314 OS << "}<";
315 if (AR->hasNoUnsignedWrap())
316 OS << "nuw><";
317 if (AR->hasNoSignedWrap())
318 OS << "nsw><";
319 if (AR->hasNoSelfWrap() &&
321 OS << "nw><";
322 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
323 OS << ">";
324 return;
325 }
326 case scAddExpr:
327 case scMulExpr:
328 case scUMaxExpr:
329 case scSMaxExpr:
330 case scUMinExpr:
331 case scSMinExpr:
333 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
334 const char *OpStr = nullptr;
335 switch (NAry->getSCEVType()) {
336 case scAddExpr: OpStr = " + "; break;
337 case scMulExpr: OpStr = " * "; break;
338 case scUMaxExpr: OpStr = " umax "; break;
339 case scSMaxExpr: OpStr = " smax "; break;
340 case scUMinExpr:
341 OpStr = " umin ";
342 break;
343 case scSMinExpr:
344 OpStr = " smin ";
345 break;
347 OpStr = " umin_seq ";
348 break;
349 default:
350 llvm_unreachable("There are no other nary expression types.");
351 }
352 OS << "(";
353 ListSeparator LS(OpStr);
354 for (const SCEV *Op : NAry->operands())
355 OS << LS << *Op;
356 OS << ")";
357 switch (NAry->getSCEVType()) {
358 case scAddExpr:
359 case scMulExpr:
360 if (NAry->hasNoUnsignedWrap())
361 OS << "<nuw>";
362 if (NAry->hasNoSignedWrap())
363 OS << "<nsw>";
364 break;
365 default:
366 // Nothing to print for other nary expressions.
367 break;
368 }
369 return;
370 }
371 case scUDivExpr: {
372 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
373 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
374 return;
375 }
376 case scUnknown:
377 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
378 return;
380 OS << "***COULDNOTCOMPUTE***";
381 return;
382 }
383 llvm_unreachable("Unknown SCEV kind!");
384}
385
387 switch (getSCEVType()) {
388 case scConstant:
389 return cast<SCEVConstant>(this)->getType();
390 case scVScale:
391 return cast<SCEVVScale>(this)->getType();
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToInt:
427 case scTruncate:
428 case scZeroExtend:
429 case scSignExtend:
430 return cast<SCEVCastExpr>(this)->operands();
431 case scAddRecExpr:
432 case scAddExpr:
433 case scMulExpr:
434 case scUMaxExpr:
435 case scSMaxExpr:
436 case scUMinExpr:
437 case scSMinExpr:
439 return cast<SCEVNAryExpr>(this)->operands();
440 case scUDivExpr:
441 return cast<SCEVUDivExpr>(this)->operands();
443 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
444 }
445 llvm_unreachable("Unknown SCEV kind!");
446}
447
448bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
449
450bool SCEV::isOne() const { return match(this, m_scev_One()); }
451
452bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
453
455 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
456 if (!Mul) return false;
457
458 // If there is a constant factor, it will be first.
459 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
460 if (!SC) return false;
461
462 // Return true if the value is negative, this matches things like (-42 * V).
463 return SC->getAPInt().isNegative();
464}
465
468
470 return S->getSCEVType() == scCouldNotCompute;
471}
472
475 ID.AddInteger(scConstant);
476 ID.AddPointer(V);
477 void *IP = nullptr;
478 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
479 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
480 UniqueSCEVs.InsertNode(S, IP);
481 return S;
482}
483
485 return getConstant(ConstantInt::get(getContext(), Val));
486}
487
488const SCEV *
490 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
491 return getConstant(ConstantInt::get(ITy, V, isSigned));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
508 if (EC.isScalable())
509 Res = getMulExpr(Res, getVScale(Ty));
510 return Res;
511}
512
514 const SCEV *op, Type *ty)
515 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
516
517SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
518 Type *ITy)
519 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 "Must be a non-bit-width-changing pointer-to-integer cast!");
522}
523
525 SCEVTypes SCEVTy, const SCEV *op,
526 Type *ty)
527 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
528
529SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
530 Type *ty)
532 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
533 "Cannot truncate non-integer value!");
534}
535
536SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
537 const SCEV *op, Type *ty)
539 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
540 "Cannot zero extend non-integer value!");
541}
542
543SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
544 const SCEV *op, Type *ty)
546 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
547 "Cannot sign extend non-integer value!");
548}
549
550void SCEVUnknown::deleted() {
551 // Clear this SCEVUnknown from various maps.
552 SE->forgetMemoizedResults(this);
553
554 // Remove this SCEVUnknown from the uniquing map.
555 SE->UniqueSCEVs.RemoveNode(this);
556
557 // Release the value.
558 setValPtr(nullptr);
559}
560
561void SCEVUnknown::allUsesReplacedWith(Value *New) {
562 // Clear this SCEVUnknown from various maps.
563 SE->forgetMemoizedResults(this);
564
565 // Remove this SCEVUnknown from the uniquing map.
566 SE->UniqueSCEVs.RemoveNode(this);
567
568 // Replace the value pointer in case someone is still using this SCEVUnknown.
569 setValPtr(New);
570}
571
572//===----------------------------------------------------------------------===//
573// SCEV Utilities
574//===----------------------------------------------------------------------===//
575
576/// Compare the two values \p LV and \p RV in terms of their "complexity" where
577/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
578/// operands in SCEV expressions.
579static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
580 Value *RV, unsigned Depth) {
582 return 0;
583
584 // Order pointer values after integer values. This helps SCEVExpander form
585 // GEPs.
586 bool LIsPointer = LV->getType()->isPointerTy(),
587 RIsPointer = RV->getType()->isPointerTy();
588 if (LIsPointer != RIsPointer)
589 return (int)LIsPointer - (int)RIsPointer;
590
591 // Compare getValueID values.
592 unsigned LID = LV->getValueID(), RID = RV->getValueID();
593 if (LID != RID)
594 return (int)LID - (int)RID;
595
596 // Sort arguments by their position.
597 if (const auto *LA = dyn_cast<Argument>(LV)) {
598 const auto *RA = cast<Argument>(RV);
599 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
600 return (int)LArgNo - (int)RArgNo;
601 }
602
603 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
604 const auto *RGV = cast<GlobalValue>(RV);
605
606 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
607 auto LT = GV->getLinkage();
608 return !(GlobalValue::isPrivateLinkage(LT) ||
610 };
611
612 // Use the names to distinguish the two values, but only if the
613 // names are semantically important.
614 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
615 return LGV->getName().compare(RGV->getName());
616 }
617
618 // For instructions, compare their loop depth, and their operand count. This
619 // is pretty loose.
620 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
621 const auto *RInst = cast<Instruction>(RV);
622
623 // Compare loop depths.
624 const BasicBlock *LParent = LInst->getParent(),
625 *RParent = RInst->getParent();
626 if (LParent != RParent) {
627 unsigned LDepth = LI->getLoopDepth(LParent),
628 RDepth = LI->getLoopDepth(RParent);
629 if (LDepth != RDepth)
630 return (int)LDepth - (int)RDepth;
631 }
632
633 // Compare the number of operands.
634 unsigned LNumOps = LInst->getNumOperands(),
635 RNumOps = RInst->getNumOperands();
636 if (LNumOps != RNumOps)
637 return (int)LNumOps - (int)RNumOps;
638
639 for (unsigned Idx : seq(LNumOps)) {
640 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
641 RInst->getOperand(Idx), Depth + 1);
642 if (Result != 0)
643 return Result;
644 }
645 }
646
647 return 0;
648}
649
650// Return negative, zero, or positive, if LHS is less than, equal to, or greater
651// than RHS, respectively. A three-way result allows recursive comparisons to be
652// more efficient.
653// If the max analysis depth was reached, return std::nullopt, assuming we do
654// not know if they are equivalent for sure.
655static std::optional<int>
657 const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
668 if (EqCacheSCEV.isEquivalent(LHS, RHS))
669 return 0;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 if (X == 0)
685 EqCacheSCEV.unionSets(LHS, RHS);
686 return X;
687 }
688
689 case scConstant: {
690 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
691 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
692
693 // Compare constant values.
694 const APInt &LA = LC->getAPInt();
695 const APInt &RA = RC->getAPInt();
696 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
697 if (LBitWidth != RBitWidth)
698 return (int)LBitWidth - (int)RBitWidth;
699 return LA.ult(RA) ? -1 : 1;
700 }
701
702 case scVScale: {
703 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
704 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
705 return LTy->getBitWidth() - RTy->getBitWidth();
706 }
707
708 case scAddRecExpr: {
709 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
710 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
711
712 // There is always a dominance between two recs that are used by one SCEV,
713 // so we can safely sort recs by loop header dominance. We require such
714 // order in getAddExpr.
715 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
716 if (LLoop != RLoop) {
717 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
718 assert(LHead != RHead && "Two loops share the same header?");
719 if (DT.dominates(LHead, RHead))
720 return 1;
721 assert(DT.dominates(RHead, LHead) &&
722 "No dominance between recurrences used by one SCEV?");
723 return -1;
724 }
725
726 [[fallthrough]];
727 }
728
729 case scTruncate:
730 case scZeroExtend:
731 case scSignExtend:
732 case scPtrToInt:
733 case scAddExpr:
734 case scMulExpr:
735 case scUDivExpr:
736 case scSMaxExpr:
737 case scUMaxExpr:
738 case scSMinExpr:
739 case scUMinExpr:
741 ArrayRef<const SCEV *> LOps = LHS->operands();
742 ArrayRef<const SCEV *> ROps = RHS->operands();
743
744 // Lexicographically compare n-ary-like expressions.
745 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
746 if (LNumOps != RNumOps)
747 return (int)LNumOps - (int)RNumOps;
748
749 for (unsigned i = 0; i != LNumOps; ++i) {
750 auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
751 Depth + 1);
752 if (X != 0)
753 return X;
754 }
755 EqCacheSCEV.unionSets(LHS, RHS);
756 return 0;
757 }
758
760 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
761 }
762 llvm_unreachable("Unknown SCEV kind!");
763}
764
765/// Given a list of SCEV objects, order them by their complexity, and group
766/// objects of the same complexity together by value. When this routine is
767/// finished, we know that any duplicates in the vector are consecutive and that
768/// complexity is monotonically increasing.
769///
770/// Note that we go take special precautions to ensure that we get deterministic
771/// results from this routine. In other words, we don't want the results of
772/// this to depend on where the addresses of various SCEV objects happened to
773/// land in memory.
775 LoopInfo *LI, DominatorTree &DT) {
776 if (Ops.size() < 2) return; // Noop
777
779
780 // Whether LHS has provably less complexity than RHS.
781 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
782 auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
783 return Complexity && *Complexity < 0;
784 };
785 if (Ops.size() == 2) {
786 // This is the common case, which also happens to be trivially simple.
787 // Special case it.
788 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
789 if (IsLessComplex(RHS, LHS))
790 std::swap(LHS, RHS);
791 return;
792 }
793
794 // Do the rough sort by complexity.
795 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
796 return IsLessComplex(LHS, RHS);
797 });
798
799 // Now that we are sorted by complexity, group elements of the same
800 // complexity. Note that this is, at worst, N^2, but the vector is likely to
801 // be extremely short in practice. Note that we take this approach because we
802 // do not want to depend on the addresses of the objects we are grouping.
803 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
804 const SCEV *S = Ops[i];
805 unsigned Complexity = S->getSCEVType();
806
807 // If there are any objects of the same complexity and same value as this
808 // one, group them.
809 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
810 if (Ops[j] == S) { // Found a duplicate.
811 // Move it to immediately after i'th element.
812 std::swap(Ops[i+1], Ops[j]);
813 ++i; // no need to rescan it.
814 if (i == e-2) return; // Done!
815 }
816 }
817 }
818}
819
820/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
821/// least HugeExprThreshold nodes).
823 return any_of(Ops, [](const SCEV *S) {
825 });
826}
827
828/// Performs a number of common optimizations on the passed \p Ops. If the
829/// whole expression reduces down to a single operand, it will be returned.
830///
831/// The following optimizations are performed:
832/// * Fold constants using the \p Fold function.
833/// * Remove identity constants satisfying \p IsIdentity.
834/// * If a constant satisfies \p IsAbsorber, return it.
835/// * Sort operands by complexity.
836template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
837static const SCEV *
839 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
840 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
841 const SCEVConstant *Folded = nullptr;
842 for (unsigned Idx = 0; Idx < Ops.size();) {
843 const SCEV *Op = Ops[Idx];
844 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
845 if (!Folded)
846 Folded = C;
847 else
848 Folded = cast<SCEVConstant>(
849 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
850 Ops.erase(Ops.begin() + Idx);
851 continue;
852 }
853 ++Idx;
854 }
855
856 if (Ops.empty()) {
857 assert(Folded && "Must have folded value");
858 return Folded;
859 }
860
861 if (Folded && IsAbsorber(Folded->getAPInt()))
862 return Folded;
863
864 GroupByComplexity(Ops, &LI, DT);
865 if (Folded && !IsIdentity(Folded->getAPInt()))
866 Ops.insert(Ops.begin(), Folded);
867
868 return Ops.size() == 1 ? Ops[0] : nullptr;
869}
870
871//===----------------------------------------------------------------------===//
872// Simple SCEV method implementations
873//===----------------------------------------------------------------------===//
874
875/// Compute BC(It, K). The result has width W. Assume, K > 0.
876static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
877 ScalarEvolution &SE,
878 Type *ResultTy) {
879 // Handle the simplest case efficiently.
880 if (K == 1)
881 return SE.getTruncateOrZeroExtend(It, ResultTy);
882
883 // We are using the following formula for BC(It, K):
884 //
885 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
886 //
887 // Suppose, W is the bitwidth of the return value. We must be prepared for
888 // overflow. Hence, we must assure that the result of our computation is
889 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
890 // safe in modular arithmetic.
891 //
892 // However, this code doesn't use exactly that formula; the formula it uses
893 // is something like the following, where T is the number of factors of 2 in
894 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
895 // exponentiation:
896 //
897 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
898 //
899 // This formula is trivially equivalent to the previous formula. However,
900 // this formula can be implemented much more efficiently. The trick is that
901 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
902 // arithmetic. To do exact division in modular arithmetic, all we have
903 // to do is multiply by the inverse. Therefore, this step can be done at
904 // width W.
905 //
906 // The next issue is how to safely do the division by 2^T. The way this
907 // is done is by doing the multiplication step at a width of at least W + T
908 // bits. This way, the bottom W+T bits of the product are accurate. Then,
909 // when we perform the division by 2^T (which is equivalent to a right shift
910 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
911 // truncated out after the division by 2^T.
912 //
913 // In comparison to just directly using the first formula, this technique
914 // is much more efficient; using the first formula requires W * K bits,
915 // but this formula less than W + K bits. Also, the first formula requires
916 // a division step, whereas this formula only requires multiplies and shifts.
917 //
918 // It doesn't matter whether the subtraction step is done in the calculation
919 // width or the input iteration count's width; if the subtraction overflows,
920 // the result must be zero anyway. We prefer here to do it in the width of
921 // the induction variable because it helps a lot for certain cases; CodeGen
922 // isn't smart enough to ignore the overflow, which leads to much less
923 // efficient code if the width of the subtraction is wider than the native
924 // register width.
925 //
926 // (It's possible to not widen at all by pulling out factors of 2 before
927 // the multiplication; for example, K=2 can be calculated as
928 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
929 // extra arithmetic, so it's not an obvious win, and it gets
930 // much more complicated for K > 3.)
931
932 // Protection from insane SCEVs; this bound is conservative,
933 // but it probably doesn't matter.
934 if (K > 1000)
935 return SE.getCouldNotCompute();
936
937 unsigned W = SE.getTypeSizeInBits(ResultTy);
938
939 // Calculate K! / 2^T and T; we divide out the factors of two before
940 // multiplying for calculating K! / 2^T to avoid overflow.
941 // Other overflow doesn't matter because we only care about the bottom
942 // W bits of the result.
943 APInt OddFactorial(W, 1);
944 unsigned T = 1;
945 for (unsigned i = 3; i <= K; ++i) {
946 unsigned TwoFactors = countr_zero(i);
947 T += TwoFactors;
948 OddFactorial *= (i >> TwoFactors);
949 }
950
951 // We need at least W + T bits for the multiplication step
952 unsigned CalculationBits = W + T;
953
954 // Calculate 2^T, at width T+W.
955 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
956
957 // Calculate the multiplicative inverse of K! / 2^T;
958 // this multiplication factor will perform the exact division by
959 // K! / 2^T.
960 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
961
962 // Calculate the product, at width T+W
963 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
964 CalculationBits);
965 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
966 for (unsigned i = 1; i != K; ++i) {
967 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
968 Dividend = SE.getMulExpr(Dividend,
969 SE.getTruncateOrZeroExtend(S, CalculationTy));
970 }
971
972 // Divide by 2^T
973 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
974
975 // Truncate the result, and divide by K! / 2^T.
976
977 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
978 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
979}
980
981/// Return the value of this chain of recurrences at the specified iteration
982/// number. We can evaluate this recurrence by multiplying each element in the
983/// chain by the binomial coefficient corresponding to it. In other words, we
984/// can evaluate {A,+,B,+,C,+,D} as:
985///
986/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
987///
988/// where BC(It, k) stands for binomial coefficient.
990 ScalarEvolution &SE) const {
991 return evaluateAtIteration(operands(), It, SE);
992}
993
994const SCEV *
996 const SCEV *It, ScalarEvolution &SE) {
997 assert(Operands.size() > 0);
998 const SCEV *Result = Operands[0];
999 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1000 // The computation is correct in the face of overflow provided that the
1001 // multiplication is performed _after_ the evaluation of the binomial
1002 // coefficient.
1003 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1004 if (isa<SCEVCouldNotCompute>(Coeff))
1005 return Coeff;
1006
1007 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1008 }
1009 return Result;
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SCEV Expression folder implementations
1014//===----------------------------------------------------------------------===//
1015
1017 unsigned Depth) {
1018 assert(Depth <= 1 &&
1019 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1020
1021 // We could be called with an integer-typed operands during SCEV rewrites.
1022 // Since the operand is an integer already, just perform zext/trunc/self cast.
1023 if (!Op->getType()->isPointerTy())
1024 return Op;
1025
1026 // What would be an ID for such a SCEV cast expression?
1028 ID.AddInteger(scPtrToInt);
1029 ID.AddPointer(Op);
1030
1031 void *IP = nullptr;
1032
1033 // Is there already an expression for such a cast?
1034 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1035 return S;
1036
1037 // It isn't legal for optimizations to construct new ptrtoint expressions
1038 // for non-integral pointers.
1039 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1040 return getCouldNotCompute();
1041
1042 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1043
1044 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1045 // is sufficiently wide to represent all possible pointer values.
1046 // We could theoretically teach SCEV to truncate wider pointers, but
1047 // that isn't implemented for now.
1049 getDataLayout().getTypeSizeInBits(IntPtrTy))
1050 return getCouldNotCompute();
1051
1052 // If not, is this expression something we can't reduce any further?
1053 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1054 // Perform some basic constant folding. If the operand of the ptr2int cast
1055 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1056 // left as-is), but produce a zero constant.
1057 // NOTE: We could handle a more general case, but lack motivational cases.
1058 if (isa<ConstantPointerNull>(U->getValue()))
1059 return getZero(IntPtrTy);
1060
1061 // Create an explicit cast node.
1062 // We can reuse the existing insert position since if we get here,
1063 // we won't have made any changes which would invalidate it.
1064 SCEV *S = new (SCEVAllocator)
1065 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1066 UniqueSCEVs.InsertNode(S, IP);
1067 registerUser(S, Op);
1068 return S;
1069 }
1070
1071 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1072 "non-SCEVUnknown's.");
1073
1074 // Otherwise, we've got some expression that is more complex than just a
1075 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1076 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1077 // only, and the expressions must otherwise be integer-typed.
1078 // So sink the cast down to the SCEVUnknown's.
1079
1080 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1081 /// which computes a pointer-typed value, and rewrites the whole expression
1082 /// tree so that *all* the computations are done on integers, and the only
1083 /// pointer-typed operands in the expression are SCEVUnknown.
1084 class SCEVPtrToIntSinkingRewriter
1085 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1087
1088 public:
1089 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1090
1091 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1092 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1093 return Rewriter.visit(Scev);
1094 }
1095
1096 const SCEV *visit(const SCEV *S) {
1097 Type *STy = S->getType();
1098 // If the expression is not pointer-typed, just keep it as-is.
1099 if (!STy->isPointerTy())
1100 return S;
1101 // Else, recursively sink the cast down into it.
1102 return Base::visit(S);
1103 }
1104
1105 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1107 bool Changed = false;
1108 for (const auto *Op : Expr->operands()) {
1109 Operands.push_back(visit(Op));
1110 Changed |= Op != Operands.back();
1111 }
1112 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1113 }
1114
1115 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1117 bool Changed = false;
1118 for (const auto *Op : Expr->operands()) {
1119 Operands.push_back(visit(Op));
1120 Changed |= Op != Operands.back();
1121 }
1122 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1123 }
1124
1125 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1126 assert(Expr->getType()->isPointerTy() &&
1127 "Should only reach pointer-typed SCEVUnknown's.");
1128 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1129 }
1130 };
1131
1132 // And actually perform the cast sinking.
1133 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1134 assert(IntOp->getType()->isIntegerTy() &&
1135 "We must have succeeded in sinking the cast, "
1136 "and ending up with an integer-typed expression!");
1137 return IntOp;
1138}
1139
1141 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1142
1143 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1144 if (isa<SCEVCouldNotCompute>(IntOp))
1145 return IntOp;
1146
1147 return getTruncateOrZeroExtend(IntOp, Ty);
1148}
1149
1151 unsigned Depth) {
1152 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1153 "This is not a truncating conversion!");
1154 assert(isSCEVable(Ty) &&
1155 "This is not a conversion to a SCEVable type!");
1156 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1157 Ty = getEffectiveSCEVType(Ty);
1158
1160 ID.AddInteger(scTruncate);
1161 ID.AddPointer(Op);
1162 ID.AddPointer(Ty);
1163 void *IP = nullptr;
1164 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1165
1166 // Fold if the operand is constant.
1167 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1168 return getConstant(
1169 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1170
1171 // trunc(trunc(x)) --> trunc(x)
1172 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1173 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1174
1175 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1176 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1178
1179 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1180 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1182
1183 if (Depth > MaxCastDepth) {
1184 SCEV *S =
1185 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1186 UniqueSCEVs.InsertNode(S, IP);
1187 registerUser(S, Op);
1188 return S;
1189 }
1190
1191 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1192 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1193 // if after transforming we have at most one truncate, not counting truncates
1194 // that replace other casts.
1195 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1196 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1198 unsigned numTruncs = 0;
1199 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1200 ++i) {
1201 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1202 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1203 isa<SCEVTruncateExpr>(S))
1204 numTruncs++;
1205 Operands.push_back(S);
1206 }
1207 if (numTruncs < 2) {
1208 if (isa<SCEVAddExpr>(Op))
1209 return getAddExpr(Operands);
1210 if (isa<SCEVMulExpr>(Op))
1211 return getMulExpr(Operands);
1212 llvm_unreachable("Unexpected SCEV type for Op.");
1213 }
1214 // Although we checked in the beginning that ID is not in the cache, it is
1215 // possible that during recursion and different modification ID was inserted
1216 // into the cache. So if we find it, just return it.
1217 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1218 return S;
1219 }
1220
1221 // If the input value is a chrec scev, truncate the chrec's operands.
1222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1224 for (const SCEV *Op : AddRec->operands())
1225 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1226 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1227 }
1228
1229 // Return zero if truncating to known zeros.
1230 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1231 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1232 return getZero(Ty);
1233
1234 // The cast wasn't folded; create an explicit cast node. We can reuse
1235 // the existing insert position since if we get here, we won't have
1236 // made any changes which would invalidate it.
1237 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1238 Op, Ty);
1239 UniqueSCEVs.InsertNode(S, IP);
1240 registerUser(S, Op);
1241 return S;
1242}
1243
1244// Get the limit of a recurrence such that incrementing by Step cannot cause
1245// signed overflow as long as the value of the recurrence within the
1246// loop does not exceed this limit before incrementing.
1247static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1248 ICmpInst::Predicate *Pred,
1249 ScalarEvolution *SE) {
1250 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1251 if (SE->isKnownPositive(Step)) {
1252 *Pred = ICmpInst::ICMP_SLT;
1254 SE->getSignedRangeMax(Step));
1255 }
1256 if (SE->isKnownNegative(Step)) {
1257 *Pred = ICmpInst::ICMP_SGT;
1259 SE->getSignedRangeMin(Step));
1260 }
1261 return nullptr;
1262}
1263
1264// Get the limit of a recurrence such that incrementing by Step cannot cause
1265// unsigned overflow as long as the value of the recurrence within the loop does
1266// not exceed this limit before incrementing.
1268 ICmpInst::Predicate *Pred,
1269 ScalarEvolution *SE) {
1270 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1271 *Pred = ICmpInst::ICMP_ULT;
1272
1274 SE->getUnsignedRangeMax(Step));
1275}
1276
1277namespace {
1278
1279struct ExtendOpTraitsBase {
1280 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1281 unsigned);
1282};
1283
1284// Used to make code generic over signed and unsigned overflow.
1285template <typename ExtendOp> struct ExtendOpTraits {
1286 // Members present:
1287 //
1288 // static const SCEV::NoWrapFlags WrapType;
1289 //
1290 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1291 //
1292 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1293 // ICmpInst::Predicate *Pred,
1294 // ScalarEvolution *SE);
1295};
1296
1297template <>
1298struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1299 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1300
1301 static const GetExtendExprTy GetExtendExpr;
1302
1303 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1304 ICmpInst::Predicate *Pred,
1305 ScalarEvolution *SE) {
1306 return getSignedOverflowLimitForStep(Step, Pred, SE);
1307 }
1308};
1309
1310const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1312
1313template <>
1314struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1315 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1316
1317 static const GetExtendExprTy GetExtendExpr;
1318
1319 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1320 ICmpInst::Predicate *Pred,
1321 ScalarEvolution *SE) {
1322 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1323 }
1324};
1325
1326const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1328
1329} // end anonymous namespace
1330
1331// The recurrence AR has been shown to have no signed/unsigned wrap or something
1332// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1333// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1334// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1335// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1336// expression "Step + sext/zext(PreIncAR)" is congruent with
1337// "sext/zext(PostIncAR)"
1338template <typename ExtendOpTy>
1339static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1340 ScalarEvolution *SE, unsigned Depth) {
1341 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1342 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1343
1344 const Loop *L = AR->getLoop();
1345 const SCEV *Start = AR->getStart();
1346 const SCEV *Step = AR->getStepRecurrence(*SE);
1347
1348 // Check for a simple looking step prior to loop entry.
1349 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1350 if (!SA)
1351 return nullptr;
1352
1353 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1354 // subtraction is expensive. For this purpose, perform a quick and dirty
1355 // difference, by checking for Step in the operand list. Note, that
1356 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1358 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1359 if (*It == Step) {
1360 DiffOps.erase(It);
1361 break;
1362 }
1363
1364 if (DiffOps.size() == SA->getNumOperands())
1365 return nullptr;
1366
1367 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1368 // `Step`:
1369
1370 // 1. NSW/NUW flags on the step increment.
1371 auto PreStartFlags =
1373 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1374 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1375 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1376
1377 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1378 // "S+X does not sign/unsign-overflow".
1379 //
1380
1381 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1382 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1383 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1384 return PreStart;
1385
1386 // 2. Direct overflow check on the step operation's expression.
1387 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1388 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1389 const SCEV *OperandExtendedStart =
1390 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1391 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1392 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1393 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1394 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1395 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1396 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1397 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1398 }
1399 return PreStart;
1400 }
1401
1402 // 3. Loop precondition.
1404 const SCEV *OverflowLimit =
1405 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1406
1407 if (OverflowLimit &&
1408 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1409 return PreStart;
1410
1411 return nullptr;
1412}
1413
1414// Get the normalized zero or sign extended expression for this AddRec's Start.
1415template <typename ExtendOpTy>
1416static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1417 ScalarEvolution *SE,
1418 unsigned Depth) {
1419 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1420
1421 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1422 if (!PreStart)
1423 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1424
1425 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1426 Depth),
1427 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1428}
1429
1430// Try to prove away overflow by looking at "nearby" add recurrences. A
1431// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1432// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1433//
1434// Formally:
1435//
1436// {S,+,X} == {S-T,+,X} + T
1437// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1438//
1439// If ({S-T,+,X} + T) does not overflow ... (1)
1440//
1441// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1442//
1443// If {S-T,+,X} does not overflow ... (2)
1444//
1445// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1446// == {Ext(S-T)+Ext(T),+,Ext(X)}
1447//
1448// If (S-T)+T does not overflow ... (3)
1449//
1450// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1451// == {Ext(S),+,Ext(X)} == LHS
1452//
1453// Thus, if (1), (2) and (3) are true for some T, then
1454// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1455//
1456// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1457// does not overflow" restricted to the 0th iteration. Therefore we only need
1458// to check for (1) and (2).
1459//
1460// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1461// is `Delta` (defined below).
1462template <typename ExtendOpTy>
1463bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1464 const SCEV *Step,
1465 const Loop *L) {
1466 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1467
1468 // We restrict `Start` to a constant to prevent SCEV from spending too much
1469 // time here. It is correct (but more expensive) to continue with a
1470 // non-constant `Start` and do a general SCEV subtraction to compute
1471 // `PreStart` below.
1472 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1473 if (!StartC)
1474 return false;
1475
1476 APInt StartAI = StartC->getAPInt();
1477
1478 for (unsigned Delta : {-2, -1, 1, 2}) {
1479 const SCEV *PreStart = getConstant(StartAI - Delta);
1480
1482 ID.AddInteger(scAddRecExpr);
1483 ID.AddPointer(PreStart);
1484 ID.AddPointer(Step);
1485 ID.AddPointer(L);
1486 void *IP = nullptr;
1487 const auto *PreAR =
1488 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1489
1490 // Give up if we don't already have the add recurrence we need because
1491 // actually constructing an add recurrence is relatively expensive.
1492 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1493 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1495 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1496 DeltaS, &Pred, this);
1497 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1498 return true;
1499 }
1500 }
1501
1502 return false;
1503}
1504
1505// Finds an integer D for an expression (C + x + y + ...) such that the top
1506// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1507// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1508// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1509// the (C + x + y + ...) expression is \p WholeAddExpr.
1511 const SCEVConstant *ConstantTerm,
1512 const SCEVAddExpr *WholeAddExpr) {
1513 const APInt &C = ConstantTerm->getAPInt();
1514 const unsigned BitWidth = C.getBitWidth();
1515 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1516 uint32_t TZ = BitWidth;
1517 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1518 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1519 if (TZ) {
1520 // Set D to be as many least significant bits of C as possible while still
1521 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1522 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1523 }
1524 return APInt(BitWidth, 0);
1525}
1526
1527// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1528// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1529// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1530// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1532 const APInt &ConstantStart,
1533 const SCEV *Step) {
1534 const unsigned BitWidth = ConstantStart.getBitWidth();
1535 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1536 if (TZ)
1537 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1538 : ConstantStart;
1539 return APInt(BitWidth, 0);
1540}
1541
1543 const ScalarEvolution::FoldID &ID, const SCEV *S,
1546 &FoldCacheUser) {
1547 auto I = FoldCache.insert({ID, S});
1548 if (!I.second) {
1549 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1550 // entry.
1551 auto &UserIDs = FoldCacheUser[I.first->second];
1552 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1553 for (unsigned I = 0; I != UserIDs.size(); ++I)
1554 if (UserIDs[I] == ID) {
1555 std::swap(UserIDs[I], UserIDs.back());
1556 break;
1557 }
1558 UserIDs.pop_back();
1559 I.first->second = S;
1560 }
1561 FoldCacheUser[S].push_back(ID);
1562}
1563
1564const SCEV *
1566 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1567 "This is not an extending conversion!");
1568 assert(isSCEVable(Ty) &&
1569 "This is not a conversion to a SCEVable type!");
1570 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1571 Ty = getEffectiveSCEVType(Ty);
1572
1573 FoldID ID(scZeroExtend, Op, Ty);
1574 auto Iter = FoldCache.find(ID);
1575 if (Iter != FoldCache.end())
1576 return Iter->second;
1577
1578 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1579 if (!isa<SCEVZeroExtendExpr>(S))
1580 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1581 return S;
1582}
1583
1585 unsigned Depth) {
1586 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1587 "This is not an extending conversion!");
1588 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1589 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1590
1591 // Fold if the operand is constant.
1592 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1593 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1594
1595 // zext(zext(x)) --> zext(x)
1596 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1598
1599 // Before doing any expensive analysis, check to see if we've already
1600 // computed a SCEV for this Op and Ty.
1602 ID.AddInteger(scZeroExtend);
1603 ID.AddPointer(Op);
1604 ID.AddPointer(Ty);
1605 void *IP = nullptr;
1606 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607 if (Depth > MaxCastDepth) {
1608 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1609 Op, Ty);
1610 UniqueSCEVs.InsertNode(S, IP);
1611 registerUser(S, Op);
1612 return S;
1613 }
1614
1615 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617 // It's possible the bits taken off by the truncate were all zero bits. If
1618 // so, we should be able to simplify this further.
1619 const SCEV *X = ST->getOperand();
1621 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622 unsigned NewBits = getTypeSizeInBits(Ty);
1623 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624 CR.zextOrTrunc(NewBits)))
1625 return getTruncateOrZeroExtend(X, Ty, Depth);
1626 }
1627
1628 // If the input value is a chrec scev, and we can prove that the value
1629 // did not overflow the old, smaller, value, we can zero extend all of the
1630 // operands (often constants). This allows analysis of something like
1631 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633 if (AR->isAffine()) {
1634 const SCEV *Start = AR->getStart();
1635 const SCEV *Step = AR->getStepRecurrence(*this);
1636 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637 const Loop *L = AR->getLoop();
1638
1639 // If we have special knowledge that this addrec won't overflow,
1640 // we don't need to do any further analysis.
1641 if (AR->hasNoUnsignedWrap()) {
1642 Start =
1643 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1644 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1645 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1646 }
1647
1648 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1649 // Note that this serves two purposes: It filters out loops that are
1650 // simply not analyzable, and it covers the case where this code is
1651 // being called from within backedge-taken count analysis, such that
1652 // attempting to ask for the backedge-taken count would likely result
1653 // in infinite recursion. In the later case, the analysis code will
1654 // cope with a conservative value, and it will take care to purge
1655 // that value once it has finished.
1656 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1657 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1658 // Manually compute the final value for AR, checking for overflow.
1659
1660 // Check whether the backedge-taken count can be losslessly casted to
1661 // the addrec's type. The count is always unsigned.
1662 const SCEV *CastedMaxBECount =
1663 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1664 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1665 CastedMaxBECount, MaxBECount->getType(), Depth);
1666 if (MaxBECount == RecastedMaxBECount) {
1667 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1668 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1669 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1671 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1673 Depth + 1),
1674 WideTy, Depth + 1);
1675 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1676 const SCEV *WideMaxBECount =
1677 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1678 const SCEV *OperandExtendedAdd =
1679 getAddExpr(WideStart,
1680 getMulExpr(WideMaxBECount,
1681 getZeroExtendExpr(Step, WideTy, Depth + 1),
1684 if (ZAdd == OperandExtendedAdd) {
1685 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1686 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1687 // Return the expression with the addrec on the outside.
1688 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1689 Depth + 1);
1690 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1691 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1692 }
1693 // Similar to above, only this time treat the step value as signed.
1694 // This covers loops that count down.
1695 OperandExtendedAdd =
1696 getAddExpr(WideStart,
1697 getMulExpr(WideMaxBECount,
1698 getSignExtendExpr(Step, WideTy, Depth + 1),
1701 if (ZAdd == OperandExtendedAdd) {
1702 // Cache knowledge of AR NW, which is propagated to this AddRec.
1703 // Negative step causes unsigned wrap, but it still can't self-wrap.
1704 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1705 // Return the expression with the addrec on the outside.
1706 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707 Depth + 1);
1708 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1709 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 }
1711 }
1712 }
1713
1714 // Normally, in the cases we can prove no-overflow via a
1715 // backedge guarding condition, we can also compute a backedge
1716 // taken count for the loop. The exceptions are assumptions and
1717 // guards present in the loop -- SCEV is not great at exploiting
1718 // these to compute max backedge taken counts, but can still use
1719 // these to prove lack of overflow. Use this fact to avoid
1720 // doing extra work that may not pay off.
1721 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1722 !AC.assumptions().empty()) {
1723
1724 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1725 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1726 if (AR->hasNoUnsignedWrap()) {
1727 // Same as nuw case above - duplicated here to avoid a compile time
1728 // issue. It's not clear that the order of checks does matter, but
1729 // it's one of two issue possible causes for a change which was
1730 // reverted. Be conservative for the moment.
1731 Start =
1732 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1733 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1734 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1735 }
1736
1737 // For a negative step, we can extend the operands iff doing so only
1738 // traverses values in the range zext([0,UINT_MAX]).
1739 if (isKnownNegative(Step)) {
1741 getSignedRangeMin(Step));
1744 // Cache knowledge of AR NW, which is propagated to this
1745 // AddRec. Negative step causes unsigned wrap, but it
1746 // still can't self-wrap.
1747 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1748 // Return the expression with the addrec on the outside.
1749 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1750 Depth + 1);
1751 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1752 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1753 }
1754 }
1755 }
1756
1757 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1758 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1759 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1760 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1761 const APInt &C = SC->getAPInt();
1762 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1763 if (D != 0) {
1764 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1765 const SCEV *SResidual =
1766 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1767 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1768 return getAddExpr(SZExtD, SZExtR,
1770 Depth + 1);
1771 }
1772 }
1773
1774 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1775 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1776 Start =
1777 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1778 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1779 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1780 }
1781 }
1782
1783 // zext(A % B) --> zext(A) % zext(B)
1784 {
1785 const SCEV *LHS;
1786 const SCEV *RHS;
1787 if (matchURem(Op, LHS, RHS))
1788 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1789 getZeroExtendExpr(RHS, Ty, Depth + 1));
1790 }
1791
1792 // zext(A / B) --> zext(A) / zext(B).
1793 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1794 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1795 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1796
1797 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1798 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1799 if (SA->hasNoUnsignedWrap()) {
1800 // If the addition does not unsign overflow then we can, by definition,
1801 // commute the zero extension with the addition operation.
1803 for (const auto *Op : SA->operands())
1804 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1805 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1806 }
1807
1808 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1809 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1810 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1811 //
1812 // Often address arithmetics contain expressions like
1813 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1814 // This transformation is useful while proving that such expressions are
1815 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1816 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1817 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1818 if (D != 0) {
1819 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1820 const SCEV *SResidual =
1822 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1823 return getAddExpr(SZExtD, SZExtR,
1825 Depth + 1);
1826 }
1827 }
1828 }
1829
1830 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1831 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1832 if (SM->hasNoUnsignedWrap()) {
1833 // If the multiply does not unsign overflow then we can, by definition,
1834 // commute the zero extension with the multiply operation.
1836 for (const auto *Op : SM->operands())
1837 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1838 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1839 }
1840
1841 // zext(2^K * (trunc X to iN)) to iM ->
1842 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1843 //
1844 // Proof:
1845 //
1846 // zext(2^K * (trunc X to iN)) to iM
1847 // = zext((trunc X to iN) << K) to iM
1848 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1849 // (because shl removes the top K bits)
1850 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1851 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1852 //
1853 if (SM->getNumOperands() == 2)
1854 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1855 if (MulLHS->getAPInt().isPowerOf2())
1856 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1857 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1858 MulLHS->getAPInt().logBase2();
1859 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1860 return getMulExpr(
1861 getZeroExtendExpr(MulLHS, Ty),
1863 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1864 SCEV::FlagNUW, Depth + 1);
1865 }
1866 }
1867
1868 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1869 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1870 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1871 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1873 for (auto *Operand : MinMax->operands())
1874 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1875 if (isa<SCEVUMinExpr>(MinMax))
1876 return getUMinExpr(Operands);
1877 return getUMaxExpr(Operands);
1878 }
1879
1880 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1881 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1882 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1884 for (auto *Operand : MinMax->operands())
1885 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1886 return getUMinExpr(Operands, /*Sequential*/ true);
1887 }
1888
1889 // The cast wasn't folded; create an explicit cast node.
1890 // Recompute the insert position, as it may have been invalidated.
1891 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1892 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1893 Op, Ty);
1894 UniqueSCEVs.InsertNode(S, IP);
1895 registerUser(S, Op);
1896 return S;
1897}
1898
1899const SCEV *
1901 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1902 "This is not an extending conversion!");
1903 assert(isSCEVable(Ty) &&
1904 "This is not a conversion to a SCEVable type!");
1905 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1906 Ty = getEffectiveSCEVType(Ty);
1907
1908 FoldID ID(scSignExtend, Op, Ty);
1909 auto Iter = FoldCache.find(ID);
1910 if (Iter != FoldCache.end())
1911 return Iter->second;
1912
1913 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1914 if (!isa<SCEVSignExtendExpr>(S))
1915 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1916 return S;
1917}
1918
1920 unsigned Depth) {
1921 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1922 "This is not an extending conversion!");
1923 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1924 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1925 Ty = getEffectiveSCEVType(Ty);
1926
1927 // Fold if the operand is constant.
1928 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1929 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1930
1931 // sext(sext(x)) --> sext(x)
1932 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1933 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1934
1935 // sext(zext(x)) --> zext(x)
1936 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1937 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1938
1939 // Before doing any expensive analysis, check to see if we've already
1940 // computed a SCEV for this Op and Ty.
1942 ID.AddInteger(scSignExtend);
1943 ID.AddPointer(Op);
1944 ID.AddPointer(Ty);
1945 void *IP = nullptr;
1946 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1947 // Limit recursion depth.
1948 if (Depth > MaxCastDepth) {
1949 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1950 Op, Ty);
1951 UniqueSCEVs.InsertNode(S, IP);
1952 registerUser(S, Op);
1953 return S;
1954 }
1955
1956 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1957 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1958 // It's possible the bits taken off by the truncate were all sign bits. If
1959 // so, we should be able to simplify this further.
1960 const SCEV *X = ST->getOperand();
1962 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1963 unsigned NewBits = getTypeSizeInBits(Ty);
1964 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1965 CR.sextOrTrunc(NewBits)))
1966 return getTruncateOrSignExtend(X, Ty, Depth);
1967 }
1968
1969 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1970 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1971 if (SA->hasNoSignedWrap()) {
1972 // If the addition does not sign overflow then we can, by definition,
1973 // commute the sign extension with the addition operation.
1975 for (const auto *Op : SA->operands())
1976 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1977 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1978 }
1979
1980 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1981 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1982 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1983 //
1984 // For instance, this will bring two seemingly different expressions:
1985 // 1 + sext(5 + 20 * %x + 24 * %y) and
1986 // sext(6 + 20 * %x + 24 * %y)
1987 // to the same form:
1988 // 2 + sext(4 + 20 * %x + 24 * %y)
1989 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1990 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1991 if (D != 0) {
1992 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1993 const SCEV *SResidual =
1995 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1996 return getAddExpr(SSExtD, SSExtR,
1998 Depth + 1);
1999 }
2000 }
2001 }
2002 // If the input value is a chrec scev, and we can prove that the value
2003 // did not overflow the old, smaller, value, we can sign extend all of the
2004 // operands (often constants). This allows analysis of something like
2005 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2006 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
2007 if (AR->isAffine()) {
2008 const SCEV *Start = AR->getStart();
2009 const SCEV *Step = AR->getStepRecurrence(*this);
2010 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2011 const Loop *L = AR->getLoop();
2012
2013 // If we have special knowledge that this addrec won't overflow,
2014 // we don't need to do any further analysis.
2015 if (AR->hasNoSignedWrap()) {
2016 Start =
2017 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2018 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2019 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2020 }
2021
2022 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2023 // Note that this serves two purposes: It filters out loops that are
2024 // simply not analyzable, and it covers the case where this code is
2025 // being called from within backedge-taken count analysis, such that
2026 // attempting to ask for the backedge-taken count would likely result
2027 // in infinite recursion. In the later case, the analysis code will
2028 // cope with a conservative value, and it will take care to purge
2029 // that value once it has finished.
2030 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2031 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2032 // Manually compute the final value for AR, checking for
2033 // overflow.
2034
2035 // Check whether the backedge-taken count can be losslessly casted to
2036 // the addrec's type. The count is always unsigned.
2037 const SCEV *CastedMaxBECount =
2038 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2039 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2040 CastedMaxBECount, MaxBECount->getType(), Depth);
2041 if (MaxBECount == RecastedMaxBECount) {
2042 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2043 // Check whether Start+Step*MaxBECount has no signed overflow.
2044 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2046 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2048 Depth + 1),
2049 WideTy, Depth + 1);
2050 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2051 const SCEV *WideMaxBECount =
2052 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2053 const SCEV *OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getSignExtendExpr(Step, WideTy, Depth + 1),
2059 if (SAdd == OperandExtendedAdd) {
2060 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2061 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2062 // Return the expression with the addrec on the outside.
2063 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2064 Depth + 1);
2065 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2066 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2067 }
2068 // Similar to above, only this time treat the step value as unsigned.
2069 // This covers loops that count up with an unsigned step.
2070 OperandExtendedAdd =
2071 getAddExpr(WideStart,
2072 getMulExpr(WideMaxBECount,
2073 getZeroExtendExpr(Step, WideTy, Depth + 1),
2076 if (SAdd == OperandExtendedAdd) {
2077 // If AR wraps around then
2078 //
2079 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2080 // => SAdd != OperandExtendedAdd
2081 //
2082 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2083 // (SAdd == OperandExtendedAdd => AR is NW)
2084
2085 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2086
2087 // Return the expression with the addrec on the outside.
2088 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2089 Depth + 1);
2090 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2091 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 }
2093 }
2094 }
2095
2096 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2098 if (AR->hasNoSignedWrap()) {
2099 // Same as nsw case above - duplicated here to avoid a compile time
2100 // issue. It's not clear that the order of checks does matter, but
2101 // it's one of two issue possible causes for a change which was
2102 // reverted. Be conservative for the moment.
2103 Start =
2104 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2105 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2106 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2107 }
2108
2109 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2110 // if D + (C - D + Step * n) could be proven to not signed wrap
2111 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2112 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2113 const APInt &C = SC->getAPInt();
2114 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2115 if (D != 0) {
2116 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2117 const SCEV *SResidual =
2118 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2119 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2120 return getAddExpr(SSExtD, SSExtR,
2122 Depth + 1);
2123 }
2124 }
2125
2126 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2127 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2128 Start =
2129 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2130 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2131 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2132 }
2133 }
2134
2135 // If the input value is provably positive and we could not simplify
2136 // away the sext build a zext instead.
2138 return getZeroExtendExpr(Op, Ty, Depth + 1);
2139
2140 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2141 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2142 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2143 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2145 for (auto *Operand : MinMax->operands())
2146 Operands.push_back(getSignExtendExpr(Operand, Ty));
2147 if (isa<SCEVSMinExpr>(MinMax))
2148 return getSMinExpr(Operands);
2149 return getSMaxExpr(Operands);
2150 }
2151
2152 // The cast wasn't folded; create an explicit cast node.
2153 // Recompute the insert position, as it may have been invalidated.
2154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2155 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2156 Op, Ty);
2157 UniqueSCEVs.InsertNode(S, IP);
2158 registerUser(S, { Op });
2159 return S;
2160}
2161
2163 Type *Ty) {
2164 switch (Kind) {
2165 case scTruncate:
2166 return getTruncateExpr(Op, Ty);
2167 case scZeroExtend:
2168 return getZeroExtendExpr(Op, Ty);
2169 case scSignExtend:
2170 return getSignExtendExpr(Op, Ty);
2171 case scPtrToInt:
2172 return getPtrToIntExpr(Op, Ty);
2173 default:
2174 llvm_unreachable("Not a SCEV cast expression!");
2175 }
2176}
2177
2178/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2179/// unspecified bits out to the given type.
2181 Type *Ty) {
2182 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2183 "This is not an extending conversion!");
2184 assert(isSCEVable(Ty) &&
2185 "This is not a conversion to a SCEVable type!");
2186 Ty = getEffectiveSCEVType(Ty);
2187
2188 // Sign-extend negative constants.
2189 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2190 if (SC->getAPInt().isNegative())
2191 return getSignExtendExpr(Op, Ty);
2192
2193 // Peel off a truncate cast.
2194 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2195 const SCEV *NewOp = T->getOperand();
2196 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2197 return getAnyExtendExpr(NewOp, Ty);
2198 return getTruncateOrNoop(NewOp, Ty);
2199 }
2200
2201 // Next try a zext cast. If the cast is folded, use it.
2202 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2203 if (!isa<SCEVZeroExtendExpr>(ZExt))
2204 return ZExt;
2205
2206 // Next try a sext cast. If the cast is folded, use it.
2207 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2208 if (!isa<SCEVSignExtendExpr>(SExt))
2209 return SExt;
2210
2211 // Force the cast to be folded into the operands of an addrec.
2212 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2214 for (const SCEV *Op : AR->operands())
2215 Ops.push_back(getAnyExtendExpr(Op, Ty));
2216 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2217 }
2218
2219 // If the expression is obviously signed, use the sext cast value.
2220 if (isa<SCEVSMaxExpr>(Op))
2221 return SExt;
2222
2223 // Absent any other information, use the zext cast value.
2224 return ZExt;
2225}
2226
2227/// Process the given Ops list, which is a list of operands to be added under
2228/// the given scale, update the given map. This is a helper function for
2229/// getAddRecExpr. As an example of what it does, given a sequence of operands
2230/// that would form an add expression like this:
2231///
2232/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2233///
2234/// where A and B are constants, update the map with these values:
2235///
2236/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2237///
2238/// and add 13 + A*B*29 to AccumulatedConstant.
2239/// This will allow getAddRecExpr to produce this:
2240///
2241/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2242///
2243/// This form often exposes folding opportunities that are hidden in
2244/// the original operand list.
2245///
2246/// Return true iff it appears that any interesting folding opportunities
2247/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2248/// the common case where no interesting opportunities are present, and
2249/// is also used as a check to avoid infinite recursion.
2250static bool
2253 APInt &AccumulatedConstant,
2254 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2255 ScalarEvolution &SE) {
2256 bool Interesting = false;
2257
2258 // Iterate over the add operands. They are sorted, with constants first.
2259 unsigned i = 0;
2260 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2261 ++i;
2262 // Pull a buried constant out to the outside.
2263 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2264 Interesting = true;
2265 AccumulatedConstant += Scale * C->getAPInt();
2266 }
2267
2268 // Next comes everything else. We're especially interested in multiplies
2269 // here, but they're in the middle, so just visit the rest with one loop.
2270 for (; i != Ops.size(); ++i) {
2271 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2272 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2273 APInt NewScale =
2274 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2275 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2276 // A multiplication of a constant with another add; recurse.
2277 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2278 Interesting |=
2279 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2280 Add->operands(), NewScale, SE);
2281 } else {
2282 // A multiplication of a constant with some other value. Update
2283 // the map.
2284 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2285 const SCEV *Key = SE.getMulExpr(MulOps);
2286 auto Pair = M.insert({Key, NewScale});
2287 if (Pair.second) {
2288 NewOps.push_back(Pair.first->first);
2289 } else {
2290 Pair.first->second += NewScale;
2291 // The map already had an entry for this value, which may indicate
2292 // a folding opportunity.
2293 Interesting = true;
2294 }
2295 }
2296 } else {
2297 // An ordinary operand. Update the map.
2298 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2299 M.insert({Ops[i], Scale});
2300 if (Pair.second) {
2301 NewOps.push_back(Pair.first->first);
2302 } else {
2303 Pair.first->second += Scale;
2304 // The map already had an entry for this value, which may indicate
2305 // a folding opportunity.
2306 Interesting = true;
2307 }
2308 }
2309 }
2310
2311 return Interesting;
2312}
2313
2315 const SCEV *LHS, const SCEV *RHS,
2316 const Instruction *CtxI) {
2317 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2318 SCEV::NoWrapFlags, unsigned);
2319 switch (BinOp) {
2320 default:
2321 llvm_unreachable("Unsupported binary op");
2322 case Instruction::Add:
2324 break;
2325 case Instruction::Sub:
2327 break;
2328 case Instruction::Mul:
2330 break;
2331 }
2332
2333 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2336
2337 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2338 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2339 auto *WideTy =
2340 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2341
2342 const SCEV *A = (this->*Extension)(
2343 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2344 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2345 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2346 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2347 if (A == B)
2348 return true;
2349 // Can we use context to prove the fact we need?
2350 if (!CtxI)
2351 return false;
2352 // TODO: Support mul.
2353 if (BinOp == Instruction::Mul)
2354 return false;
2355 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2356 // TODO: Lift this limitation.
2357 if (!RHSC)
2358 return false;
2359 APInt C = RHSC->getAPInt();
2360 unsigned NumBits = C.getBitWidth();
2361 bool IsSub = (BinOp == Instruction::Sub);
2362 bool IsNegativeConst = (Signed && C.isNegative());
2363 // Compute the direction and magnitude by which we need to check overflow.
2364 bool OverflowDown = IsSub ^ IsNegativeConst;
2365 APInt Magnitude = C;
2366 if (IsNegativeConst) {
2367 if (C == APInt::getSignedMinValue(NumBits))
2368 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2369 // want to deal with that.
2370 return false;
2371 Magnitude = -C;
2372 }
2373
2375 if (OverflowDown) {
2376 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2377 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2378 : APInt::getMinValue(NumBits);
2379 APInt Limit = Min + Magnitude;
2380 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2381 } else {
2382 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2383 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2384 : APInt::getMaxValue(NumBits);
2385 APInt Limit = Max - Magnitude;
2386 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2387 }
2388}
2389
2390std::optional<SCEV::NoWrapFlags>
2392 const OverflowingBinaryOperator *OBO) {
2393 // It cannot be done any better.
2394 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2395 return std::nullopt;
2396
2398
2399 if (OBO->hasNoUnsignedWrap())
2401 if (OBO->hasNoSignedWrap())
2403
2404 bool Deduced = false;
2405
2406 if (OBO->getOpcode() != Instruction::Add &&
2407 OBO->getOpcode() != Instruction::Sub &&
2408 OBO->getOpcode() != Instruction::Mul)
2409 return std::nullopt;
2410
2411 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2412 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2413
2414 const Instruction *CtxI =
2415 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2416 if (!OBO->hasNoUnsignedWrap() &&
2418 /* Signed */ false, LHS, RHS, CtxI)) {
2420 Deduced = true;
2421 }
2422
2423 if (!OBO->hasNoSignedWrap() &&
2425 /* Signed */ true, LHS, RHS, CtxI)) {
2427 Deduced = true;
2428 }
2429
2430 if (Deduced)
2431 return Flags;
2432 return std::nullopt;
2433}
2434
2435// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2436// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2437// can't-overflow flags for the operation if possible.
2438static SCEV::NoWrapFlags
2440 const ArrayRef<const SCEV *> Ops,
2441 SCEV::NoWrapFlags Flags) {
2442 using namespace std::placeholders;
2443
2444 using OBO = OverflowingBinaryOperator;
2445
2446 bool CanAnalyze =
2448 (void)CanAnalyze;
2449 assert(CanAnalyze && "don't call from other places!");
2450
2451 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2452 SCEV::NoWrapFlags SignOrUnsignWrap =
2453 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2456 auto IsKnownNonNegative = [&](const SCEV *S) {
2457 return SE->isKnownNonNegative(S);
2458 };
2459
2460 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2461 Flags =
2462 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2463
2464 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2465
2466 if (SignOrUnsignWrap != SignOrUnsignMask &&
2467 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2468 isa<SCEVConstant>(Ops[0])) {
2469
2470 auto Opcode = [&] {
2471 switch (Type) {
2472 case scAddExpr:
2473 return Instruction::Add;
2474 case scMulExpr:
2475 return Instruction::Mul;
2476 default:
2477 llvm_unreachable("Unexpected SCEV op.");
2478 }
2479 }();
2480
2481 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2482
2483 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2484 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2486 Opcode, C, OBO::NoSignedWrap);
2487 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2489 }
2490
2491 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2492 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2494 Opcode, C, OBO::NoUnsignedWrap);
2495 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2497 }
2498 }
2499
2500 // <0,+,nonnegative><nw> is also nuw
2501 // TODO: Add corresponding nsw case
2503 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2504 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2506
2507 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2509 Ops.size() == 2) {
2510 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2511 if (UDiv->getOperand(1) == Ops[1])
2513 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2514 if (UDiv->getOperand(1) == Ops[0])
2516 }
2517
2518 return Flags;
2519}
2520
2522 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2523}
2524
2525/// Get a canonical add expression, or something simpler if possible.
2527 SCEV::NoWrapFlags OrigFlags,
2528 unsigned Depth) {
2529 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2530 "only nuw or nsw allowed");
2531 assert(!Ops.empty() && "Cannot get empty add!");
2532 if (Ops.size() == 1) return Ops[0];
2533#ifndef NDEBUG
2534 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2535 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2536 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2537 "SCEVAddExpr operand types don't match!");
2538 unsigned NumPtrs = count_if(
2539 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2540 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2541#endif
2542
2543 const SCEV *Folded = constantFoldAndGroupOps(
2544 *this, LI, DT, Ops,
2545 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2546 [](const APInt &C) { return C.isZero(); }, // identity
2547 [](const APInt &C) { return false; }); // absorber
2548 if (Folded)
2549 return Folded;
2550
2551 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2552
2553 // Delay expensive flag strengthening until necessary.
2554 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2555 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2556 };
2557
2558 // Limit recursion calls depth.
2560 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2561
2562 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2563 // Don't strengthen flags if we have no new information.
2564 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2565 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2566 Add->setNoWrapFlags(ComputeFlags(Ops));
2567 return S;
2568 }
2569
2570 // Okay, check to see if the same value occurs in the operand list more than
2571 // once. If so, merge them together into an multiply expression. Since we
2572 // sorted the list, these values are required to be adjacent.
2573 Type *Ty = Ops[0]->getType();
2574 bool FoundMatch = false;
2575 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2576 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2577 // Scan ahead to count how many equal operands there are.
2578 unsigned Count = 2;
2579 while (i+Count != e && Ops[i+Count] == Ops[i])
2580 ++Count;
2581 // Merge the values into a multiply.
2582 const SCEV *Scale = getConstant(Ty, Count);
2583 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2584 if (Ops.size() == Count)
2585 return Mul;
2586 Ops[i] = Mul;
2587 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2588 --i; e -= Count - 1;
2589 FoundMatch = true;
2590 }
2591 if (FoundMatch)
2592 return getAddExpr(Ops, OrigFlags, Depth + 1);
2593
2594 // Check for truncates. If all the operands are truncated from the same
2595 // type, see if factoring out the truncate would permit the result to be
2596 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2597 // if the contents of the resulting outer trunc fold to something simple.
2598 auto FindTruncSrcType = [&]() -> Type * {
2599 // We're ultimately looking to fold an addrec of truncs and muls of only
2600 // constants and truncs, so if we find any other types of SCEV
2601 // as operands of the addrec then we bail and return nullptr here.
2602 // Otherwise, we return the type of the operand of a trunc that we find.
2603 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2604 return T->getOperand()->getType();
2605 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2606 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2607 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2608 return T->getOperand()->getType();
2609 }
2610 return nullptr;
2611 };
2612 if (auto *SrcType = FindTruncSrcType()) {
2614 bool Ok = true;
2615 // Check all the operands to see if they can be represented in the
2616 // source type of the truncate.
2617 for (const SCEV *Op : Ops) {
2618 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2619 if (T->getOperand()->getType() != SrcType) {
2620 Ok = false;
2621 break;
2622 }
2623 LargeOps.push_back(T->getOperand());
2624 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2625 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2626 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2627 SmallVector<const SCEV *, 8> LargeMulOps;
2628 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2629 if (const SCEVTruncateExpr *T =
2630 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2631 if (T->getOperand()->getType() != SrcType) {
2632 Ok = false;
2633 break;
2634 }
2635 LargeMulOps.push_back(T->getOperand());
2636 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2637 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2638 } else {
2639 Ok = false;
2640 break;
2641 }
2642 }
2643 if (Ok)
2644 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2645 } else {
2646 Ok = false;
2647 break;
2648 }
2649 }
2650 if (Ok) {
2651 // Evaluate the expression in the larger type.
2652 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2653 // If it folds to something simple, use it. Otherwise, don't.
2654 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2655 return getTruncateExpr(Fold, Ty);
2656 }
2657 }
2658
2659 if (Ops.size() == 2) {
2660 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2661 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2662 // C1).
2663 const SCEV *A = Ops[0];
2664 const SCEV *B = Ops[1];
2665 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2666 auto *C = dyn_cast<SCEVConstant>(A);
2667 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2668 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2669 auto C2 = C->getAPInt();
2670 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2671
2672 APInt ConstAdd = C1 + C2;
2673 auto AddFlags = AddExpr->getNoWrapFlags();
2674 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2676 ConstAdd.ule(C1)) {
2677 PreservedFlags =
2679 }
2680
2681 // Adding a constant with the same sign and small magnitude is NSW, if the
2682 // original AddExpr was NSW.
2684 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2685 ConstAdd.abs().ule(C1.abs())) {
2686 PreservedFlags =
2688 }
2689
2690 if (PreservedFlags != SCEV::FlagAnyWrap) {
2691 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2692 NewOps[0] = getConstant(ConstAdd);
2693 return getAddExpr(NewOps, PreservedFlags);
2694 }
2695 }
2696 }
2697
2698 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2699 if (Ops.size() == 2) {
2700 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2701 if (Mul && Mul->getNumOperands() == 2 &&
2702 Mul->getOperand(0)->isAllOnesValue()) {
2703 const SCEV *X;
2704 const SCEV *Y;
2705 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2706 return getMulExpr(Y, getUDivExpr(X, Y));
2707 }
2708 }
2709 }
2710
2711 // Skip past any other cast SCEVs.
2712 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2713 ++Idx;
2714
2715 // If there are add operands they would be next.
2716 if (Idx < Ops.size()) {
2717 bool DeletedAdd = false;
2718 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2719 // common NUW flag for expression after inlining. Other flags cannot be
2720 // preserved, because they may depend on the original order of operations.
2721 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2722 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2723 if (Ops.size() > AddOpsInlineThreshold ||
2724 Add->getNumOperands() > AddOpsInlineThreshold)
2725 break;
2726 // If we have an add, expand the add operands onto the end of the operands
2727 // list.
2728 Ops.erase(Ops.begin()+Idx);
2729 append_range(Ops, Add->operands());
2730 DeletedAdd = true;
2731 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2732 }
2733
2734 // If we deleted at least one add, we added operands to the end of the list,
2735 // and they are not necessarily sorted. Recurse to resort and resimplify
2736 // any operands we just acquired.
2737 if (DeletedAdd)
2738 return getAddExpr(Ops, CommonFlags, Depth + 1);
2739 }
2740
2741 // Skip over the add expression until we get to a multiply.
2742 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2743 ++Idx;
2744
2745 // Check to see if there are any folding opportunities present with
2746 // operands multiplied by constant values.
2747 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2751 APInt AccumulatedConstant(BitWidth, 0);
2752 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2753 Ops, APInt(BitWidth, 1), *this)) {
2754 struct APIntCompare {
2755 bool operator()(const APInt &LHS, const APInt &RHS) const {
2756 return LHS.ult(RHS);
2757 }
2758 };
2759
2760 // Some interesting folding opportunity is present, so its worthwhile to
2761 // re-generate the operands list. Group the operands by constant scale,
2762 // to avoid multiplying by the same constant scale multiple times.
2763 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2764 for (const SCEV *NewOp : NewOps)
2765 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2766 // Re-generate the operands list.
2767 Ops.clear();
2768 if (AccumulatedConstant != 0)
2769 Ops.push_back(getConstant(AccumulatedConstant));
2770 for (auto &MulOp : MulOpLists) {
2771 if (MulOp.first == 1) {
2772 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2773 } else if (MulOp.first != 0) {
2775 getConstant(MulOp.first),
2776 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2777 SCEV::FlagAnyWrap, Depth + 1));
2778 }
2779 }
2780 if (Ops.empty())
2781 return getZero(Ty);
2782 if (Ops.size() == 1)
2783 return Ops[0];
2784 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2785 }
2786 }
2787
2788 // If we are adding something to a multiply expression, make sure the
2789 // something is not already an operand of the multiply. If so, merge it into
2790 // the multiply.
2791 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2792 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2793 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2794 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2795 if (isa<SCEVConstant>(MulOpSCEV))
2796 continue;
2797 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2798 if (MulOpSCEV == Ops[AddOp]) {
2799 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2800 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2801 if (Mul->getNumOperands() != 2) {
2802 // If the multiply has more than two operands, we must get the
2803 // Y*Z term.
2805 Mul->operands().take_front(MulOp));
2806 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2807 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2808 }
2809 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2810 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2811 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2813 if (Ops.size() == 2) return OuterMul;
2814 if (AddOp < Idx) {
2815 Ops.erase(Ops.begin()+AddOp);
2816 Ops.erase(Ops.begin()+Idx-1);
2817 } else {
2818 Ops.erase(Ops.begin()+Idx);
2819 Ops.erase(Ops.begin()+AddOp-1);
2820 }
2821 Ops.push_back(OuterMul);
2822 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2823 }
2824
2825 // Check this multiply against other multiplies being added together.
2826 for (unsigned OtherMulIdx = Idx+1;
2827 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 ++OtherMulIdx) {
2829 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2830 // If MulOp occurs in OtherMul, we can fold the two multiplies
2831 // together.
2832 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2833 OMulOp != e; ++OMulOp)
2834 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2835 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2836 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2837 if (Mul->getNumOperands() != 2) {
2839 Mul->operands().take_front(MulOp));
2840 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2841 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2842 }
2843 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2844 if (OtherMul->getNumOperands() != 2) {
2846 OtherMul->operands().take_front(OMulOp));
2847 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2848 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2849 }
2850 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2851 const SCEV *InnerMulSum =
2852 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2853 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2855 if (Ops.size() == 2) return OuterMul;
2856 Ops.erase(Ops.begin()+Idx);
2857 Ops.erase(Ops.begin()+OtherMulIdx-1);
2858 Ops.push_back(OuterMul);
2859 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2860 }
2861 }
2862 }
2863 }
2864
2865 // If there are any add recurrences in the operands list, see if any other
2866 // added values are loop invariant. If so, we can fold them into the
2867 // recurrence.
2868 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2869 ++Idx;
2870
2871 // Scan over all recurrences, trying to fold loop invariants into them.
2872 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2873 // Scan all of the other operands to this add and add them to the vector if
2874 // they are loop invariant w.r.t. the recurrence.
2876 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2877 const Loop *AddRecLoop = AddRec->getLoop();
2878 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2879 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2880 LIOps.push_back(Ops[i]);
2881 Ops.erase(Ops.begin()+i);
2882 --i; --e;
2883 }
2884
2885 // If we found some loop invariants, fold them into the recurrence.
2886 if (!LIOps.empty()) {
2887 // Compute nowrap flags for the addition of the loop-invariant ops and
2888 // the addrec. Temporarily push it as an operand for that purpose. These
2889 // flags are valid in the scope of the addrec only.
2890 LIOps.push_back(AddRec);
2891 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2892 LIOps.pop_back();
2893
2894 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2895 LIOps.push_back(AddRec->getStart());
2896
2897 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2898
2899 // It is not in general safe to propagate flags valid on an add within
2900 // the addrec scope to one outside it. We must prove that the inner
2901 // scope is guaranteed to execute if the outer one does to be able to
2902 // safely propagate. We know the program is undefined if poison is
2903 // produced on the inner scoped addrec. We also know that *for this use*
2904 // the outer scoped add can't overflow (because of the flags we just
2905 // computed for the inner scoped add) without the program being undefined.
2906 // Proving that entry to the outer scope neccesitates entry to the inner
2907 // scope, thus proves the program undefined if the flags would be violated
2908 // in the outer scope.
2909 SCEV::NoWrapFlags AddFlags = Flags;
2910 if (AddFlags != SCEV::FlagAnyWrap) {
2911 auto *DefI = getDefiningScopeBound(LIOps);
2912 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2913 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2914 AddFlags = SCEV::FlagAnyWrap;
2915 }
2916 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2917
2918 // Build the new addrec. Propagate the NUW and NSW flags if both the
2919 // outer add and the inner addrec are guaranteed to have no overflow.
2920 // Always propagate NW.
2921 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2922 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2923
2924 // If all of the other operands were loop invariant, we are done.
2925 if (Ops.size() == 1) return NewRec;
2926
2927 // Otherwise, add the folded AddRec by the non-invariant parts.
2928 for (unsigned i = 0;; ++i)
2929 if (Ops[i] == AddRec) {
2930 Ops[i] = NewRec;
2931 break;
2932 }
2933 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2934 }
2935
2936 // Okay, if there weren't any loop invariants to be folded, check to see if
2937 // there are multiple AddRec's with the same loop induction variable being
2938 // added together. If so, we can fold them.
2939 for (unsigned OtherIdx = Idx+1;
2940 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2941 ++OtherIdx) {
2942 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2943 // so that the 1st found AddRecExpr is dominated by all others.
2944 assert(DT.dominates(
2945 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2946 AddRec->getLoop()->getHeader()) &&
2947 "AddRecExprs are not sorted in reverse dominance order?");
2948 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2949 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2950 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2951 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 ++OtherIdx) {
2953 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2954 if (OtherAddRec->getLoop() == AddRecLoop) {
2955 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2956 i != e; ++i) {
2957 if (i >= AddRecOps.size()) {
2958 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2959 break;
2960 }
2962 AddRecOps[i], OtherAddRec->getOperand(i)};
2963 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2966 }
2967 }
2968 // Step size has changed, so we cannot guarantee no self-wraparound.
2969 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2970 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 }
2973
2974 // Otherwise couldn't fold anything into this recurrence. Move onto the
2975 // next one.
2976 }
2977
2978 // Okay, it looks like we really DO need an add expr. Check to see if we
2979 // already have one, otherwise create a new one.
2980 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2981}
2982
2983const SCEV *
2984ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2985 SCEV::NoWrapFlags Flags) {
2987 ID.AddInteger(scAddExpr);
2988 for (const SCEV *Op : Ops)
2989 ID.AddPointer(Op);
2990 void *IP = nullptr;
2991 SCEVAddExpr *S =
2992 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2993 if (!S) {
2994 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2995 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2996 S = new (SCEVAllocator)
2997 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2998 UniqueSCEVs.InsertNode(S, IP);
2999 registerUser(S, Ops);
3000 }
3001 S->setNoWrapFlags(Flags);
3002 return S;
3003}
3004
3005const SCEV *
3006ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3007 const Loop *L, SCEV::NoWrapFlags Flags) {
3009 ID.AddInteger(scAddRecExpr);
3010 for (const SCEV *Op : Ops)
3011 ID.AddPointer(Op);
3012 ID.AddPointer(L);
3013 void *IP = nullptr;
3014 SCEVAddRecExpr *S =
3015 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3016 if (!S) {
3017 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3019 S = new (SCEVAllocator)
3020 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3021 UniqueSCEVs.InsertNode(S, IP);
3022 LoopUsers[L].push_back(S);
3023 registerUser(S, Ops);
3024 }
3025 setNoWrapFlags(S, Flags);
3026 return S;
3027}
3028
3029const SCEV *
3030ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3031 SCEV::NoWrapFlags Flags) {
3033 ID.AddInteger(scMulExpr);
3034 for (const SCEV *Op : Ops)
3035 ID.AddPointer(Op);
3036 void *IP = nullptr;
3037 SCEVMulExpr *S =
3038 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3039 if (!S) {
3040 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3041 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3042 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3043 O, Ops.size());
3044 UniqueSCEVs.InsertNode(S, IP);
3045 registerUser(S, Ops);
3046 }
3047 S->setNoWrapFlags(Flags);
3048 return S;
3049}
3050
3051static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3052 uint64_t k = i*j;
3053 if (j > 1 && k / j != i) Overflow = true;
3054 return k;
3055}
3056
3057/// Compute the result of "n choose k", the binomial coefficient. If an
3058/// intermediate computation overflows, Overflow will be set and the return will
3059/// be garbage. Overflow is not cleared on absence of overflow.
3060static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3061 // We use the multiplicative formula:
3062 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3063 // At each iteration, we take the n-th term of the numeral and divide by the
3064 // (k-n)th term of the denominator. This division will always produce an
3065 // integral result, and helps reduce the chance of overflow in the
3066 // intermediate computations. However, we can still overflow even when the
3067 // final result would fit.
3068
3069 if (n == 0 || n == k) return 1;
3070 if (k > n) return 0;
3071
3072 if (k > n/2)
3073 k = n-k;
3074
3075 uint64_t r = 1;
3076 for (uint64_t i = 1; i <= k; ++i) {
3077 r = umul_ov(r, n-(i-1), Overflow);
3078 r /= i;
3079 }
3080 return r;
3081}
3082
3083/// Determine if any of the operands in this SCEV are a constant or if
3084/// any of the add or multiply expressions in this SCEV contain a constant.
3085static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3086 struct FindConstantInAddMulChain {
3087 bool FoundConstant = false;
3088
3089 bool follow(const SCEV *S) {
3090 FoundConstant |= isa<SCEVConstant>(S);
3091 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3092 }
3093
3094 bool isDone() const {
3095 return FoundConstant;
3096 }
3097 };
3098
3099 FindConstantInAddMulChain F;
3101 ST.visitAll(StartExpr);
3102 return F.FoundConstant;
3103}
3104
3105/// Get a canonical multiply expression, or something simpler if possible.
3107 SCEV::NoWrapFlags OrigFlags,
3108 unsigned Depth) {
3109 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3110 "only nuw or nsw allowed");
3111 assert(!Ops.empty() && "Cannot get empty mul!");
3112 if (Ops.size() == 1) return Ops[0];
3113#ifndef NDEBUG
3114 Type *ETy = Ops[0]->getType();
3115 assert(!ETy->isPointerTy());
3116 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3117 assert(Ops[i]->getType() == ETy &&
3118 "SCEVMulExpr operand types don't match!");
3119#endif
3120
3121 const SCEV *Folded = constantFoldAndGroupOps(
3122 *this, LI, DT, Ops,
3123 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3124 [](const APInt &C) { return C.isOne(); }, // identity
3125 [](const APInt &C) { return C.isZero(); }); // absorber
3126 if (Folded)
3127 return Folded;
3128
3129 // Delay expensive flag strengthening until necessary.
3130 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3131 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3132 };
3133
3134 // Limit recursion calls depth.
3136 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3137
3138 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3139 // Don't strengthen flags if we have no new information.
3140 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3141 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3142 Mul->setNoWrapFlags(ComputeFlags(Ops));
3143 return S;
3144 }
3145
3146 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3147 if (Ops.size() == 2) {
3148 // C1*(C2+V) -> C1*C2 + C1*V
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3157 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3159 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184 // Let M be the minimum representable signed value. AddRec with nsw
3185 // multiplied by -1 can have signed overflow if and only if it takes a
3186 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3187 // maximum signed value. In all other cases signed overflow is
3188 // impossible.
3189 auto FlagsMask = SCEV::FlagNW;
3190 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3191 auto MinInt =
3192 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3193 if (getSignedRangeMin(AddRec) != MinInt)
3194 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3195 }
3196 return getAddRecExpr(Operands, AddRec->getLoop(),
3197 AddRec->getNoWrapFlags(FlagsMask));
3198 }
3199 }
3200 }
3201 }
3202
3203 // Skip over the add expression until we get to a multiply.
3204 unsigned Idx = 0;
3205 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3206 ++Idx;
3207
3208 // If there are mul operands inline them all into this expression.
3209 if (Idx < Ops.size()) {
3210 bool DeletedMul = false;
3211 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3212 if (Ops.size() > MulOpsInlineThreshold)
3213 break;
3214 // If we have an mul, expand the mul operands onto the end of the
3215 // operands list.
3216 Ops.erase(Ops.begin()+Idx);
3217 append_range(Ops, Mul->operands());
3218 DeletedMul = true;
3219 }
3220
3221 // If we deleted at least one mul, we added operands to the end of the
3222 // list, and they are not necessarily sorted. Recurse to resort and
3223 // resimplify any operands we just acquired.
3224 if (DeletedMul)
3225 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3226 }
3227
3228 // If there are any add recurrences in the operands list, see if any other
3229 // added values are loop invariant. If so, we can fold them into the
3230 // recurrence.
3231 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3232 ++Idx;
3233
3234 // Scan over all recurrences, trying to fold loop invariants into them.
3235 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3236 // Scan all of the other operands to this mul and add them to the vector
3237 // if they are loop invariant w.r.t. the recurrence.
3239 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3240 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3241 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3242 LIOps.push_back(Ops[i]);
3243 Ops.erase(Ops.begin()+i);
3244 --i; --e;
3245 }
3246
3247 // If we found some loop invariants, fold them into the recurrence.
3248 if (!LIOps.empty()) {
3249 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3251 NewOps.reserve(AddRec->getNumOperands());
3252 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3253
3254 // If both the mul and addrec are nuw, we can preserve nuw.
3255 // If both the mul and addrec are nsw, we can only preserve nsw if either
3256 // a) they are also nuw, or
3257 // b) all multiplications of addrec operands with scale are nsw.
3258 SCEV::NoWrapFlags Flags =
3259 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3260
3261 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3262 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3263 SCEV::FlagAnyWrap, Depth + 1));
3264
3265 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3267 Instruction::Mul, getSignedRange(Scale),
3269 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3270 Flags = clearFlags(Flags, SCEV::FlagNSW);
3271 }
3272 }
3273
3274 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3275
3276 // If all of the other operands were loop invariant, we are done.
3277 if (Ops.size() == 1) return NewRec;
3278
3279 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3280 for (unsigned i = 0;; ++i)
3281 if (Ops[i] == AddRec) {
3282 Ops[i] = NewRec;
3283 break;
3284 }
3285 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 // Okay, if there weren't any loop invariants to be folded, check to see
3289 // if there are multiple AddRec's with the same loop induction variable
3290 // being multiplied together. If so, we can fold them.
3291
3292 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3293 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3294 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3295 // ]]],+,...up to x=2n}.
3296 // Note that the arguments to choose() are always integers with values
3297 // known at compile time, never SCEV objects.
3298 //
3299 // The implementation avoids pointless extra computations when the two
3300 // addrec's are of different length (mathematically, it's equivalent to
3301 // an infinite stream of zeros on the right).
3302 bool OpsModified = false;
3303 for (unsigned OtherIdx = Idx+1;
3304 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3305 ++OtherIdx) {
3306 const SCEVAddRecExpr *OtherAddRec =
3307 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3308 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3309 continue;
3310
3311 // Limit max number of arguments to avoid creation of unreasonably big
3312 // SCEVAddRecs with very complex operands.
3313 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3314 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3315 continue;
3316
3317 bool Overflow = false;
3318 Type *Ty = AddRec->getType();
3319 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3321 for (int x = 0, xe = AddRec->getNumOperands() +
3322 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3323 SmallVector <const SCEV *, 7> SumOps;
3324 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3325 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3326 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3327 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3328 z < ze && !Overflow; ++z) {
3329 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3330 uint64_t Coeff;
3331 if (LargerThan64Bits)
3332 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3333 else
3334 Coeff = Coeff1*Coeff2;
3335 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3336 const SCEV *Term1 = AddRec->getOperand(y-z);
3337 const SCEV *Term2 = OtherAddRec->getOperand(z);
3338 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3339 SCEV::FlagAnyWrap, Depth + 1));
3340 }
3341 }
3342 if (SumOps.empty())
3343 SumOps.push_back(getZero(Ty));
3344 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3345 }
3346 if (!Overflow) {
3347 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3349 if (Ops.size() == 2) return NewAddRec;
3350 Ops[Idx] = NewAddRec;
3351 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3352 OpsModified = true;
3353 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3354 if (!AddRec)
3355 break;
3356 }
3357 }
3358 if (OpsModified)
3359 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3360
3361 // Otherwise couldn't fold anything into this recurrence. Move onto the
3362 // next one.
3363 }
3364
3365 // Okay, it looks like we really DO need an mul expr. Check to see if we
3366 // already have one, otherwise create a new one.
3367 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3368}
3369
3370/// Represents an unsigned remainder expression based on unsigned division.
3372 const SCEV *RHS) {
3375 "SCEVURemExpr operand types don't match!");
3376
3377 // Short-circuit easy cases
3378 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3379 // If constant is one, the result is trivial
3380 if (RHSC->getValue()->isOne())
3381 return getZero(LHS->getType()); // X urem 1 --> 0
3382
3383 // If constant is a power of two, fold into a zext(trunc(LHS)).
3384 if (RHSC->getAPInt().isPowerOf2()) {
3385 Type *FullTy = LHS->getType();
3386 Type *TruncTy =
3387 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3388 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3389 }
3390 }
3391
3392 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3393 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3394 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3395 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3396}
3397
3398/// Get a canonical unsigned division expression, or something simpler if
3399/// possible.
3401 const SCEV *RHS) {
3402 assert(!LHS->getType()->isPointerTy() &&
3403 "SCEVUDivExpr operand can't be pointer!");
3404 assert(LHS->getType() == RHS->getType() &&
3405 "SCEVUDivExpr operand types don't match!");
3406
3408 ID.AddInteger(scUDivExpr);
3409 ID.AddPointer(LHS);
3410 ID.AddPointer(RHS);
3411 void *IP = nullptr;
3412 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3413 return S;
3414
3415 // 0 udiv Y == 0
3416 if (match(LHS, m_scev_Zero()))
3417 return LHS;
3418
3419 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3420 if (RHSC->getValue()->isOne())
3421 return LHS; // X udiv 1 --> x
3422 // If the denominator is zero, the result of the udiv is undefined. Don't
3423 // try to analyze it, because the resolution chosen here may differ from
3424 // the resolution chosen in other parts of the compiler.
3425 if (!RHSC->getValue()->isZero()) {
3426 // Determine if the division can be folded into the operands of
3427 // its operands.
3428 // TODO: Generalize this to non-constants by using known-bits information.
3429 Type *Ty = LHS->getType();
3430 unsigned LZ = RHSC->getAPInt().countl_zero();
3431 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3432 // For non-power-of-two values, effectively round the value up to the
3433 // nearest power of two.
3434 if (!RHSC->getAPInt().isPowerOf2())
3435 ++MaxShiftAmt;
3436 IntegerType *ExtTy =
3437 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3438 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3439 if (const SCEVConstant *Step =
3440 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3441 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3442 const APInt &StepInt = Step->getAPInt();
3443 const APInt &DivInt = RHSC->getAPInt();
3444 if (!StepInt.urem(DivInt) &&
3445 getZeroExtendExpr(AR, ExtTy) ==
3446 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3447 getZeroExtendExpr(Step, ExtTy),
3448 AR->getLoop(), SCEV::FlagAnyWrap)) {
3450 for (const SCEV *Op : AR->operands())
3451 Operands.push_back(getUDivExpr(Op, RHS));
3452 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3453 }
3454 /// Get a canonical UDivExpr for a recurrence.
3455 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3456 // We can currently only fold X%N if X is constant.
3457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3458 if (StartC && !DivInt.urem(StepInt) &&
3459 getZeroExtendExpr(AR, ExtTy) ==
3460 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3461 getZeroExtendExpr(Step, ExtTy),
3462 AR->getLoop(), SCEV::FlagAnyWrap)) {
3463 const APInt &StartInt = StartC->getAPInt();
3464 const APInt &StartRem = StartInt.urem(StepInt);
3465 if (StartRem != 0) {
3466 const SCEV *NewLHS =
3467 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3468 AR->getLoop(), SCEV::FlagNW);
3469 if (LHS != NewLHS) {
3470 LHS = NewLHS;
3471
3472 // Reset the ID to include the new LHS, and check if it is
3473 // already cached.
3474 ID.clear();
3475 ID.AddInteger(scUDivExpr);
3476 ID.AddPointer(LHS);
3477 ID.AddPointer(RHS);
3478 IP = nullptr;
3479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3480 return S;
3481 }
3482 }
3483 }
3484 }
3485 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3486 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3488 for (const SCEV *Op : M->operands())
3489 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3490 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3491 // Find an operand that's safely divisible.
3492 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3493 const SCEV *Op = M->getOperand(i);
3494 const SCEV *Div = getUDivExpr(Op, RHSC);
3495 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3496 Operands = SmallVector<const SCEV *, 4>(M->operands());
3497 Operands[i] = Div;
3498 return getMulExpr(Operands);
3499 }
3500 }
3501 }
3502
3503 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3504 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3505 if (auto *DivisorConstant =
3506 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3507 bool Overflow = false;
3508 APInt NewRHS =
3509 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3510 if (Overflow) {
3511 return getConstant(RHSC->getType(), 0, false);
3512 }
3513 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3514 }
3515 }
3516
3517 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3518 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3520 for (const SCEV *Op : A->operands())
3521 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3522 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3523 Operands.clear();
3524 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3525 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3526 if (isa<SCEVUDivExpr>(Op) ||
3527 getMulExpr(Op, RHS) != A->getOperand(i))
3528 break;
3529 Operands.push_back(Op);
3530 }
3531 if (Operands.size() == A->getNumOperands())
3532 return getAddExpr(Operands);
3533 }
3534 }
3535
3536 // Fold if both operands are constant.
3537 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3538 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3539 }
3540 }
3541
3542 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3543 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3544 AE && AE->getNumOperands() == 2) {
3545 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3546 const APInt &NegC = VC->getAPInt();
3547 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3548 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3549 if (MME && MME->getNumOperands() == 2 &&
3550 isa<SCEVConstant>(MME->getOperand(0)) &&
3551 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3552 MME->getOperand(1) == RHS)
3553 return getZero(LHS->getType());
3554 }
3555 }
3556 }
3557
3558 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3559 // changes). Make sure we get a new one.
3560 IP = nullptr;
3561 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3562 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3563 LHS, RHS);
3564 UniqueSCEVs.InsertNode(S, IP);
3565 registerUser(S, {LHS, RHS});
3566 return S;
3567}
3568
3569APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3570 APInt A = C1->getAPInt().abs();
3571 APInt B = C2->getAPInt().abs();
3572 uint32_t ABW = A.getBitWidth();
3573 uint32_t BBW = B.getBitWidth();
3574
3575 if (ABW > BBW)
3576 B = B.zext(ABW);
3577 else if (ABW < BBW)
3578 A = A.zext(BBW);
3579
3580 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3581}
3582
3583/// Get a canonical unsigned division expression, or something simpler if
3584/// possible. There is no representation for an exact udiv in SCEV IR, but we
3585/// can attempt to remove factors from the LHS and RHS. We can't do this when
3586/// it's not exact because the udiv may be clearing bits.
3588 const SCEV *RHS) {
3589 // TODO: we could try to find factors in all sorts of things, but for now we
3590 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3591 // end of this file for inspiration.
3592
3593 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3594 if (!Mul || !Mul->hasNoUnsignedWrap())
3595 return getUDivExpr(LHS, RHS);
3596
3597 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3598 // If the mulexpr multiplies by a constant, then that constant must be the
3599 // first element of the mulexpr.
3600 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3601 if (LHSCst == RHSCst) {
3603 return getMulExpr(Operands);
3604 }
3605
3606 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3607 // that there's a factor provided by one of the other terms. We need to
3608 // check.
3609 APInt Factor = gcd(LHSCst, RHSCst);
3610 if (!Factor.isIntN(1)) {
3611 LHSCst =
3612 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3613 RHSCst =
3614 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3616 Operands.push_back(LHSCst);
3617 append_range(Operands, Mul->operands().drop_front());
3619 RHS = RHSCst;
3620 Mul = dyn_cast<SCEVMulExpr>(LHS);
3621 if (!Mul)
3622 return getUDivExactExpr(LHS, RHS);
3623 }
3624 }
3625 }
3626
3627 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3628 if (Mul->getOperand(i) == RHS) {
3630 append_range(Operands, Mul->operands().take_front(i));
3631 append_range(Operands, Mul->operands().drop_front(i + 1));
3632 return getMulExpr(Operands);
3633 }
3634 }
3635
3636 return getUDivExpr(LHS, RHS);
3637}
3638
3639/// Get an add recurrence expression for the specified loop. Simplify the
3640/// expression as much as possible.
3641const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3642 const Loop *L,
3643 SCEV::NoWrapFlags Flags) {
3645 Operands.push_back(Start);
3646 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3647 if (StepChrec->getLoop() == L) {
3648 append_range(Operands, StepChrec->operands());
3649 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3650 }
3651
3652 Operands.push_back(Step);
3653 return getAddRecExpr(Operands, L, Flags);
3654}
3655
3656/// Get an add recurrence expression for the specified loop. Simplify the
3657/// expression as much as possible.
3658const SCEV *
3660 const Loop *L, SCEV::NoWrapFlags Flags) {
3661 if (Operands.size() == 1) return Operands[0];
3662#ifndef NDEBUG
3664 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3665 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3666 "SCEVAddRecExpr operand types don't match!");
3667 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3668 }
3669 for (const SCEV *Op : Operands)
3671 "SCEVAddRecExpr operand is not available at loop entry!");
3672#endif
3673
3674 if (Operands.back()->isZero()) {
3675 Operands.pop_back();
3676 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3677 }
3678
3679 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3680 // use that information to infer NUW and NSW flags. However, computing a
3681 // BE count requires calling getAddRecExpr, so we may not yet have a
3682 // meaningful BE count at this point (and if we don't, we'd be stuck
3683 // with a SCEVCouldNotCompute as the cached BE count).
3684
3685 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3686
3687 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3688 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3689 const Loop *NestedLoop = NestedAR->getLoop();
3690 if (L->contains(NestedLoop)
3691 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3692 : (!NestedLoop->contains(L) &&
3693 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3694 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3695 Operands[0] = NestedAR->getStart();
3696 // AddRecs require their operands be loop-invariant with respect to their
3697 // loops. Don't perform this transformation if it would break this
3698 // requirement.
3699 bool AllInvariant = all_of(
3700 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3701
3702 if (AllInvariant) {
3703 // Create a recurrence for the outer loop with the same step size.
3704 //
3705 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3706 // inner recurrence has the same property.
3707 SCEV::NoWrapFlags OuterFlags =
3708 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3709
3710 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3711 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3712 return isLoopInvariant(Op, NestedLoop);
3713 });
3714
3715 if (AllInvariant) {
3716 // Ok, both add recurrences are valid after the transformation.
3717 //
3718 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3719 // the outer recurrence has the same property.
3720 SCEV::NoWrapFlags InnerFlags =
3721 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3722 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3723 }
3724 }
3725 // Reset Operands to its original state.
3726 Operands[0] = NestedAR;
3727 }
3728 }
3729
3730 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3731 // already have one, otherwise create a new one.
3732 return getOrCreateAddRecExpr(Operands, L, Flags);
3733}
3734
3735const SCEV *
3737 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3738 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3739 // getSCEV(Base)->getType() has the same address space as Base->getType()
3740 // because SCEV::getType() preserves the address space.
3741 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3742 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3743 if (NW != GEPNoWrapFlags::none()) {
3744 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3745 // but to do that, we have to ensure that said flag is valid in the entire
3746 // defined scope of the SCEV.
3747 // TODO: non-instructions have global scope. We might be able to prove
3748 // some global scope cases
3749 auto *GEPI = dyn_cast<Instruction>(GEP);
3750 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3751 NW = GEPNoWrapFlags::none();
3752 }
3753
3755 if (NW.hasNoUnsignedSignedWrap())
3756 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3757 if (NW.hasNoUnsignedWrap())
3758 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3759
3760 Type *CurTy = GEP->getType();
3761 bool FirstIter = true;
3763 for (const SCEV *IndexExpr : IndexExprs) {
3764 // Compute the (potentially symbolic) offset in bytes for this index.
3765 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3766 // For a struct, add the member offset.
3767 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3768 unsigned FieldNo = Index->getZExtValue();
3769 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3770 Offsets.push_back(FieldOffset);
3771
3772 // Update CurTy to the type of the field at Index.
3773 CurTy = STy->getTypeAtIndex(Index);
3774 } else {
3775 // Update CurTy to its element type.
3776 if (FirstIter) {
3777 assert(isa<PointerType>(CurTy) &&
3778 "The first index of a GEP indexes a pointer");
3779 CurTy = GEP->getSourceElementType();
3780 FirstIter = false;
3781 } else {
3783 }
3784 // For an array, add the element offset, explicitly scaled.
3785 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3786 // Getelementptr indices are signed.
3787 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3788
3789 // Multiply the index by the element size to compute the element offset.
3790 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3791 Offsets.push_back(LocalOffset);
3792 }
3793 }
3794
3795 // Handle degenerate case of GEP without offsets.
3796 if (Offsets.empty())
3797 return BaseExpr;
3798
3799 // Add the offsets together, assuming nsw if inbounds.
3800 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3801 // Add the base address and the offset. We cannot use the nsw flag, as the
3802 // base address is unsigned. However, if we know that the offset is
3803 // non-negative, we can use nuw.
3804 bool NUW = NW.hasNoUnsignedWrap() ||
3807 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3808 assert(BaseExpr->getType() == GEPExpr->getType() &&
3809 "GEP should not change type mid-flight.");
3810 return GEPExpr;
3811}
3812
3813SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3816 ID.AddInteger(SCEVType);
3817 for (const SCEV *Op : Ops)
3818 ID.AddPointer(Op);
3819 void *IP = nullptr;
3820 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3821}
3822
3823const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3825 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3826}
3827
3830 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3831 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3832 if (Ops.size() == 1) return Ops[0];
3833#ifndef NDEBUG
3834 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3835 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3836 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3837 "Operand types don't match!");
3838 assert(Ops[0]->getType()->isPointerTy() ==
3839 Ops[i]->getType()->isPointerTy() &&
3840 "min/max should be consistently pointerish");
3841 }
3842#endif
3843
3844 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3845 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3846
3847 const SCEV *Folded = constantFoldAndGroupOps(
3848 *this, LI, DT, Ops,
3849 [&](const APInt &C1, const APInt &C2) {
3850 switch (Kind) {
3851 case scSMaxExpr:
3852 return APIntOps::smax(C1, C2);
3853 case scSMinExpr:
3854 return APIntOps::smin(C1, C2);
3855 case scUMaxExpr:
3856 return APIntOps::umax(C1, C2);
3857 case scUMinExpr:
3858 return APIntOps::umin(C1, C2);
3859 default:
3860 llvm_unreachable("Unknown SCEV min/max opcode");
3861 }
3862 },
3863 [&](const APInt &C) {
3864 // identity
3865 if (IsMax)
3866 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3867 else
3868 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3869 },
3870 [&](const APInt &C) {
3871 // absorber
3872 if (IsMax)
3873 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3874 else
3875 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3876 });
3877 if (Folded)
3878 return Folded;
3879
3880 // Check if we have created the same expression before.
3881 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3882 return S;
3883 }
3884
3885 // Find the first operation of the same kind
3886 unsigned Idx = 0;
3887 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3888 ++Idx;
3889
3890 // Check to see if one of the operands is of the same kind. If so, expand its
3891 // operands onto our operand list, and recurse to simplify.
3892 if (Idx < Ops.size()) {
3893 bool DeletedAny = false;
3894 while (Ops[Idx]->getSCEVType() == Kind) {
3895 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3896 Ops.erase(Ops.begin()+Idx);
3897 append_range(Ops, SMME->operands());
3898 DeletedAny = true;
3899 }
3900
3901 if (DeletedAny)
3902 return getMinMaxExpr(Kind, Ops);
3903 }
3904
3905 // Okay, check to see if the same value occurs in the operand list twice. If
3906 // so, delete one. Since we sorted the list, these values are required to
3907 // be adjacent.
3912 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3913 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3914 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3915 if (Ops[i] == Ops[i + 1] ||
3916 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3917 // X op Y op Y --> X op Y
3918 // X op Y --> X, if we know X, Y are ordered appropriately
3919 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3920 --i;
3921 --e;
3922 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3923 Ops[i + 1])) {
3924 // X op Y --> Y, if we know X, Y are ordered appropriately
3925 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3926 --i;
3927 --e;
3928 }
3929 }
3930
3931 if (Ops.size() == 1) return Ops[0];
3932
3933 assert(!Ops.empty() && "Reduced smax down to nothing!");
3934
3935 // Okay, it looks like we really DO need an expr. Check to see if we
3936 // already have one, otherwise create a new one.
3938 ID.AddInteger(Kind);
3939 for (const SCEV *Op : Ops)
3940 ID.AddPointer(Op);
3941 void *IP = nullptr;
3942 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3943 if (ExistingSCEV)
3944 return ExistingSCEV;
3945 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3946 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3947 SCEV *S = new (SCEVAllocator)
3948 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3949
3950 UniqueSCEVs.InsertNode(S, IP);
3951 registerUser(S, Ops);
3952 return S;
3953}
3954
3955namespace {
3956
3957class SCEVSequentialMinMaxDeduplicatingVisitor final
3958 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3959 std::optional<const SCEV *>> {
3960 using RetVal = std::optional<const SCEV *>;
3962
3963 ScalarEvolution &SE;
3964 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3965 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3967
3968 bool canRecurseInto(SCEVTypes Kind) const {
3969 // We can only recurse into the SCEV expression of the same effective type
3970 // as the type of our root SCEV expression.
3971 return RootKind == Kind || NonSequentialRootKind == Kind;
3972 };
3973
3974 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3975 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3976 "Only for min/max expressions.");
3977 SCEVTypes Kind = S->getSCEVType();
3978
3979 if (!canRecurseInto(Kind))
3980 return S;
3981
3982 auto *NAry = cast<SCEVNAryExpr>(S);
3984 bool Changed = visit(Kind, NAry->operands(), NewOps);
3985
3986 if (!Changed)
3987 return S;
3988 if (NewOps.empty())
3989 return std::nullopt;
3990
3991 return isa<SCEVSequentialMinMaxExpr>(S)
3992 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3993 : SE.getMinMaxExpr(Kind, NewOps);
3994 }
3995
3996 RetVal visit(const SCEV *S) {
3997 // Has the whole operand been seen already?
3998 if (!SeenOps.insert(S).second)
3999 return std::nullopt;
4000 return Base::visit(S);
4001 }
4002
4003public:
4004 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4005 SCEVTypes RootKind)
4006 : SE(SE), RootKind(RootKind),
4007 NonSequentialRootKind(
4008 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4009 RootKind)) {}
4010
4011 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4013 bool Changed = false;
4015 Ops.reserve(OrigOps.size());
4016
4017 for (const SCEV *Op : OrigOps) {
4018 RetVal NewOp = visit(Op);
4019 if (NewOp != Op)
4020 Changed = true;
4021 if (NewOp)
4022 Ops.emplace_back(*NewOp);
4023 }
4024
4025 if (Changed)
4026 NewOps = std::move(Ops);
4027 return Changed;
4028 }
4029
4030 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4031
4032 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4033
4034 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4035
4036 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4037
4038 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4039
4040 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4041
4042 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4043
4044 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4045
4046 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4047
4048 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4049
4050 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4051 return visitAnyMinMaxExpr(Expr);
4052 }
4053
4054 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4055 return visitAnyMinMaxExpr(Expr);
4056 }
4057
4058 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4059 return visitAnyMinMaxExpr(Expr);
4060 }
4061
4062 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4063 return visitAnyMinMaxExpr(Expr);
4064 }
4065
4066 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4067 return visitAnyMinMaxExpr(Expr);
4068 }
4069
4070 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4071
4072 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4073};
4074
4075} // namespace
4076
4078 switch (Kind) {
4079 case scConstant:
4080 case scVScale:
4081 case scTruncate:
4082 case scZeroExtend:
4083 case scSignExtend:
4084 case scPtrToInt:
4085 case scAddExpr:
4086 case scMulExpr:
4087 case scUDivExpr:
4088 case scAddRecExpr:
4089 case scUMaxExpr:
4090 case scSMaxExpr:
4091 case scUMinExpr:
4092 case scSMinExpr:
4093 case scUnknown:
4094 // If any operand is poison, the whole expression is poison.
4095 return true;
4097 // FIXME: if the *first* operand is poison, the whole expression is poison.
4098 return false; // Pessimistically, say that it does not propagate poison.
4099 case scCouldNotCompute:
4100 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4101 }
4102 llvm_unreachable("Unknown SCEV kind!");
4103}
4104
4105namespace {
4106// The only way poison may be introduced in a SCEV expression is from a
4107// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4108// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4109// introduce poison -- they encode guaranteed, non-speculated knowledge.
4110//
4111// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4112// with the notable exception of umin_seq, where only poison from the first
4113// operand is (unconditionally) propagated.
4114struct SCEVPoisonCollector {
4115 bool LookThroughMaybePoisonBlocking;
4117 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4118 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4119
4120 bool follow(const SCEV *S) {
4121 if (!LookThroughMaybePoisonBlocking &&
4123 return false;
4124
4125 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4126 if (!isGuaranteedNotToBePoison(SU->getValue()))
4127 MaybePoison.insert(SU);
4128 }
4129 return true;
4130 }
4131 bool isDone() const { return false; }
4132};
4133} // namespace
4134
4135/// Return true if V is poison given that AssumedPoison is already poison.
4136static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4137 // First collect all SCEVs that might result in AssumedPoison to be poison.
4138 // We need to look through potentially poison-blocking operations here,
4139 // because we want to find all SCEVs that *might* result in poison, not only
4140 // those that are *required* to.
4141 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4142 visitAll(AssumedPoison, PC1);
4143
4144 // AssumedPoison is never poison. As the assumption is false, the implication
4145 // is true. Don't bother walking the other SCEV in this case.
4146 if (PC1.MaybePoison.empty())
4147 return true;
4148
4149 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4150 // as well. We cannot look through potentially poison-blocking operations
4151 // here, as their arguments only *may* make the result poison.
4152 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4153 visitAll(S, PC2);
4154
4155 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4156 // it will also make S poison by being part of PC2.MaybePoison.
4157 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4158}
4159
4161 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4162 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4163 visitAll(S, PC);
4164 for (const SCEVUnknown *SU : PC.MaybePoison)
4165 Result.insert(SU->getValue());
4166}
4167
4169 const SCEV *S, Instruction *I,
4170 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4171 // If the instruction cannot be poison, it's always safe to reuse.
4173 return true;
4174
4175 // Otherwise, it is possible that I is more poisonous that S. Collect the
4176 // poison-contributors of S, and then check whether I has any additional
4177 // poison-contributors. Poison that is contributed through poison-generating
4178 // flags is handled by dropping those flags instead.
4180 getPoisonGeneratingValues(PoisonVals, S);
4181
4182 SmallVector<Value *> Worklist;
4184 Worklist.push_back(I);
4185 while (!Worklist.empty()) {
4186 Value *V = Worklist.pop_back_val();
4187 if (!Visited.insert(V).second)
4188 continue;
4189
4190 // Avoid walking large instruction graphs.
4191 if (Visited.size() > 16)
4192 return false;
4193
4194 // Either the value can't be poison, or the S would also be poison if it
4195 // is.
4196 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4197 continue;
4198
4199 auto *I = dyn_cast<Instruction>(V);
4200 if (!I)
4201 return false;
4202
4203 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4204 // can't replace an arbitrary add with disjoint or, even if we drop the
4205 // flag. We would need to convert the or into an add.
4206 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4207 if (PDI->isDisjoint())
4208 return false;
4209
4210 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4211 // because SCEV currently assumes it can't be poison. Remove this special
4212 // case once we proper model when vscale can be poison.
4213 if (auto *II = dyn_cast<IntrinsicInst>(I);
4214 II && II->getIntrinsicID() == Intrinsic::vscale)
4215 continue;
4216
4217 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4218 return false;
4219
4220 // If the instruction can't create poison, we can recurse to its operands.
4221 if (I->hasPoisonGeneratingAnnotations())
4222 DropPoisonGeneratingInsts.push_back(I);
4223
4224 for (Value *Op : I->operands())
4225 Worklist.push_back(Op);
4226 }
4227 return true;
4228}
4229
4230const SCEV *
4233 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4234 "Not a SCEVSequentialMinMaxExpr!");
4235 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4236 if (Ops.size() == 1)
4237 return Ops[0];
4238#ifndef NDEBUG
4239 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4240 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4241 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4242 "Operand types don't match!");
4243 assert(Ops[0]->getType()->isPointerTy() ==
4244 Ops[i]->getType()->isPointerTy() &&
4245 "min/max should be consistently pointerish");
4246 }
4247#endif
4248
4249 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4250 // so we can *NOT* do any kind of sorting of the expressions!
4251
4252 // Check if we have created the same expression before.
4253 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4254 return S;
4255
4256 // FIXME: there are *some* simplifications that we can do here.
4257
4258 // Keep only the first instance of an operand.
4259 {
4260 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4261 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4262 if (Changed)
4263 return getSequentialMinMaxExpr(Kind, Ops);
4264 }
4265
4266 // Check to see if one of the operands is of the same kind. If so, expand its
4267 // operands onto our operand list, and recurse to simplify.
4268 {
4269 unsigned Idx = 0;
4270 bool DeletedAny = false;
4271 while (Idx < Ops.size()) {
4272 if (Ops[Idx]->getSCEVType() != Kind) {
4273 ++Idx;
4274 continue;
4275 }
4276 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4277 Ops.erase(Ops.begin() + Idx);
4278 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4279 SMME->operands().end());
4280 DeletedAny = true;
4281 }
4282
4283 if (DeletedAny)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 const SCEV *SaturationPoint;
4289 switch (Kind) {
4291 SaturationPoint = getZero(Ops[0]->getType());
4292 Pred = ICmpInst::ICMP_ULE;
4293 break;
4294 default:
4295 llvm_unreachable("Not a sequential min/max type.");
4296 }
4297
4298 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4299 if (!isGuaranteedNotToCauseUB(Ops[i]))
4300 continue;
4301 // We can replace %x umin_seq %y with %x umin %y if either:
4302 // * %y being poison implies %x is also poison.
4303 // * %x cannot be the saturating value (e.g. zero for umin).
4304 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4305 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4306 SaturationPoint)) {
4307 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4308 Ops[i - 1] = getMinMaxExpr(
4310 SeqOps);
4311 Ops.erase(Ops.begin() + i);
4312 return getSequentialMinMaxExpr(Kind, Ops);
4313 }
4314 // Fold %x umin_seq %y to %x if %x ule %y.
4315 // TODO: We might be able to prove the predicate for a later operand.
4316 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4317 Ops.erase(Ops.begin() + i);
4318 return getSequentialMinMaxExpr(Kind, Ops);
4319 }
4320 }
4321
4322 // Okay, it looks like we really DO need an expr. Check to see if we
4323 // already have one, otherwise create a new one.
4325 ID.AddInteger(Kind);
4326 for (const SCEV *Op : Ops)
4327 ID.AddPointer(Op);
4328 void *IP = nullptr;
4329 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4330 if (ExistingSCEV)
4331 return ExistingSCEV;
4332
4333 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4334 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4335 SCEV *S = new (SCEVAllocator)
4336 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4337
4338 UniqueSCEVs.InsertNode(S, IP);
4339 registerUser(S, Ops);
4340 return S;
4341}
4342
4343const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4345 return getSMaxExpr(Ops);
4346}
4347
4349 return getMinMaxExpr(scSMaxExpr, Ops);
4350}
4351
4352const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4354 return getUMaxExpr(Ops);
4355}
4356
4358 return getMinMaxExpr(scUMaxExpr, Ops);
4359}
4360
4362 const SCEV *RHS) {
4364 return getSMinExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMinExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4372 bool Sequential) {
4374 return getUMinExpr(Ops, Sequential);
4375}
4376
4378 bool Sequential) {
4379 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4380 : getMinMaxExpr(scUMinExpr, Ops);
4381}
4382
4383const SCEV *
4385 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4386 if (Size.isScalable())
4387 Res = getMulExpr(Res, getVScale(IntTy));
4388 return Res;
4389}
4390
4392 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4393}
4394
4396 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4397}
4398
4400 StructType *STy,
4401 unsigned FieldNo) {
4402 // We can bypass creating a target-independent constant expression and then
4403 // folding it back into a ConstantInt. This is just a compile-time
4404 // optimization.
4405 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4406 assert(!SL->getSizeInBits().isScalable() &&
4407 "Cannot get offset for structure containing scalable vector types");
4408 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4409}
4410
4412 // Don't attempt to do anything other than create a SCEVUnknown object
4413 // here. createSCEV only calls getUnknown after checking for all other
4414 // interesting possibilities, and any other code that calls getUnknown
4415 // is doing so in order to hide a value from SCEV canonicalization.
4416
4418 ID.AddInteger(scUnknown);
4419 ID.AddPointer(V);
4420 void *IP = nullptr;
4421 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4422 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4423 "Stale SCEVUnknown in uniquing map!");
4424 return S;
4425 }
4426 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4427 FirstUnknown);
4428 FirstUnknown = cast<SCEVUnknown>(S);
4429 UniqueSCEVs.InsertNode(S, IP);
4430 return S;
4431}
4432
4433//===----------------------------------------------------------------------===//
4434// Basic SCEV Analysis and PHI Idiom Recognition Code
4435//
4436
4437/// Test if values of the given type are analyzable within the SCEV
4438/// framework. This primarily includes integer types, and it can optionally
4439/// include pointer types if the ScalarEvolution class has access to
4440/// target-specific information.
4442 // Integers and pointers are always SCEVable.
4443 return Ty->isIntOrPtrTy();
4444}
4445
4446/// Return the size in bits of the specified type, for which isSCEVable must
4447/// return true.
4449 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4450 if (Ty->isPointerTy())
4452 return getDataLayout().getTypeSizeInBits(Ty);
4453}
4454
4455/// Return a type with the same bitwidth as the given type and which represents
4456/// how SCEV will treat the given type, for which isSCEVable must return
4457/// true. For pointer types, this is the pointer index sized integer type.
4459 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4460
4461 if (Ty->isIntegerTy())
4462 return Ty;
4463
4464 // The only other support type is pointer.
4465 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4466 return getDataLayout().getIndexType(Ty);
4467}
4468
4470 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4471}
4472
4474 const SCEV *B) {
4475 /// For a valid use point to exist, the defining scope of one operand
4476 /// must dominate the other.
4477 bool PreciseA, PreciseB;
4478 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4479 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4480 if (!PreciseA || !PreciseB)
4481 // Can't tell.
4482 return false;
4483 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4484 DT.dominates(ScopeB, ScopeA);
4485}
4486
4488 return CouldNotCompute.get();
4489}
4490
4491bool ScalarEvolution::checkValidity(const SCEV *S) const {
4492 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4493 auto *SU = dyn_cast<SCEVUnknown>(S);
4494 return SU && SU->getValue() == nullptr;
4495 });
4496
4497 return !ContainsNulls;
4498}
4499
4501 HasRecMapType::iterator I = HasRecMap.find(S);
4502 if (I != HasRecMap.end())
4503 return I->second;
4504
4505 bool FoundAddRec =
4506 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4507 HasRecMap.insert({S, FoundAddRec});
4508 return FoundAddRec;
4509}
4510
4511/// Return the ValueOffsetPair set for \p S. \p S can be represented
4512/// by the value and offset from any ValueOffsetPair in the set.
4513ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4514 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4515 if (SI == ExprValueMap.end())
4516 return {};
4517 return SI->second.getArrayRef();
4518}
4519
4520/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4521/// cannot be used separately. eraseValueFromMap should be used to remove
4522/// V from ValueExprMap and ExprValueMap at the same time.
4523void ScalarEvolution::eraseValueFromMap(Value *V) {
4524 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4525 if (I != ValueExprMap.end()) {
4526 auto EVIt = ExprValueMap.find(I->second);
4527 bool Removed = EVIt->second.remove(V);
4528 (void) Removed;
4529 assert(Removed && "Value not in ExprValueMap?");
4530 ValueExprMap.erase(I);
4531 }
4532}
4533
4534void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4535 // A recursive query may have already computed the SCEV. It should be
4536 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4537 // inferred nowrap flags.
4538 auto It = ValueExprMap.find_as(V);
4539 if (It == ValueExprMap.end()) {
4540 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4541 ExprValueMap[S].insert(V);
4542 }
4543}
4544
4545/// Return an existing SCEV if it exists, otherwise analyze the expression and
4546/// create a new one.
4548 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4549
4550 if (const SCEV *S = getExistingSCEV(V))
4551 return S;
4552 return createSCEVIter(V);
4553}
4554
4556 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4557
4558 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4559 if (I != ValueExprMap.end()) {
4560 const SCEV *S = I->second;
4561 assert(checkValidity(S) &&
4562 "existing SCEV has not been properly invalidated");
4563 return S;
4564 }
4565 return nullptr;
4566}
4567
4568/// Return a SCEV corresponding to -V = -1*V
4570 SCEV::NoWrapFlags Flags) {
4571 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4572 return getConstant(
4573 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4574
4575 Type *Ty = V->getType();
4576 Ty = getEffectiveSCEVType(Ty);
4577 return getMulExpr(V, getMinusOne(Ty), Flags);
4578}
4579
4580/// If Expr computes ~A, return A else return nullptr
4581static const SCEV *MatchNotExpr(const SCEV *Expr) {
4582 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4583 if (!Add || Add->getNumOperands() != 2 ||
4584 !Add->getOperand(0)->isAllOnesValue())
4585 return nullptr;
4586
4587 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4588 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4589 !AddRHS->getOperand(0)->isAllOnesValue())
4590 return nullptr;
4591
4592 return AddRHS->getOperand(1);
4593}
4594
4595/// Return a SCEV corresponding to ~V = -1-V
4597 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4598
4599 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4600 return getConstant(
4601 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4602
4603 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4604 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4605 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4606 SmallVector<const SCEV *, 2> MatchedOperands;
4607 for (const SCEV *Operand : MME->operands()) {
4608 const SCEV *Matched = MatchNotExpr(Operand);
4609 if (!Matched)
4610 return (const SCEV *)nullptr;
4611 MatchedOperands.push_back(Matched);
4612 }
4613 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4614 MatchedOperands);
4615 };
4616 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4617 return Replaced;
4618 }
4619
4620 Type *Ty = V->getType();
4621 Ty = getEffectiveSCEVType(Ty);
4622 return getMinusSCEV(getMinusOne(Ty), V);
4623}
4624
4626 assert(P->getType()->isPointerTy());
4627
4628 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4629 // The base of an AddRec is the first operand.
4630 SmallVector<const SCEV *> Ops{AddRec->operands()};
4631 Ops[0] = removePointerBase(Ops[0]);
4632 // Don't try to transfer nowrap flags for now. We could in some cases
4633 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4634 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4635 }
4636 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4637 // The base of an Add is the pointer operand.
4638 SmallVector<const SCEV *> Ops{Add->operands()};
4639 const SCEV **PtrOp = nullptr;
4640 for (const SCEV *&AddOp : Ops) {
4641 if (AddOp->getType()->isPointerTy()) {
4642 assert(!PtrOp && "Cannot have multiple pointer ops");
4643 PtrOp = &AddOp;
4644 }
4645 }
4646 *PtrOp = removePointerBase(*PtrOp);
4647 // Don't try to transfer nowrap flags for now. We could in some cases
4648 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4649 return getAddExpr(Ops);
4650 }
4651 // Any other expression must be a pointer base.
4652 return getZero(P->getType());
4653}
4654
4655const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4656 SCEV::NoWrapFlags Flags,
4657 unsigned Depth) {
4658 // Fast path: X - X --> 0.
4659 if (LHS == RHS)
4660 return getZero(LHS->getType());
4661
4662 // If we subtract two pointers with different pointer bases, bail.
4663 // Eventually, we're going to add an assertion to getMulExpr that we
4664 // can't multiply by a pointer.
4665 if (RHS->getType()->isPointerTy()) {
4666 if (!LHS->getType()->isPointerTy() ||
4668 return getCouldNotCompute();
4671 }
4672
4673 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4674 // makes it so that we cannot make much use of NUW.
4675 auto AddFlags = SCEV::FlagAnyWrap;
4676 const bool RHSIsNotMinSigned =
4678 if (hasFlags(Flags, SCEV::FlagNSW)) {
4679 // Let M be the minimum representable signed value. Then (-1)*RHS
4680 // signed-wraps if and only if RHS is M. That can happen even for
4681 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4682 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4683 // (-1)*RHS, we need to prove that RHS != M.
4684 //
4685 // If LHS is non-negative and we know that LHS - RHS does not
4686 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4687 // either by proving that RHS > M or that LHS >= 0.
4688 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4689 AddFlags = SCEV::FlagNSW;
4690 }
4691 }
4692
4693 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4694 // RHS is NSW and LHS >= 0.
4695 //
4696 // The difficulty here is that the NSW flag may have been proven
4697 // relative to a loop that is to be found in a recurrence in LHS and
4698 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4699 // larger scope than intended.
4700 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4701
4702 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4703}
4704
4706 unsigned Depth) {
4707 Type *SrcTy = V->getType();
4708 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4709 "Cannot truncate or zero extend with non-integer arguments!");
4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4711 return V; // No conversion
4712 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4713 return getTruncateExpr(V, Ty, Depth);
4714 return getZeroExtendExpr(V, Ty, Depth);
4715}
4716
4718 unsigned Depth) {
4719 Type *SrcTy = V->getType();
4720 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4721 "Cannot truncate or zero extend with non-integer arguments!");
4722 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4723 return V; // No conversion
4724 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4725 return getTruncateExpr(V, Ty, Depth);
4726 return getSignExtendExpr(V, Ty, Depth);
4727}
4728
4729const SCEV *
4731 Type *SrcTy = V->getType();
4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4733 "Cannot noop or zero extend with non-integer arguments!");
4735 "getNoopOrZeroExtend cannot truncate!");
4736 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4737 return V; // No conversion
4738 return getZeroExtendExpr(V, Ty);
4739}
4740
4741const SCEV *
4743 Type *SrcTy = V->getType();
4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4745 "Cannot noop or sign extend with non-integer arguments!");
4747 "getNoopOrSignExtend cannot truncate!");
4748 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4749 return V; // No conversion
4750 return getSignExtendExpr(V, Ty);
4751}
4752
4753const SCEV *
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot noop or any extend with non-integer arguments!");
4759 "getNoopOrAnyExtend cannot truncate!");
4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4761 return V; // No conversion
4762 return getAnyExtendExpr(V, Ty);
4763}
4764
4765const SCEV *
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot truncate or noop with non-integer arguments!");
4771 "getTruncateOrNoop cannot extend!");
4772 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4773 return V; // No conversion
4774 return getTruncateExpr(V, Ty);
4775}
4776
4778 const SCEV *RHS) {
4779 const SCEV *PromotedLHS = LHS;
4780 const SCEV *PromotedRHS = RHS;
4781
4783 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4784 else
4785 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4786
4787 return getUMaxExpr(PromotedLHS, PromotedRHS);
4788}
4789
4791 const SCEV *RHS,
4792 bool Sequential) {
4794 return getUMinFromMismatchedTypes(Ops, Sequential);
4795}
4796
4797const SCEV *
4799 bool Sequential) {
4800 assert(!Ops.empty() && "At least one operand must be!");
4801 // Trivial case.
4802 if (Ops.size() == 1)
4803 return Ops[0];
4804
4805 // Find the max type first.
4806 Type *MaxType = nullptr;
4807 for (const auto *S : Ops)
4808 if (MaxType)
4809 MaxType = getWiderType(MaxType, S->getType());
4810 else
4811 MaxType = S->getType();
4812 assert(MaxType && "Failed to find maximum type!");
4813
4814 // Extend all ops to max type.
4815 SmallVector<const SCEV *, 2> PromotedOps;
4816 for (const auto *S : Ops)
4817 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4818
4819 // Generate umin.
4820 return getUMinExpr(PromotedOps, Sequential);
4821}
4822
4824 // A pointer operand may evaluate to a nonpointer expression, such as null.
4825 if (!V->getType()->isPointerTy())
4826 return V;
4827
4828 while (true) {
4829 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4830 V = AddRec->getStart();
4831 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4832 const SCEV *PtrOp = nullptr;
4833 for (const SCEV *AddOp : Add->operands()) {
4834 if (AddOp->getType()->isPointerTy()) {
4835 assert(!PtrOp && "Cannot have multiple pointer ops");
4836 PtrOp = AddOp;
4837 }
4838 }
4839 assert(PtrOp && "Must have pointer op");
4840 V = PtrOp;
4841 } else // Not something we can look further into.
4842 return V;
4843 }
4844}
4845
4846/// Push users of the given Instruction onto the given Worklist.
4850 // Push the def-use children onto the Worklist stack.
4851 for (User *U : I->users()) {
4852 auto *UserInsn = cast<Instruction>(U);
4853 if (Visited.insert(UserInsn).second)
4854 Worklist.push_back(UserInsn);
4855 }
4856}
4857
4858namespace {
4859
4860/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4861/// expression in case its Loop is L. If it is not L then
4862/// if IgnoreOtherLoops is true then use AddRec itself
4863/// otherwise rewrite cannot be done.
4864/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4865class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4866public:
4867 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4868 bool IgnoreOtherLoops = true) {
4869 SCEVInitRewriter Rewriter(L, SE);
4870 const SCEV *Result = Rewriter.visit(S);
4871 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4872 return SE.getCouldNotCompute();
4873 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4874 ? SE.getCouldNotCompute()
4875 : Result;
4876 }
4877
4878 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4879 if (!SE.isLoopInvariant(Expr, L))
4880 SeenLoopVariantSCEVUnknown = true;
4881 return Expr;
4882 }
4883
4884 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4885 // Only re-write AddRecExprs for this loop.
4886 if (Expr->getLoop() == L)
4887 return Expr->getStart();
4888 SeenOtherLoops = true;
4889 return Expr;
4890 }
4891
4892 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4893
4894 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4895
4896private:
4897 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4898 : SCEVRewriteVisitor(SE), L(L) {}
4899
4900 const Loop *L;
4901 bool SeenLoopVariantSCEVUnknown = false;
4902 bool SeenOtherLoops = false;
4903};
4904
4905/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4906/// increment expression in case its Loop is L. If it is not L then
4907/// use AddRec itself.
4908/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4909class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4910public:
4911 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4912 SCEVPostIncRewriter Rewriter(L, SE);
4913 const SCEV *Result = Rewriter.visit(S);
4914 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4915 ? SE.getCouldNotCompute()
4916 : Result;
4917 }
4918
4919 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4920 if (!SE.isLoopInvariant(Expr, L))
4921 SeenLoopVariantSCEVUnknown = true;
4922 return Expr;
4923 }
4924
4925 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4926 // Only re-write AddRecExprs for this loop.
4927 if (Expr->getLoop() == L)
4928 return Expr->getPostIncExpr(SE);
4929 SeenOtherLoops = true;
4930 return Expr;
4931 }
4932
4933 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4934
4935 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4936
4937private:
4938 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4939 : SCEVRewriteVisitor(SE), L(L) {}
4940
4941 const Loop *L;
4942 bool SeenLoopVariantSCEVUnknown = false;
4943 bool SeenOtherLoops = false;
4944};
4945
4946/// This class evaluates the compare condition by matching it against the
4947/// condition of loop latch. If there is a match we assume a true value
4948/// for the condition while building SCEV nodes.
4949class SCEVBackedgeConditionFolder
4950 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4951public:
4952 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4953 ScalarEvolution &SE) {
4954 bool IsPosBECond = false;
4955 Value *BECond = nullptr;
4956 if (BasicBlock *Latch = L->getLoopLatch()) {
4957 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4958 if (BI && BI->isConditional()) {
4959 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4960 "Both outgoing branches should not target same header!");
4961 BECond = BI->getCondition();
4962 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4963 } else {
4964 return S;
4965 }
4966 }
4967 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4968 return Rewriter.visit(S);
4969 }
4970
4971 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4972 const SCEV *Result = Expr;
4973 bool InvariantF = SE.isLoopInvariant(Expr, L);
4974
4975 if (!InvariantF) {
4976 Instruction *I = cast<Instruction>(Expr->getValue());
4977 switch (I->getOpcode()) {
4978 case Instruction::Select: {
4979 SelectInst *SI = cast<SelectInst>(I);
4980 std::optional<const SCEV *> Res =
4981 compareWithBackedgeCondition(SI->getCondition());
4982 if (Res) {
4983 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4984 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4985 }
4986 break;
4987 }
4988 default: {
4989 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4990 if (Res)
4991 Result = *Res;
4992 break;
4993 }
4994 }
4995 }
4996 return Result;
4997 }
4998
4999private:
5000 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5001 bool IsPosBECond, ScalarEvolution &SE)
5002 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5003 IsPositiveBECond(IsPosBECond) {}
5004
5005 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5006
5007 const Loop *L;
5008 /// Loop back condition.
5009 Value *BackedgeCond = nullptr;
5010 /// Set to true if loop back is on positive branch condition.
5011 bool IsPositiveBECond;
5012};
5013
5014std::optional<const SCEV *>
5015SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5016
5017 // If value matches the backedge condition for loop latch,
5018 // then return a constant evolution node based on loopback
5019 // branch taken.
5020 if (BackedgeCond == IC)
5021 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5023 return std::nullopt;
5024}
5025
5026class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5027public:
5028 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5029 ScalarEvolution &SE) {
5030 SCEVShiftRewriter Rewriter(L, SE);
5031 const SCEV *Result = Rewriter.visit(S);
5032 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5033 }
5034
5035 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5036 // Only allow AddRecExprs for this loop.
5037 if (!SE.isLoopInvariant(Expr, L))
5038 Valid = false;
5039 return Expr;
5040 }
5041
5042 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5043 if (Expr->getLoop() == L && Expr->isAffine())
5044 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5045 Valid = false;
5046 return Expr;
5047 }
5048
5049 bool isValid() { return Valid; }
5050
5051private:
5052 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5053 : SCEVRewriteVisitor(SE), L(L) {}
5054
5055 const Loop *L;
5056 bool Valid = true;
5057};
5058
5059} // end anonymous namespace
5060
5062ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5063 if (!AR->isAffine())
5064 return SCEV::FlagAnyWrap;
5065
5066 using OBO = OverflowingBinaryOperator;
5067
5069
5070 if (!AR->hasNoSelfWrap()) {
5071 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5072 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5073 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5074 const APInt &BECountAP = BECountMax->getAPInt();
5075 unsigned NoOverflowBitWidth =
5076 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5077 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5079 }
5080 }
5081
5082 if (!AR->hasNoSignedWrap()) {
5083 ConstantRange AddRecRange = getSignedRange(AR);
5084 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5085
5087 Instruction::Add, IncRange, OBO::NoSignedWrap);
5088 if (NSWRegion.contains(AddRecRange))
5090 }
5091
5092 if (!AR->hasNoUnsignedWrap()) {
5093 ConstantRange AddRecRange = getUnsignedRange(AR);
5094 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5095
5097 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5098 if (NUWRegion.contains(AddRecRange))
5100 }
5101
5102 return Result;
5103}
5104
5106ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5108
5109 if (AR->hasNoSignedWrap())
5110 return Result;
5111
5112 if (!AR->isAffine())
5113 return Result;
5114
5115 // This function can be expensive, only try to prove NSW once per AddRec.
5116 if (!SignedWrapViaInductionTried.insert(AR).second)
5117 return Result;
5118
5119 const SCEV *Step = AR->getStepRecurrence(*this);
5120 const Loop *L = AR->getLoop();
5121
5122 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5123 // Note that this serves two purposes: It filters out loops that are
5124 // simply not analyzable, and it covers the case where this code is
5125 // being called from within backedge-taken count analysis, such that
5126 // attempting to ask for the backedge-taken count would likely result
5127 // in infinite recursion. In the later case, the analysis code will
5128 // cope with a conservative value, and it will take care to purge
5129 // that value once it has finished.
5130 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5131
5132 // Normally, in the cases we can prove no-overflow via a
5133 // backedge guarding condition, we can also compute a backedge
5134 // taken count for the loop. The exceptions are assumptions and
5135 // guards present in the loop -- SCEV is not great at exploiting
5136 // these to compute max backedge taken counts, but can still use
5137 // these to prove lack of overflow. Use this fact to avoid
5138 // doing extra work that may not pay off.
5139
5140 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5141 AC.assumptions().empty())
5142 return Result;
5143
5144 // If the backedge is guarded by a comparison with the pre-inc value the
5145 // addrec is safe. Also, if the entry is guarded by a comparison with the
5146 // start value and the backedge is guarded by a comparison with the post-inc
5147 // value, the addrec is safe.
5149 const SCEV *OverflowLimit =
5150 getSignedOverflowLimitForStep(Step, &Pred, this);
5151 if (OverflowLimit &&
5152 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5153 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5154 Result = setFlags(Result, SCEV::FlagNSW);
5155 }
5156 return Result;
5157}
5159ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5161
5162 if (AR->hasNoUnsignedWrap())
5163 return Result;
5164
5165 if (!AR->isAffine())
5166 return Result;
5167
5168 // This function can be expensive, only try to prove NUW once per AddRec.
5169 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5170 return Result;
5171
5172 const SCEV *Step = AR->getStepRecurrence(*this);
5173 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5174 const Loop *L = AR->getLoop();
5175
5176 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5177 // Note that this serves two purposes: It filters out loops that are
5178 // simply not analyzable, and it covers the case where this code is
5179 // being called from within backedge-taken count analysis, such that
5180 // attempting to ask for the backedge-taken count would likely result
5181 // in infinite recursion. In the later case, the analysis code will
5182 // cope with a conservative value, and it will take care to purge
5183 // that value once it has finished.
5184 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5185
5186 // Normally, in the cases we can prove no-overflow via a
5187 // backedge guarding condition, we can also compute a backedge
5188 // taken count for the loop. The exceptions are assumptions and
5189 // guards present in the loop -- SCEV is not great at exploiting
5190 // these to compute max backedge taken counts, but can still use
5191 // these to prove lack of overflow. Use this fact to avoid
5192 // doing extra work that may not pay off.
5193
5194 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5195 AC.assumptions().empty())
5196 return Result;
5197
5198 // If the backedge is guarded by a comparison with the pre-inc value the
5199 // addrec is safe. Also, if the entry is guarded by a comparison with the
5200 // start value and the backedge is guarded by a comparison with the post-inc
5201 // value, the addrec is safe.
5202 if (isKnownPositive(Step)) {
5204 getUnsignedRangeMax(Step));
5207 Result = setFlags(Result, SCEV::FlagNUW);
5208 }
5209 }
5210
5211 return Result;
5212}
5213
5214namespace {
5215
5216/// Represents an abstract binary operation. This may exist as a
5217/// normal instruction or constant expression, or may have been
5218/// derived from an expression tree.
5219struct BinaryOp {
5220 unsigned Opcode;
5221 Value *LHS;
5222 Value *RHS;
5223 bool IsNSW = false;
5224 bool IsNUW = false;
5225
5226 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5227 /// constant expression.
5228 Operator *Op = nullptr;
5229
5230 explicit BinaryOp(Operator *Op)
5231 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5232 Op(Op) {
5233 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5234 IsNSW = OBO->hasNoSignedWrap();
5235 IsNUW = OBO->hasNoUnsignedWrap();
5236 }
5237 }
5238
5239 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5240 bool IsNUW = false)
5241 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5242};
5243
5244} // end anonymous namespace
5245
5246/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5247static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5248 AssumptionCache &AC,
5249 const DominatorTree &DT,
5250 const Instruction *CxtI) {
5251 auto *Op = dyn_cast<Operator>(V);
5252 if (!Op)
5253 return std::nullopt;
5254
5255 // Implementation detail: all the cleverness here should happen without
5256 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5257 // SCEV expressions when possible, and we should not break that.
5258
5259 switch (Op->getOpcode()) {
5260 case Instruction::Add:
5261 case Instruction::Sub:
5262 case Instruction::Mul:
5263 case Instruction::UDiv:
5264 case Instruction::URem:
5265 case Instruction::And:
5266 case Instruction::AShr:
5267 case Instruction::Shl:
5268 return BinaryOp(Op);
5269
5270 case Instruction::Or: {
5271 // Convert or disjoint into add nuw nsw.
5272 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5273 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5274 /*IsNSW=*/true, /*IsNUW=*/true);
5275 return BinaryOp(Op);
5276 }
5277
5278 case Instruction::Xor:
5279 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5280 // If the RHS of the xor is a signmask, then this is just an add.
5281 // Instcombine turns add of signmask into xor as a strength reduction step.
5282 if (RHSC->getValue().isSignMask())
5283 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5284 // Binary `xor` is a bit-wise `add`.
5285 if (V->getType()->isIntegerTy(1))
5286 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5287 return BinaryOp(Op);
5288
5289 case Instruction::LShr:
5290 // Turn logical shift right of a constant into a unsigned divide.
5291 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5292 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5293
5294 // If the shift count is not less than the bitwidth, the result of
5295 // the shift is undefined. Don't try to analyze it, because the
5296 // resolution chosen here may differ from the resolution chosen in
5297 // other parts of the compiler.
5298 if (SA->getValue().ult(BitWidth)) {
5299 Constant *X =
5300 ConstantInt::get(SA->getContext(),
5301 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5302 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5303 }
5304 }
5305 return BinaryOp(Op);
5306
5307 case Instruction::ExtractValue: {
5308 auto *EVI = cast<ExtractValueInst>(Op);
5309 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5310 break;
5311
5312 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5313 if (!WO)
5314 break;
5315
5316 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5317 bool Signed = WO->isSigned();
5318 // TODO: Should add nuw/nsw flags for mul as well.
5319 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5320 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5321
5322 // Now that we know that all uses of the arithmetic-result component of
5323 // CI are guarded by the overflow check, we can go ahead and pretend
5324 // that the arithmetic is non-overflowing.
5325 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5326 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5327 }
5328
5329 default:
5330 break;
5331 }
5332
5333 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5334 // semantics as a Sub, return a binary sub expression.
5335 if (auto *II = dyn_cast<IntrinsicInst>(V))
5336 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5337 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5338
5339 return std::nullopt;
5340}
5341
5342/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5343/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5344/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5345/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5346/// follows one of the following patterns:
5347/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5348/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5349/// If the SCEV expression of \p Op conforms with one of the expected patterns
5350/// we return the type of the truncation operation, and indicate whether the
5351/// truncated type should be treated as signed/unsigned by setting
5352/// \p Signed to true/false, respectively.
5353static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5354 bool &Signed, ScalarEvolution &SE) {
5355 // The case where Op == SymbolicPHI (that is, with no type conversions on
5356 // the way) is handled by the regular add recurrence creating logic and
5357 // would have already been triggered in createAddRecForPHI. Reaching it here
5358 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5359 // because one of the other operands of the SCEVAddExpr updating this PHI is
5360 // not invariant).
5361 //
5362 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5363 // this case predicates that allow us to prove that Op == SymbolicPHI will
5364 // be added.
5365 if (Op == SymbolicPHI)
5366 return nullptr;
5367
5368 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5369 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5370 if (SourceBits != NewBits)
5371 return nullptr;
5372
5373 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5374 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5375 if (!SExt && !ZExt)
5376 return nullptr;
5377 const SCEVTruncateExpr *Trunc =
5378 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5379 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5380 if (!Trunc)
5381 return nullptr;
5382 const SCEV *X = Trunc->getOperand();
5383 if (X != SymbolicPHI)
5384 return nullptr;
5385 Signed = SExt != nullptr;
5386 return Trunc->getType();
5387}
5388
5389static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5390 if (!PN->getType()->isIntegerTy())
5391 return nullptr;
5392 const Loop *L = LI.getLoopFor(PN->getParent());
5393 if (!L || L->getHeader() != PN->getParent())
5394 return nullptr;
5395 return L;
5396}
5397
5398// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5399// computation that updates the phi follows the following pattern:
5400// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5401// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5402// If so, try to see if it can be rewritten as an AddRecExpr under some
5403// Predicates. If successful, return them as a pair. Also cache the results
5404// of the analysis.
5405//
5406// Example usage scenario:
5407// Say the Rewriter is called for the following SCEV:
5408// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5409// where:
5410// %X = phi i64 (%Start, %BEValue)
5411// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5412// and call this function with %SymbolicPHI = %X.
5413//
5414// The analysis will find that the value coming around the backedge has
5415// the following SCEV:
5416// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5417// Upon concluding that this matches the desired pattern, the function
5418// will return the pair {NewAddRec, SmallPredsVec} where:
5419// NewAddRec = {%Start,+,%Step}
5420// SmallPredsVec = {P1, P2, P3} as follows:
5421// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5422// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5423// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5424// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5425// under the predicates {P1,P2,P3}.
5426// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5427// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5428//
5429// TODO's:
5430//
5431// 1) Extend the Induction descriptor to also support inductions that involve
5432// casts: When needed (namely, when we are called in the context of the
5433// vectorizer induction analysis), a Set of cast instructions will be
5434// populated by this method, and provided back to isInductionPHI. This is
5435// needed to allow the vectorizer to properly record them to be ignored by
5436// the cost model and to avoid vectorizing them (otherwise these casts,
5437// which are redundant under the runtime overflow checks, will be
5438// vectorized, which can be costly).
5439//
5440// 2) Support additional induction/PHISCEV patterns: We also want to support
5441// inductions where the sext-trunc / zext-trunc operations (partly) occur
5442// after the induction update operation (the induction increment):
5443//
5444// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5445// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5446//
5447// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5448// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5449//
5450// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5451std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5452ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5454
5455 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5456 // return an AddRec expression under some predicate.
5457
5458 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5459 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5460 assert(L && "Expecting an integer loop header phi");
5461
5462 // The loop may have multiple entrances or multiple exits; we can analyze
5463 // this phi as an addrec if it has a unique entry value and a unique
5464 // backedge value.
5465 Value *BEValueV = nullptr, *StartValueV = nullptr;
5466 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5467 Value *V = PN->getIncomingValue(i);
5468 if (L->contains(PN->getIncomingBlock(i))) {
5469 if (!BEValueV) {
5470 BEValueV = V;
5471 } else if (BEValueV != V) {
5472 BEValueV = nullptr;
5473 break;
5474 }
5475 } else if (!StartValueV) {
5476 StartValueV = V;
5477 } else if (StartValueV != V) {
5478 StartValueV = nullptr;
5479 break;
5480 }
5481 }
5482 if (!BEValueV || !StartValueV)
5483 return std::nullopt;
5484
5485 const SCEV *BEValue = getSCEV(BEValueV);
5486
5487 // If the value coming around the backedge is an add with the symbolic
5488 // value we just inserted, possibly with casts that we can ignore under
5489 // an appropriate runtime guard, then we found a simple induction variable!
5490 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5491 if (!Add)
5492 return std::nullopt;
5493
5494 // If there is a single occurrence of the symbolic value, possibly
5495 // casted, replace it with a recurrence.
5496 unsigned FoundIndex = Add->getNumOperands();
5497 Type *TruncTy = nullptr;
5498 bool Signed;
5499 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5500 if ((TruncTy =
5501 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5502 if (FoundIndex == e) {
5503 FoundIndex = i;
5504 break;
5505 }
5506
5507 if (FoundIndex == Add->getNumOperands())
5508 return std::nullopt;
5509
5510 // Create an add with everything but the specified operand.
5512 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5513 if (i != FoundIndex)
5514 Ops.push_back(Add->getOperand(i));
5515 const SCEV *Accum = getAddExpr(Ops);
5516
5517 // The runtime checks will not be valid if the step amount is
5518 // varying inside the loop.
5519 if (!isLoopInvariant(Accum, L))
5520 return std::nullopt;
5521
5522 // *** Part2: Create the predicates
5523
5524 // Analysis was successful: we have a phi-with-cast pattern for which we
5525 // can return an AddRec expression under the following predicates:
5526 //
5527 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5528 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5529 // P2: An Equal predicate that guarantees that
5530 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5531 // P3: An Equal predicate that guarantees that
5532 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5533 //
5534 // As we next prove, the above predicates guarantee that:
5535 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5536 //
5537 //
5538 // More formally, we want to prove that:
5539 // Expr(i+1) = Start + (i+1) * Accum
5540 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5541 //
5542 // Given that:
5543 // 1) Expr(0) = Start
5544 // 2) Expr(1) = Start + Accum
5545 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5546 // 3) Induction hypothesis (step i):
5547 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5548 //
5549 // Proof:
5550 // Expr(i+1) =
5551 // = Start + (i+1)*Accum
5552 // = (Start + i*Accum) + Accum
5553 // = Expr(i) + Accum
5554 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5555 // :: from step i
5556 //
5557 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5558 //
5559 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5560 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5561 // + Accum :: from P3
5562 //
5563 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5564 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5565 //
5566 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5567 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5568 //
5569 // By induction, the same applies to all iterations 1<=i<n:
5570 //
5571
5572 // Create a truncated addrec for which we will add a no overflow check (P1).
5573 const SCEV *StartVal = getSCEV(StartValueV);
5574 const SCEV *PHISCEV =
5575 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5576 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5577
5578 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5579 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5580 // will be constant.
5581 //
5582 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5583 // add P1.
5584 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5588 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5589 Predicates.push_back(AddRecPred);
5590 }
5591
5592 // Create the Equal Predicates P2,P3:
5593
5594 // It is possible that the predicates P2 and/or P3 are computable at
5595 // compile time due to StartVal and/or Accum being constants.
5596 // If either one is, then we can check that now and escape if either P2
5597 // or P3 is false.
5598
5599 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5600 // for each of StartVal and Accum
5601 auto getExtendedExpr = [&](const SCEV *Expr,
5602 bool CreateSignExtend) -> const SCEV * {
5603 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5604 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5605 const SCEV *ExtendedExpr =
5606 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5607 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5608 return ExtendedExpr;
5609 };
5610
5611 // Given:
5612 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5613 // = getExtendedExpr(Expr)
5614 // Determine whether the predicate P: Expr == ExtendedExpr
5615 // is known to be false at compile time
5616 auto PredIsKnownFalse = [&](const SCEV *Expr,
5617 const SCEV *ExtendedExpr) -> bool {
5618 return Expr != ExtendedExpr &&
5619 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5620 };
5621
5622 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5623 if (PredIsKnownFalse(StartVal, StartExtended)) {
5624 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5625 return std::nullopt;
5626 }
5627
5628 // The Step is always Signed (because the overflow checks are either
5629 // NSSW or NUSW)
5630 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5631 if (PredIsKnownFalse(Accum, AccumExtended)) {
5632 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5633 return std::nullopt;
5634 }
5635
5636 auto AppendPredicate = [&](const SCEV *Expr,
5637 const SCEV *ExtendedExpr) -> void {
5638 if (Expr != ExtendedExpr &&
5639 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5640 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5641 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5642 Predicates.push_back(Pred);
5643 }
5644 };
5645
5646 AppendPredicate(StartVal, StartExtended);
5647 AppendPredicate(Accum, AccumExtended);
5648
5649 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5650 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5651 // into NewAR if it will also add the runtime overflow checks specified in
5652 // Predicates.
5653 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5654
5655 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5656 std::make_pair(NewAR, Predicates);
5657 // Remember the result of the analysis for this SCEV at this locayyytion.
5658 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5659 return PredRewrite;
5660}
5661
5662std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5664 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5665 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5666 if (!L)
5667 return std::nullopt;
5668
5669 // Check to see if we already analyzed this PHI.
5670 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5671 if (I != PredicatedSCEVRewrites.end()) {
5672 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5673 I->second;
5674 // Analysis was done before and failed to create an AddRec:
5675 if (Rewrite.first == SymbolicPHI)
5676 return std::nullopt;
5677 // Analysis was done before and succeeded to create an AddRec under
5678 // a predicate:
5679 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5680 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5681 return Rewrite;
5682 }
5683
5684 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5685 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5686
5687 // Record in the cache that the analysis failed
5688 if (!Rewrite) {
5690 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5691 return std::nullopt;
5692 }
5693
5694 return Rewrite;
5695}
5696
5697// FIXME: This utility is currently required because the Rewriter currently
5698// does not rewrite this expression:
5699// {0, +, (sext ix (trunc iy to ix) to iy)}
5700// into {0, +, %step},
5701// even when the following Equal predicate exists:
5702// "%step == (sext ix (trunc iy to ix) to iy)".
5704 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5705 if (AR1 == AR2)
5706 return true;
5707
5708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709 if (Expr1 != Expr2 &&
5710 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5712 return false;
5713 return true;
5714 };
5715
5716 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5717 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5718 return false;
5719 return true;
5720}
5721
5722/// A helper function for createAddRecFromPHI to handle simple cases.
5723///
5724/// This function tries to find an AddRec expression for the simplest (yet most
5725/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5726/// If it fails, createAddRecFromPHI will use a more general, but slow,
5727/// technique for finding the AddRec expression.
5728const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5729 Value *BEValueV,
5730 Value *StartValueV) {
5731 const Loop *L = LI.getLoopFor(PN->getParent());
5732 assert(L && L->getHeader() == PN->getParent());
5733 assert(BEValueV && StartValueV);
5734
5735 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5736 if (!BO)
5737 return nullptr;
5738
5739 if (BO->Opcode != Instruction::Add)
5740 return nullptr;
5741
5742 const SCEV *Accum = nullptr;
5743 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5744 Accum = getSCEV(BO->RHS);
5745 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5746 Accum = getSCEV(BO->LHS);
5747
5748 if (!Accum)
5749 return nullptr;
5750
5752 if (BO->IsNUW)
5753 Flags = setFlags(Flags, SCEV::FlagNUW);
5754 if (BO->IsNSW)
5755 Flags = setFlags(Flags, SCEV::FlagNSW);
5756
5757 const SCEV *StartVal = getSCEV(StartValueV);
5758 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5759 insertValueToMap(PN, PHISCEV);
5760
5761 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5762 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5764 proveNoWrapViaConstantRanges(AR)));
5765 }
5766
5767 // We can add Flags to the post-inc expression only if we
5768 // know that it is *undefined behavior* for BEValueV to
5769 // overflow.
5770 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5771 assert(isLoopInvariant(Accum, L) &&
5772 "Accum is defined outside L, but is not invariant?");
5773 if (isAddRecNeverPoison(BEInst, L))
5774 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5775 }
5776
5777 return PHISCEV;
5778}
5779
5780const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5781 const Loop *L = LI.getLoopFor(PN->getParent());
5782 if (!L || L->getHeader() != PN->getParent())
5783 return nullptr;
5784
5785 // The loop may have multiple entrances or multiple exits; we can analyze
5786 // this phi as an addrec if it has a unique entry value and a unique
5787 // backedge value.
5788 Value *BEValueV = nullptr, *StartValueV = nullptr;
5789 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5790 Value *V = PN->getIncomingValue(i);
5791 if (L->contains(PN->getIncomingBlock(i))) {
5792 if (!BEValueV) {
5793 BEValueV = V;
5794 } else if (BEValueV != V) {
5795 BEValueV = nullptr;
5796 break;
5797 }
5798 } else if (!StartValueV) {
5799 StartValueV = V;
5800 } else if (StartValueV != V) {
5801 StartValueV = nullptr;
5802 break;
5803 }
5804 }
5805 if (!BEValueV || !StartValueV)
5806 return nullptr;
5807
5808 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5809 "PHI node already processed?");
5810
5811 // First, try to find AddRec expression without creating a fictituos symbolic
5812 // value for PN.
5813 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5814 return S;
5815
5816 // Handle PHI node value symbolically.
5817 const SCEV *SymbolicName = getUnknown(PN);
5818 insertValueToMap(PN, SymbolicName);
5819
5820 // Using this symbolic name for the PHI, analyze the value coming around
5821 // the back-edge.
5822 const SCEV *BEValue = getSCEV(BEValueV);
5823
5824 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5825 // has a special value for the first iteration of the loop.
5826
5827 // If the value coming around the backedge is an add with the symbolic
5828 // value we just inserted, then we found a simple induction variable!
5829 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5830 // If there is a single occurrence of the symbolic value, replace it
5831 // with a recurrence.
5832 unsigned FoundIndex = Add->getNumOperands();
5833 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5834 if (Add->getOperand(i) == SymbolicName)
5835 if (FoundIndex == e) {
5836 FoundIndex = i;
5837 break;
5838 }
5839
5840 if (FoundIndex != Add->getNumOperands()) {
5841 // Create an add with everything but the specified operand.
5843 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5844 if (i != FoundIndex)
5845 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5846 L, *this));
5847 const SCEV *Accum = getAddExpr(Ops);
5848
5849 // This is not a valid addrec if the step amount is varying each
5850 // loop iteration, but is not itself an addrec in this loop.
5851 if (isLoopInvariant(Accum, L) ||
5852 (isa<SCEVAddRecExpr>(Accum) &&
5853 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5855
5856 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5857 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5858 if (BO->IsNUW)
5859 Flags = setFlags(Flags, SCEV::FlagNUW);
5860 if (BO->IsNSW)
5861 Flags = setFlags(Flags, SCEV::FlagNSW);
5862 }
5863 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5864 if (GEP->getOperand(0) == PN) {
5865 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5866 // If the increment has any nowrap flags, then we know the address
5867 // space cannot be wrapped around.
5868 if (NW != GEPNoWrapFlags::none())
5869 Flags = setFlags(Flags, SCEV::FlagNW);
5870 // If the GEP is nuw or nusw with non-negative offset, we know that
5871 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5872 // offset is treated as signed, while the base is unsigned.
5873 if (NW.hasNoUnsignedWrap() ||
5875 Flags = setFlags(Flags, SCEV::FlagNUW);
5876 }
5877
5878 // We cannot transfer nuw and nsw flags from subtraction
5879 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5880 // for instance.
5881 }
5882
5883 const SCEV *StartVal = getSCEV(StartValueV);
5884 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5885
5886 // Okay, for the entire analysis of this edge we assumed the PHI
5887 // to be symbolic. We now need to go back and purge all of the
5888 // entries for the scalars that use the symbolic expression.
5889 forgetMemoizedResults(SymbolicName);
5890 insertValueToMap(PN, PHISCEV);
5891
5892 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5893 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5895 proveNoWrapViaConstantRanges(AR)));
5896 }
5897
5898 // We can add Flags to the post-inc expression only if we
5899 // know that it is *undefined behavior* for BEValueV to
5900 // overflow.
5901 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5902 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5903 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5904
5905 return PHISCEV;
5906 }
5907 }
5908 } else {
5909 // Otherwise, this could be a loop like this:
5910 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5911 // In this case, j = {1,+,1} and BEValue is j.
5912 // Because the other in-value of i (0) fits the evolution of BEValue
5913 // i really is an addrec evolution.
5914 //
5915 // We can generalize this saying that i is the shifted value of BEValue
5916 // by one iteration:
5917 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5918
5919 // Do not allow refinement in rewriting of BEValue.
5920 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5921 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5922 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5923 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5924 const SCEV *StartVal = getSCEV(StartValueV);
5925 if (Start == StartVal) {
5926 // Okay, for the entire analysis of this edge we assumed the PHI
5927 // to be symbolic. We now need to go back and purge all of the
5928 // entries for the scalars that use the symbolic expression.
5929 forgetMemoizedResults(SymbolicName);
5930 insertValueToMap(PN, Shifted);
5931 return Shifted;
5932 }
5933 }
5934 }
5935
5936 // Remove the temporary PHI node SCEV that has been inserted while intending
5937 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5938 // as it will prevent later (possibly simpler) SCEV expressions to be added
5939 // to the ValueExprMap.
5940 eraseValueFromMap(PN);
5941
5942 return nullptr;
5943}
5944
5945// Try to match a control flow sequence that branches out at BI and merges back
5946// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5947// match.
5949 Value *&C, Value *&LHS, Value *&RHS) {
5950 C = BI->getCondition();
5951
5952 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5953 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5954
5955 if (!LeftEdge.isSingleEdge())
5956 return false;
5957
5958 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5959
5960 Use &LeftUse = Merge->getOperandUse(0);
5961 Use &RightUse = Merge->getOperandUse(1);
5962
5963 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5964 LHS = LeftUse;
5965 RHS = RightUse;
5966 return true;
5967 }
5968
5969 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5970 LHS = RightUse;
5971 RHS = LeftUse;
5972 return true;
5973 }
5974
5975 return false;
5976}
5977
5978const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5979 auto IsReachable =
5980 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5981 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5982 // Try to match
5983 //
5984 // br %cond, label %left, label %right
5985 // left:
5986 // br label %merge
5987 // right:
5988 // br label %merge
5989 // merge:
5990 // V = phi [ %x, %left ], [ %y, %right ]
5991 //
5992 // as "select %cond, %x, %y"
5993
5994 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5995 assert(IDom && "At least the entry block should dominate PN");
5996
5997 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5998 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5999
6000 if (BI && BI->isConditional() &&
6001 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6002 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6003 properlyDominates(getSCEV(RHS), PN->getParent()))
6004 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6005 }
6006
6007 return nullptr;
6008}
6009
6010/// Returns SCEV for the first operand of a phi if all phi operands have
6011/// identical opcodes and operands
6012/// eg.
6013/// a: %add = %a + %b
6014/// br %c
6015/// b: %add1 = %a + %b
6016/// br %c
6017/// c: %phi = phi [%add, a], [%add1, b]
6018/// scev(%phi) => scev(%add)
6019const SCEV *
6020ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6021 BinaryOperator *CommonInst = nullptr;
6022 // Check if instructions are identical.
6023 for (Value *Incoming : PN->incoming_values()) {
6024 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6025 if (!IncomingInst)
6026 return nullptr;
6027 if (CommonInst) {
6028 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6029 return nullptr; // Not identical, give up
6030 } else {
6031 // Remember binary operator
6032 CommonInst = IncomingInst;
6033 }
6034 }
6035 if (!CommonInst)
6036 return nullptr;
6037
6038 // Check if SCEV exprs for instructions are identical.
6039 const SCEV *CommonSCEV = getSCEV(CommonInst);
6040 bool SCEVExprsIdentical =
6042 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6043 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6044}
6045
6046const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6047 if (const SCEV *S = createAddRecFromPHI(PN))
6048 return S;
6049
6050 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6051 // phi node for X.
6052 if (Value *V = simplifyInstruction(
6053 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6054 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6055 return getSCEV(V);
6056
6057 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6058 return S;
6059
6060 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6061 return S;
6062
6063 // If it's not a loop phi, we can't handle it yet.
6064 return getUnknown(PN);
6065}
6066
6067bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6068 SCEVTypes RootKind) {
6069 struct FindClosure {
6070 const SCEV *OperandToFind;
6071 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6072 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6073
6074 bool Found = false;
6075
6076 bool canRecurseInto(SCEVTypes Kind) const {
6077 // We can only recurse into the SCEV expression of the same effective type
6078 // as the type of our root SCEV expression, and into zero-extensions.
6079 return RootKind == Kind || NonSequentialRootKind == Kind ||
6080 scZeroExtend == Kind;
6081 };
6082
6083 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6084 : OperandToFind(OperandToFind), RootKind(RootKind),
6085 NonSequentialRootKind(
6087 RootKind)) {}
6088
6089 bool follow(const SCEV *S) {
6090 Found = S == OperandToFind;
6091
6092 return !isDone() && canRecurseInto(S->getSCEVType());
6093 }
6094
6095 bool isDone() const { return Found; }
6096 };
6097
6098 FindClosure FC(OperandToFind, RootKind);
6099 visitAll(Root, FC);
6100 return FC.Found;
6101}
6102
6103std::optional<const SCEV *>
6104ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6105 ICmpInst *Cond,
6106 Value *TrueVal,
6107 Value *FalseVal) {
6108 // Try to match some simple smax or umax patterns.
6109 auto *ICI = Cond;
6110
6111 Value *LHS = ICI->getOperand(0);
6112 Value *RHS = ICI->getOperand(1);
6113
6114 switch (ICI->getPredicate()) {
6115 case ICmpInst::ICMP_SLT:
6116 case ICmpInst::ICMP_SLE:
6117 case ICmpInst::ICMP_ULT:
6118 case ICmpInst::ICMP_ULE:
6119 std::swap(LHS, RHS);
6120 [[fallthrough]];
6121 case ICmpInst::ICMP_SGT:
6122 case ICmpInst::ICMP_SGE:
6123 case ICmpInst::ICMP_UGT:
6124 case ICmpInst::ICMP_UGE:
6125 // a > b ? a+x : b+x -> max(a, b)+x
6126 // a > b ? b+x : a+x -> min(a, b)+x
6128 bool Signed = ICI->isSigned();
6129 const SCEV *LA = getSCEV(TrueVal);
6130 const SCEV *RA = getSCEV(FalseVal);
6131 const SCEV *LS = getSCEV(LHS);
6132 const SCEV *RS = getSCEV(RHS);
6133 if (LA->getType()->isPointerTy()) {
6134 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6135 // Need to make sure we can't produce weird expressions involving
6136 // negated pointers.
6137 if (LA == LS && RA == RS)
6138 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6139 if (LA == RS && RA == LS)
6140 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6141 }
6142 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6143 if (Op->getType()->isPointerTy()) {
6145 if (isa<SCEVCouldNotCompute>(Op))
6146 return Op;
6147 }
6148 if (Signed)
6149 Op = getNoopOrSignExtend(Op, Ty);
6150 else
6151 Op = getNoopOrZeroExtend(Op, Ty);
6152 return Op;
6153 };
6154 LS = CoerceOperand(LS);
6155 RS = CoerceOperand(RS);
6156 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6157 break;
6158 const SCEV *LDiff = getMinusSCEV(LA, LS);
6159 const SCEV *RDiff = getMinusSCEV(RA, RS);
6160 if (LDiff == RDiff)
6161 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6162 LDiff);
6163 LDiff = getMinusSCEV(LA, RS);
6164 RDiff = getMinusSCEV(RA, LS);
6165 if (LDiff == RDiff)
6166 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6167 LDiff);
6168 }
6169 break;
6170 case ICmpInst::ICMP_NE:
6171 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6172 std::swap(TrueVal, FalseVal);
6173 [[fallthrough]];
6174 case ICmpInst::ICMP_EQ:
6175 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6177 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6178 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6179 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6180 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6181 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6182 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6183 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6184 return getAddExpr(getUMaxExpr(X, C), Y);
6185 }
6186 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6187 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6188 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6189 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6190 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6191 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6192 const SCEV *X = getSCEV(LHS);
6193 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6194 X = ZExt->getOperand();
6195 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6196 const SCEV *FalseValExpr = getSCEV(FalseVal);
6197 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6198 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6199 /*Sequential=*/true);
6200 }
6201 }
6202 break;
6203 default:
6204 break;
6205 }
6206
6207 return std::nullopt;
6208}
6209
6210static std::optional<const SCEV *>
6212 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6213 assert(CondExpr->getType()->isIntegerTy(1) &&
6214 TrueExpr->getType() == FalseExpr->getType() &&
6215 TrueExpr->getType()->isIntegerTy(1) &&
6216 "Unexpected operands of a select.");
6217
6218 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6219 // --> C + (umin_seq cond, x - C)
6220 //
6221 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6222 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6223 // --> C + (umin_seq ~cond, x - C)
6224
6225 // FIXME: while we can't legally model the case where both of the hands
6226 // are fully variable, we only require that the *difference* is constant.
6227 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6228 return std::nullopt;
6229
6230 const SCEV *X, *C;
6231 if (isa<SCEVConstant>(TrueExpr)) {
6232 CondExpr = SE->getNotSCEV(CondExpr);
6233 X = FalseExpr;
6234 C = TrueExpr;
6235 } else {
6236 X = TrueExpr;
6237 C = FalseExpr;
6238 }
6239 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6240 /*Sequential=*/true));
6241}
6242
6243static std::optional<const SCEV *>
6245 Value *FalseVal) {
6246 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6247 return std::nullopt;
6248
6249 const auto *SECond = SE->getSCEV(Cond);
6250 const auto *SETrue = SE->getSCEV(TrueVal);
6251 const auto *SEFalse = SE->getSCEV(FalseVal);
6252 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6253}
6254
6255const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6256 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6257 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6258 assert(TrueVal->getType() == FalseVal->getType() &&
6259 V->getType() == TrueVal->getType() &&
6260 "Types of select hands and of the result must match.");
6261
6262 // For now, only deal with i1-typed `select`s.
6263 if (!V->getType()->isIntegerTy(1))
6264 return getUnknown(V);
6265
6266 if (std::optional<const SCEV *> S =
6267 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6268 return *S;
6269
6270 return getUnknown(V);
6271}
6272
6273const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6274 Value *TrueVal,
6275 Value *FalseVal) {
6276 // Handle "constant" branch or select. This can occur for instance when a
6277 // loop pass transforms an inner loop and moves on to process the outer loop.
6278 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6279 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6280
6281 if (auto *I = dyn_cast<Instruction>(V)) {
6282 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6283 if (std::optional<const SCEV *> S =
6284 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6285 TrueVal, FalseVal))
6286 return *S;
6287 }
6288 }
6289
6290 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6291}
6292
6293/// Expand GEP instructions into add and multiply operations. This allows them
6294/// to be analyzed by regular SCEV code.
6295const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6296 assert(GEP->getSourceElementType()->isSized() &&
6297 "GEP source element type must be sized");
6298
6300 for (Value *Index : GEP->indices())
6301 IndexExprs.push_back(getSCEV(Index));
6302 return getGEPExpr(GEP, IndexExprs);
6303}
6304
6305APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6307 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6308 return TrailingZeros >= BitWidth
6310 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6311 };
6312 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6313 // The result is GCD of all operands results.
6314 APInt Res = getConstantMultiple(N->getOperand(0));
6315 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6317 Res, getConstantMultiple(N->getOperand(I)));
6318 return Res;
6319 };
6320
6321 switch (S->getSCEVType()) {
6322 case scConstant:
6323 return cast<SCEVConstant>(S)->getAPInt();
6324 case scPtrToInt:
6325 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6326 case scUDivExpr:
6327 case scVScale:
6328 return APInt(BitWidth, 1);
6329 case scTruncate: {
6330 // Only multiples that are a power of 2 will hold after truncation.
6331 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6332 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6333 return GetShiftedByZeros(TZ);
6334 }
6335 case scZeroExtend: {
6336 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6337 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6338 }
6339 case scSignExtend: {
6340 // Only multiples that are a power of 2 will hold after sext.
6341 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6343 return GetShiftedByZeros(TZ);
6344 }
6345 case scMulExpr: {
6346 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6347 if (M->hasNoUnsignedWrap()) {
6348 // The result is the product of all operand results.
6349 APInt Res = getConstantMultiple(M->getOperand(0));
6350 for (const SCEV *Operand : M->operands().drop_front())
6351 Res = Res * getConstantMultiple(Operand);
6352 return Res;
6353 }
6354
6355 // If there are no wrap guarentees, find the trailing zeros, which is the
6356 // sum of trailing zeros for all its operands.
6357 uint32_t TZ = 0;
6358 for (const SCEV *Operand : M->operands())
6359 TZ += getMinTrailingZeros(Operand);
6360 return GetShiftedByZeros(TZ);
6361 }
6362 case scAddExpr:
6363 case scAddRecExpr: {
6364 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6365 if (N->hasNoUnsignedWrap())
6366 return GetGCDMultiple(N);
6367 // Find the trailing bits, which is the minimum of its operands.
6368 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6369 for (const SCEV *Operand : N->operands().drop_front())
6370 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6371 return GetShiftedByZeros(TZ);
6372 }
6373 case scUMaxExpr:
6374 case scSMaxExpr:
6375 case scUMinExpr:
6376 case scSMinExpr:
6378 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6379 case scUnknown: {
6380 // ask ValueTracking for known bits
6381 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6382 unsigned Known =
6383 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6384 .countMinTrailingZeros();
6385 return GetShiftedByZeros(Known);
6386 }
6387 case scCouldNotCompute:
6388 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6389 }
6390 llvm_unreachable("Unknown SCEV kind!");
6391}
6392
6394 auto I = ConstantMultipleCache.find(S);
6395 if (I != ConstantMultipleCache.end())
6396 return I->second;
6397
6398 APInt Result = getConstantMultipleImpl(S);
6399 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6400 assert(InsertPair.second && "Should insert a new key");
6401 return InsertPair.first->second;
6402}
6403
6405 APInt Multiple = getConstantMultiple(S);
6406 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6407}
6408
6410 return std::min(getConstantMultiple(S).countTrailingZeros(),
6411 (unsigned)getTypeSizeInBits(S->getType()));
6412}
6413
6414/// Helper method to assign a range to V from metadata present in the IR.
6415static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6416 if (Instruction *I = dyn_cast<Instruction>(V)) {
6417 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6418 return getConstantRangeFromMetadata(*MD);
6419 if (const auto *CB = dyn_cast<CallBase>(V))
6420 if (std::optional<ConstantRange> Range = CB->getRange())
6421 return Range;
6422 }
6423 if (auto *A = dyn_cast<Argument>(V))
6424 if (std::optional<ConstantRange> Range = A->getRange())
6425 return Range;
6426
6427 return std::nullopt;
6428}
6429
6431 SCEV::NoWrapFlags Flags) {
6432 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6433 AddRec->setNoWrapFlags(Flags);
6434 UnsignedRanges.erase(AddRec);
6435 SignedRanges.erase(AddRec);
6436 ConstantMultipleCache.erase(AddRec);
6437 }
6438}
6439
6440ConstantRange ScalarEvolution::
6441getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6442 const DataLayout &DL = getDataLayout();
6443
6444 unsigned BitWidth = getTypeSizeInBits(U->getType());
6445 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6446
6447 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6448 // use information about the trip count to improve our available range. Note
6449 // that the trip count independent cases are already handled by known bits.
6450 // WARNING: The definition of recurrence used here is subtly different than
6451 // the one used by AddRec (and thus most of this file). Step is allowed to
6452 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6453 // and other addrecs in the same loop (for non-affine addrecs). The code
6454 // below intentionally handles the case where step is not loop invariant.
6455 auto *P = dyn_cast<PHINode>(U->getValue());
6456 if (!P)
6457 return FullSet;
6458
6459 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6460 // even the values that are not available in these blocks may come from them,
6461 // and this leads to false-positive recurrence test.
6462 for (auto *Pred : predecessors(P->getParent()))
6463 if (!DT.isReachableFromEntry(Pred))
6464 return FullSet;
6465
6466 BinaryOperator *BO;
6467 Value *Start, *Step;
6468 if (!matchSimpleRecurrence(P, BO, Start, Step))
6469 return FullSet;
6470
6471 // If we found a recurrence in reachable code, we must be in a loop. Note
6472 // that BO might be in some subloop of L, and that's completely okay.
6473 auto *L = LI.getLoopFor(P->getParent());
6474 assert(L && L->getHeader() == P->getParent());
6475 if (!L->contains(BO->getParent()))
6476 // NOTE: This bailout should be an assert instead. However, asserting
6477 // the condition here exposes a case where LoopFusion is querying SCEV
6478 // with malformed loop information during the midst of the transform.
6479 // There doesn't appear to be an obvious fix, so for the moment bailout
6480 // until the caller issue can be fixed. PR49566 tracks the bug.
6481 return FullSet;
6482
6483 // TODO: Extend to other opcodes such as mul, and div
6484 switch (BO->getOpcode()) {
6485 default:
6486 return FullSet;
6487 case Instruction::AShr:
6488 case Instruction::LShr:
6489 case Instruction::Shl:
6490 break;
6491 };
6492
6493 if (BO->getOperand(0) != P)
6494 // TODO: Handle the power function forms some day.
6495 return FullSet;
6496
6497 unsigned TC = getSmallConstantMaxTripCount(L);
6498 if (!TC || TC >= BitWidth)
6499 return FullSet;
6500
6501 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6502 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6503 assert(KnownStart.getBitWidth() == BitWidth &&
6504 KnownStep.getBitWidth() == BitWidth);
6505
6506 // Compute total shift amount, being careful of overflow and bitwidths.
6507 auto MaxShiftAmt = KnownStep.getMaxValue();
6508 APInt TCAP(BitWidth, TC-1);
6509 bool Overflow = false;
6510 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6511 if (Overflow)
6512 return FullSet;
6513
6514 switch (BO->getOpcode()) {
6515 default:
6516 llvm_unreachable("filtered out above");
6517 case Instruction::AShr: {
6518 // For each ashr, three cases:
6519 // shift = 0 => unchanged value
6520 // saturation => 0 or -1
6521 // other => a value closer to zero (of the same sign)
6522 // Thus, the end value is closer to zero than the start.
6523 auto KnownEnd = KnownBits::ashr(KnownStart,
6524 KnownBits::makeConstant(TotalShift));
6525 if (KnownStart.isNonNegative())
6526 // Analogous to lshr (simply not yet canonicalized)
6527 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6528 KnownStart.getMaxValue() + 1);
6529 if (KnownStart.isNegative())
6530 // End >=u Start && End <=s Start
6531 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6532 KnownEnd.getMaxValue() + 1);
6533 break;
6534 }
6535 case Instruction::LShr: {
6536 // For each lshr, three cases:
6537 // shift = 0 => unchanged value
6538 // saturation => 0
6539 // other => a smaller positive number
6540 // Thus, the low end of the unsigned range is the last value produced.
6541 auto KnownEnd = KnownBits::lshr(KnownStart,
6542 KnownBits::makeConstant(TotalShift));
6543 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6544 KnownStart.getMaxValue() + 1);
6545 }
6546 case Instruction::Shl: {
6547 // Iff no bits are shifted out, value increases on every shift.
6548 auto KnownEnd = KnownBits::shl(KnownStart,
6549 KnownBits::makeConstant(TotalShift));
6550 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6551 return ConstantRange(KnownStart.getMinValue(),
6552 KnownEnd.getMaxValue() + 1);
6553 break;
6554 }
6555 };
6556 return FullSet;
6557}
6558
6559const ConstantRange &
6560ScalarEvolution::getRangeRefIter(const SCEV *S,
6561 ScalarEvolution::RangeSignHint SignHint) {
6563 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6564 : SignedRanges;
6567
6568 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6569 // SCEVUnknown PHI node.
6570 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6571 if (!Seen.insert(Expr).second)
6572 return;
6573 if (Cache.contains(Expr))
6574 return;
6575 switch (Expr->getSCEVType()) {
6576 case scUnknown:
6577 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6578 break;
6579 [[fallthrough]];
6580 case scConstant:
6581 case scVScale:
6582 case scTruncate:
6583 case scZeroExtend:
6584 case scSignExtend:
6585 case scPtrToInt:
6586 case scAddExpr:
6587 case scMulExpr:
6588 case scUDivExpr:
6589 case scAddRecExpr:
6590 case scUMaxExpr:
6591 case scSMaxExpr:
6592 case scUMinExpr:
6593 case scSMinExpr:
6595 WorkList.push_back(Expr);
6596 break;
6597 case scCouldNotCompute:
6598 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6599 }
6600 };
6601 AddToWorklist(S);
6602
6603 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6604 for (unsigned I = 0; I != WorkList.size(); ++I) {
6605 const SCEV *P = WorkList[I];
6606 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6607 // If it is not a `SCEVUnknown`, just recurse into operands.
6608 if (!UnknownS) {
6609 for (const SCEV *Op : P->operands())
6610 AddToWorklist(Op);
6611 continue;
6612 }
6613 // `SCEVUnknown`'s require special treatment.
6614 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6615 if (!PendingPhiRangesIter.insert(P).second)
6616 continue;
6617 for (auto &Op : reverse(P->operands()))
6618 AddToWorklist(getSCEV(Op));
6619 }
6620 }
6621
6622 if (!WorkList.empty()) {
6623 // Use getRangeRef to compute ranges for items in the worklist in reverse
6624 // order. This will force ranges for earlier operands to be computed before
6625 // their users in most cases.
6626 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6627 getRangeRef(P, SignHint);
6628
6629 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6630 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6631 PendingPhiRangesIter.erase(P);
6632 }
6633 }
6634
6635 return getRangeRef(S, SignHint, 0);
6636}
6637
6638/// Determine the range for a particular SCEV. If SignHint is
6639/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6640/// with a "cleaner" unsigned (resp. signed) representation.
6641const ConstantRange &ScalarEvolution::getRangeRef(
6642 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6644 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6645 : SignedRanges;
6647 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6649
6650 // See if we've computed this range already.
6652 if (I != Cache.end())
6653 return I->second;
6654
6655 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6656 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6657
6658 // Switch to iteratively computing the range for S, if it is part of a deeply
6659 // nested expression.
6661 return getRangeRefIter(S, SignHint);
6662
6663 unsigned BitWidth = getTypeSizeInBits(S->getType());
6664 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6665 using OBO = OverflowingBinaryOperator;
6666
6667 // If the value has known zeros, the maximum value will have those known zeros
6668 // as well.
6669 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6670 APInt Multiple = getNonZeroConstantMultiple(S);
6671 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6672 if (!Remainder.isZero())
6673 ConservativeResult =
6675 APInt::getMaxValue(BitWidth) - Remainder + 1);
6676 }
6677 else {
6679 if (TZ != 0) {
6680 ConservativeResult = ConstantRange(
6682 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6683 }
6684 }
6685
6686 switch (S->getSCEVType()) {
6687 case scConstant:
6688 llvm_unreachable("Already handled above.");
6689 case scVScale:
6690 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6691 case scTruncate: {
6692 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6693 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6694 return setRange(
6695 Trunc, SignHint,
6696 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6697 }
6698 case scZeroExtend: {
6699 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6700 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6701 return setRange(
6702 ZExt, SignHint,
6703 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6704 }
6705 case scSignExtend: {
6706 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6707 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6708 return setRange(
6709 SExt, SignHint,
6710 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6711 }
6712 case scPtrToInt: {
6713 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6714 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6715 return setRange(PtrToInt, SignHint, X);
6716 }
6717 case scAddExpr: {
6718 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6719 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6720 unsigned WrapType = OBO::AnyWrap;
6721 if (Add->hasNoSignedWrap())
6722 WrapType |= OBO::NoSignedWrap;
6723 if (Add->hasNoUnsignedWrap())
6724 WrapType |= OBO::NoUnsignedWrap;
6725 for (const SCEV *Op : drop_begin(Add->operands()))
6726 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6727 RangeType);
6728 return setRange(Add, SignHint,
6729 ConservativeResult.intersectWith(X, RangeType));
6730 }
6731 case scMulExpr: {
6732 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6733 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6734 for (const SCEV *Op : drop_begin(Mul->operands()))
6735 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6736 return setRange(Mul, SignHint,
6737 ConservativeResult.intersectWith(X, RangeType));
6738 }
6739 case scUDivExpr: {
6740 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6741 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6742 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6743 return setRange(UDiv, SignHint,
6744 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6745 }
6746 case scAddRecExpr: {
6747 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6748 // If there's no unsigned wrap, the value will never be less than its
6749 // initial value.
6750 if (AddRec->hasNoUnsignedWrap()) {
6751 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6752 if (!UnsignedMinValue.isZero())
6753 ConservativeResult = ConservativeResult.intersectWith(
6754 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6755 }
6756
6757 // If there's no signed wrap, and all the operands except initial value have
6758 // the same sign or zero, the value won't ever be:
6759 // 1: smaller than initial value if operands are non negative,
6760 // 2: bigger than initial value if operands are non positive.
6761 // For both cases, value can not cross signed min/max boundary.
6762 if (AddRec->hasNoSignedWrap()) {
6763 bool AllNonNeg = true;
6764 bool AllNonPos = true;
6765 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6766 if (!isKnownNonNegative(AddRec->getOperand(i)))
6767 AllNonNeg = false;
6768 if (!isKnownNonPositive(AddRec->getOperand(i)))
6769 AllNonPos = false;
6770 }
6771 if (AllNonNeg)
6772 ConservativeResult = ConservativeResult.intersectWith(
6775 RangeType);
6776 else if (AllNonPos)
6777 ConservativeResult = ConservativeResult.intersectWith(
6779 getSignedRangeMax(AddRec->getStart()) +
6780 1),
6781 RangeType);
6782 }
6783
6784 // TODO: non-affine addrec
6785 if (AddRec->isAffine()) {
6786 const SCEV *MaxBEScev =
6788 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6789 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6790
6791 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6792 // MaxBECount's active bits are all <= AddRec's bit width.
6793 if (MaxBECount.getBitWidth() > BitWidth &&
6794 MaxBECount.getActiveBits() <= BitWidth)
6795 MaxBECount = MaxBECount.trunc(BitWidth);
6796 else if (MaxBECount.getBitWidth() < BitWidth)
6797 MaxBECount = MaxBECount.zext(BitWidth);
6798
6799 if (MaxBECount.getBitWidth() == BitWidth) {
6800 auto RangeFromAffine = getRangeForAffineAR(
6801 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6802 ConservativeResult =
6803 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6804
6805 auto RangeFromFactoring = getRangeViaFactoring(
6806 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6807 ConservativeResult =
6808 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6809 }
6810 }
6811
6812 // Now try symbolic BE count and more powerful methods.
6814 const SCEV *SymbolicMaxBECount =
6816 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6817 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6818 AddRec->hasNoSelfWrap()) {
6819 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6820 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6821 ConservativeResult =
6822 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6823 }
6824 }
6825 }
6826
6827 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6828 }
6829 case scUMaxExpr:
6830 case scSMaxExpr:
6831 case scUMinExpr:
6832 case scSMinExpr:
6833 case scSequentialUMinExpr: {
6835 switch (S->getSCEVType()) {
6836 case scUMaxExpr:
6837 ID = Intrinsic::umax;
6838 break;
6839 case scSMaxExpr:
6840 ID = Intrinsic::smax;
6841 break;
6842 case scUMinExpr:
6844 ID = Intrinsic::umin;
6845 break;
6846 case scSMinExpr:
6847 ID = Intrinsic::smin;
6848 break;
6849 default:
6850 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6851 }
6852
6853 const auto *NAry = cast<SCEVNAryExpr>(S);
6854 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6855 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6856 X = X.intrinsic(
6857 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6858 return setRange(S, SignHint,
6859 ConservativeResult.intersectWith(X, RangeType));
6860 }
6861 case scUnknown: {
6862 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6863 Value *V = U->getValue();
6864
6865 // Check if the IR explicitly contains !range metadata.
6866 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6867 if (MDRange)
6868 ConservativeResult =
6869 ConservativeResult.intersectWith(*MDRange, RangeType);
6870
6871 // Use facts about recurrences in the underlying IR. Note that add
6872 // recurrences are AddRecExprs and thus don't hit this path. This
6873 // primarily handles shift recurrences.
6874 auto CR = getRangeForUnknownRecurrence(U);
6875 ConservativeResult = ConservativeResult.intersectWith(CR);
6876
6877 // See if ValueTracking can give us a useful range.
6878 const DataLayout &DL = getDataLayout();
6879 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6880 if (Known.getBitWidth() != BitWidth)
6881 Known = Known.zextOrTrunc(BitWidth);
6882
6883 // ValueTracking may be able to compute a tighter result for the number of
6884 // sign bits than for the value of those sign bits.
6885 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6886 if (U->getType()->isPointerTy()) {
6887 // If the pointer size is larger than the index size type, this can cause
6888 // NS to be larger than BitWidth. So compensate for this.
6889 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6890 int ptrIdxDiff = ptrSize - BitWidth;
6891 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6892 NS -= ptrIdxDiff;
6893 }
6894
6895 if (NS > 1) {
6896 // If we know any of the sign bits, we know all of the sign bits.
6897 if (!Known.Zero.getHiBits(NS).isZero())
6898 Known.Zero.setHighBits(NS);
6899 if (!Known.One.getHiBits(NS).isZero())
6900 Known.One.setHighBits(NS);
6901 }
6902
6903 if (Known.getMinValue() != Known.getMaxValue() + 1)
6904 ConservativeResult = ConservativeResult.intersectWith(
6905 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6906 RangeType);
6907 if (NS > 1)
6908 ConservativeResult = ConservativeResult.intersectWith(
6910 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6911 RangeType);
6912
6913 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6914 // Strengthen the range if the underlying IR value is a
6915 // global/alloca/heap allocation using the size of the object.
6916 bool CanBeNull, CanBeFreed;
6917 uint64_t DerefBytes =
6918 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6919 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6920 // The highest address the object can start is DerefBytes bytes before
6921 // the end (unsigned max value). If this value is not a multiple of the
6922 // alignment, the last possible start value is the next lowest multiple
6923 // of the alignment. Note: The computations below cannot overflow,
6924 // because if they would there's no possible start address for the
6925 // object.
6926 APInt MaxVal =
6927 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6928 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6929 uint64_t Rem = MaxVal.urem(Align);
6930 MaxVal -= APInt(BitWidth, Rem);
6931 APInt MinVal = APInt::getZero(BitWidth);
6932 if (llvm::isKnownNonZero(V, DL))
6933 MinVal = Align;
6934 ConservativeResult = ConservativeResult.intersectWith(
6935 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6936 }
6937 }
6938
6939 // A range of Phi is a subset of union of all ranges of its input.
6940 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6941 // Make sure that we do not run over cycled Phis.
6942 if (PendingPhiRanges.insert(Phi).second) {
6943 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6944
6945 for (const auto &Op : Phi->operands()) {
6946 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6947 RangeFromOps = RangeFromOps.unionWith(OpRange);
6948 // No point to continue if we already have a full set.
6949 if (RangeFromOps.isFullSet())
6950 break;
6951 }
6952 ConservativeResult =
6953 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6954 bool Erased = PendingPhiRanges.erase(Phi);
6955 assert(Erased && "Failed to erase Phi properly?");
6956 (void)Erased;
6957 }
6958 }
6959
6960 // vscale can't be equal to zero
6961 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6962 if (II->getIntrinsicID() == Intrinsic::vscale) {
6964 ConservativeResult = ConservativeResult.difference(Disallowed);
6965 }
6966
6967 return setRange(U, SignHint, std::move(ConservativeResult));
6968 }
6969 case scCouldNotCompute:
6970 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6971 }
6972
6973 return setRange(S, SignHint, std::move(ConservativeResult));
6974}
6975
6976// Given a StartRange, Step and MaxBECount for an expression compute a range of
6977// values that the expression can take. Initially, the expression has a value
6978// from StartRange and then is changed by Step up to MaxBECount times. Signed
6979// argument defines if we treat Step as signed or unsigned.
6981 const ConstantRange &StartRange,
6982 const APInt &MaxBECount,
6983 bool Signed) {
6984 unsigned BitWidth = Step.getBitWidth();
6985 assert(BitWidth == StartRange.getBitWidth() &&
6986 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6987 // If either Step or MaxBECount is 0, then the expression won't change, and we
6988 // just need to return the initial range.
6989 if (Step == 0 || MaxBECount == 0)
6990 return StartRange;
6991
6992 // If we don't know anything about the initial value (i.e. StartRange is
6993 // FullRange), then we don't know anything about the final range either.
6994 // Return FullRange.
6995 if (StartRange.isFullSet())
6996 return ConstantRange::getFull(BitWidth);
6997
6998 // If Step is signed and negative, then we use its absolute value, but we also
6999 // note that we're moving in the opposite direction.
7000 bool Descending = Signed && Step.isNegative();
7001
7002 if (Signed)
7003 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7004 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7005 // This equations hold true due to the well-defined wrap-around behavior of
7006 // APInt.
7007 Step = Step.abs();
7008
7009 // Check if Offset is more than full span of BitWidth. If it is, the
7010 // expression is guaranteed to overflow.
7011 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7012 return ConstantRange::getFull(BitWidth);
7013
7014 // Offset is by how much the expression can change. Checks above guarantee no
7015 // overflow here.
7016 APInt Offset = Step * MaxBECount;
7017
7018 // Minimum value of the final range will match the minimal value of StartRange
7019 // if the expression is increasing and will be decreased by Offset otherwise.
7020 // Maximum value of the final range will match the maximal value of StartRange
7021 // if the expression is decreasing and will be increased by Offset otherwise.
7022 APInt StartLower = StartRange.getLower();
7023 APInt StartUpper = StartRange.getUpper() - 1;
7024 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7025 : (StartUpper + std::move(Offset));
7026
7027 // It's possible that the new minimum/maximum value will fall into the initial
7028 // range (due to wrap around). This means that the expression can take any
7029 // value in this bitwidth, and we have to return full range.
7030 if (StartRange.contains(MovedBoundary))
7031 return ConstantRange::getFull(BitWidth);
7032
7033 APInt NewLower =
7034 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7035 APInt NewUpper =
7036 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7037 NewUpper += 1;
7038
7039 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7040 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7041}
7042
7043ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7044 const SCEV *Step,
7045 const APInt &MaxBECount) {
7046 assert(getTypeSizeInBits(Start->getType()) ==
7047 getTypeSizeInBits(Step->getType()) &&
7048 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7049 "mismatched bit widths");
7050
7051 // First, consider step signed.
7052 ConstantRange StartSRange = getSignedRange(Start);
7053 ConstantRange StepSRange = getSignedRange(Step);
7054
7055 // If Step can be both positive and negative, we need to find ranges for the
7056 // maximum absolute step values in both directions and union them.
7058 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7060 StartSRange, MaxBECount,
7061 /* Signed = */ true));
7062
7063 // Next, consider step unsigned.
7065 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7066 /* Signed = */ false);
7067
7068 // Finally, intersect signed and unsigned ranges.
7070}
7071
7072ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7073 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7074 ScalarEvolution::RangeSignHint SignHint) {
7075 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7076 assert(AddRec->hasNoSelfWrap() &&
7077 "This only works for non-self-wrapping AddRecs!");
7078 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7079 const SCEV *Step = AddRec->getStepRecurrence(*this);
7080 // Only deal with constant step to save compile time.
7081 if (!isa<SCEVConstant>(Step))
7082 return ConstantRange::getFull(BitWidth);
7083 // Let's make sure that we can prove that we do not self-wrap during
7084 // MaxBECount iterations. We need this because MaxBECount is a maximum
7085 // iteration count estimate, and we might infer nw from some exit for which we
7086 // do not know max exit count (or any other side reasoning).
7087 // TODO: Turn into assert at some point.
7088 if (getTypeSizeInBits(MaxBECount->getType()) >
7089 getTypeSizeInBits(AddRec->getType()))
7090 return ConstantRange::getFull(BitWidth);
7091 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7092 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7093 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7094 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7095 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7096 MaxItersWithoutWrap))
7097 return ConstantRange::getFull(BitWidth);
7098
7099 ICmpInst::Predicate LEPred =
7101 ICmpInst::Predicate GEPred =
7103 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7104
7105 // We know that there is no self-wrap. Let's take Start and End values and
7106 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7107 // the iteration. They either lie inside the range [Min(Start, End),
7108 // Max(Start, End)] or outside it:
7109 //
7110 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7111 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7112 //
7113 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7114 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7115 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7116 // Start <= End and step is positive, or Start >= End and step is negative.
7117 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7118 ConstantRange StartRange = getRangeRef(Start, SignHint);
7119 ConstantRange EndRange = getRangeRef(End, SignHint);
7120 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7121 // If they already cover full iteration space, we will know nothing useful
7122 // even if we prove what we want to prove.
7123 if (RangeBetween.isFullSet())
7124 return RangeBetween;
7125 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7126 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7127 : RangeBetween.isWrappedSet();
7128 if (IsWrappedSet)
7129 return ConstantRange::getFull(BitWidth);
7130
7131 if (isKnownPositive(Step) &&
7132 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7133 return RangeBetween;
7134 if (isKnownNegative(Step) &&
7135 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7136 return RangeBetween;
7137 return ConstantRange::getFull(BitWidth);
7138}
7139
7140ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7141 const SCEV *Step,
7142 const APInt &MaxBECount) {
7143 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7144 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7145
7146 unsigned BitWidth = MaxBECount.getBitWidth();
7147 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7148 getTypeSizeInBits(Step->getType()) == BitWidth &&
7149 "mismatched bit widths");
7150
7151 struct SelectPattern {
7152 Value *Condition = nullptr;
7153 APInt TrueValue;
7154 APInt FalseValue;
7155
7156 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7157 const SCEV *S) {
7158 std::optional<unsigned> CastOp;
7159 APInt Offset(BitWidth, 0);
7160
7162 "Should be!");
7163
7164 // Peel off a constant offset:
7165 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7166 // In the future we could consider being smarter here and handle
7167 // {Start+Step,+,Step} too.
7168 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7169 return;
7170
7171 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7172 S = SA->getOperand(1);
7173 }
7174
7175 // Peel off a cast operation
7176 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7177 CastOp = SCast->getSCEVType();
7178 S = SCast->getOperand();
7179 }
7180
7181 using namespace llvm::PatternMatch;
7182
7183 auto *SU = dyn_cast<SCEVUnknown>(S);
7184 const APInt *TrueVal, *FalseVal;
7185 if (!SU ||
7186 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7187 m_APInt(FalseVal)))) {
7188 Condition = nullptr;
7189 return;
7190 }
7191
7192 TrueValue = *TrueVal;
7193 FalseValue = *FalseVal;
7194
7195 // Re-apply the cast we peeled off earlier
7196 if (CastOp)
7197 switch (*CastOp) {
7198 default:
7199 llvm_unreachable("Unknown SCEV cast type!");
7200
7201 case scTruncate:
7202 TrueValue = TrueValue.trunc(BitWidth);
7203 FalseValue = FalseValue.trunc(BitWidth);
7204 break;
7205 case scZeroExtend:
7206 TrueValue = TrueValue.zext(BitWidth);
7207 FalseValue = FalseValue.zext(BitWidth);
7208 break;
7209 case scSignExtend:
7210 TrueValue = TrueValue.sext(BitWidth);
7211 FalseValue = FalseValue.sext(BitWidth);
7212 break;
7213 }
7214
7215 // Re-apply the constant offset we peeled off earlier
7216 TrueValue += Offset;
7217 FalseValue += Offset;
7218 }
7219
7220 bool isRecognized() { return Condition != nullptr; }
7221 };
7222
7223 SelectPattern StartPattern(*this, BitWidth, Start);
7224 if (!StartPattern.isRecognized())
7225 return ConstantRange::getFull(BitWidth);
7226
7227 SelectPattern StepPattern(*this, BitWidth, Step);
7228 if (!StepPattern.isRecognized())
7229 return ConstantRange::getFull(BitWidth);
7230
7231 if (StartPattern.Condition != StepPattern.Condition) {
7232 // We don't handle this case today; but we could, by considering four
7233 // possibilities below instead of two. I'm not sure if there are cases where
7234 // that will help over what getRange already does, though.
7235 return ConstantRange::getFull(BitWidth);
7236 }
7237
7238 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7239 // construct arbitrary general SCEV expressions here. This function is called
7240 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7241 // say) can end up caching a suboptimal value.
7242
7243 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7244 // C2352 and C2512 (otherwise it isn't needed).
7245
7246 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7247 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7248 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7249 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7250
7251 ConstantRange TrueRange =
7252 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7253 ConstantRange FalseRange =
7254 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7255
7256 return TrueRange.unionWith(FalseRange);
7257}
7258
7259SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7260 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7261 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7262
7263 // Return early if there are no flags to propagate to the SCEV.
7265 if (BinOp->hasNoUnsignedWrap())
7267 if (BinOp->hasNoSignedWrap())
7269 if (Flags == SCEV::FlagAnyWrap)
7270 return SCEV::FlagAnyWrap;
7271
7272 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7273}
7274
7275const Instruction *
7276ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7277 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7278 return &*AddRec->getLoop()->getHeader()->begin();
7279 if (auto *U = dyn_cast<SCEVUnknown>(S))
7280 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7281 return I;
7282 return nullptr;
7283}
7284
7285const Instruction *
7286ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7287 bool &Precise) {
7288 Precise = true;
7289 // Do a bounded search of the def relation of the requested SCEVs.
7292 auto pushOp = [&](const SCEV *S) {
7293 if (!Visited.insert(S).second)
7294 return;
7295 // Threshold of 30 here is arbitrary.
7296 if (Visited.size() > 30) {
7297 Precise = false;
7298 return;
7299 }
7300 Worklist.push_back(S);
7301 };
7302
7303 for (const auto *S : Ops)
7304 pushOp(S);
7305
7306 const Instruction *Bound = nullptr;
7307 while (!Worklist.empty()) {
7308 auto *S = Worklist.pop_back_val();
7309 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7310 if (!Bound || DT.dominates(Bound, DefI))
7311 Bound = DefI;
7312 } else {
7313 for (const auto *Op : S->operands())
7314 pushOp(Op);
7315 }
7316 }
7317 return Bound ? Bound : &*F.getEntryBlock().begin();
7318}
7319
7320const Instruction *
7321ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7322 bool Discard;
7323 return getDefiningScopeBound(Ops, Discard);
7324}
7325
7326bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7327 const Instruction *B) {
7328 if (A->getParent() == B->getParent() &&
7330 B->getIterator()))
7331 return true;
7332
7333 auto *BLoop = LI.getLoopFor(B->getParent());
7334 if (BLoop && BLoop->getHeader() == B->getParent() &&
7335 BLoop->getLoopPreheader() == A->getParent() &&
7337 A->getParent()->end()) &&
7338 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7339 B->getIterator()))
7340 return true;
7341 return false;
7342}
7343
7344bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7345 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7346 visitAll(Op, PC);
7347 return PC.MaybePoison.empty();
7348}
7349
7350bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7351 return !SCEVExprContains(Op, [this](const SCEV *S) {
7352 auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7353 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7354 // is a non-zero constant, we have to assume the UDiv may be UB.
7355 return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7356 !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7357 });
7358}
7359
7360bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7361 // Only proceed if we can prove that I does not yield poison.
7363 return false;
7364
7365 // At this point we know that if I is executed, then it does not wrap
7366 // according to at least one of NSW or NUW. If I is not executed, then we do
7367 // not know if the calculation that I represents would wrap. Multiple
7368 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7369 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7370 // derived from other instructions that map to the same SCEV. We cannot make
7371 // that guarantee for cases where I is not executed. So we need to find a
7372 // upper bound on the defining scope for the SCEV, and prove that I is
7373 // executed every time we enter that scope. When the bounding scope is a
7374 // loop (the common case), this is equivalent to proving I executes on every
7375 // iteration of that loop.
7377 for (const Use &Op : I->operands()) {
7378 // I could be an extractvalue from a call to an overflow intrinsic.
7379 // TODO: We can do better here in some cases.
7380 if (isSCEVable(Op->getType()))
7381 SCEVOps.push_back(getSCEV(Op));
7382 }
7383 auto *DefI = getDefiningScopeBound(SCEVOps);
7384 return isGuaranteedToTransferExecutionTo(DefI, I);
7385}
7386
7387bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7388 // If we know that \c I can never be poison period, then that's enough.
7389 if (isSCEVExprNeverPoison(I))
7390 return true;
7391
7392 // If the loop only has one exit, then we know that, if the loop is entered,
7393 // any instruction dominating that exit will be executed. If any such
7394 // instruction would result in UB, the addrec cannot be poison.
7395 //
7396 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7397 // also handles uses outside the loop header (they just need to dominate the
7398 // single exit).
7399
7400 auto *ExitingBB = L->getExitingBlock();
7401 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7402 return false;
7403
7406
7407 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7408 // things that are known to be poison under that assumption go on the
7409 // Worklist.
7410 KnownPoison.insert(I);
7411 Worklist.push_back(I);
7412
7413 while (!Worklist.empty()) {
7414 const Instruction *Poison = Worklist.pop_back_val();
7415
7416 for (const Use &U : Poison->uses()) {
7417 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7418 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7419 DT.dominates(PoisonUser->getParent(), ExitingBB))
7420 return true;
7421
7422 if (propagatesPoison(U) && L->contains(PoisonUser))
7423 if (KnownPoison.insert(PoisonUser).second)
7424 Worklist.push_back(PoisonUser);
7425 }
7426 }
7427
7428 return false;
7429}
7430
7431ScalarEvolution::LoopProperties
7432ScalarEvolution::getLoopProperties(const Loop *L) {
7433 using LoopProperties = ScalarEvolution::LoopProperties;
7434
7435 auto Itr = LoopPropertiesCache.find(L);
7436 if (Itr == LoopPropertiesCache.end()) {
7437 auto HasSideEffects = [](Instruction *I) {
7438 if (auto *SI = dyn_cast<StoreInst>(I))
7439 return !SI->isSimple();
7440
7441 return I->mayThrow() || I->mayWriteToMemory();
7442 };
7443
7444 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7445 /*HasNoSideEffects*/ true};
7446
7447 for (auto *BB : L->getBlocks())
7448 for (auto &I : *BB) {
7450 LP.HasNoAbnormalExits = false;
7451 if (HasSideEffects(&I))
7452 LP.HasNoSideEffects = false;
7453 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7454 break; // We're already as pessimistic as we can get.
7455 }
7456
7457 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7458 assert(InsertPair.second && "We just checked!");
7459 Itr = InsertPair.first;
7460 }
7461
7462 return Itr->second;
7463}
7464
7466 // A mustprogress loop without side effects must be finite.
7467 // TODO: The check used here is very conservative. It's only *specific*
7468 // side effects which are well defined in infinite loops.
7469 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7470}
7471
7472const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7473 // Worklist item with a Value and a bool indicating whether all operands have
7474 // been visited already.
7477
7478 Stack.emplace_back(V, true);
7479 Stack.emplace_back(V, false);
7480 while (!Stack.empty()) {
7481 auto E = Stack.pop_back_val();
7482 Value *CurV = E.getPointer();
7483
7484 if (getExistingSCEV(CurV))
7485 continue;
7486
7488 const SCEV *CreatedSCEV = nullptr;
7489 // If all operands have been visited already, create the SCEV.
7490 if (E.getInt()) {
7491 CreatedSCEV = createSCEV(CurV);
7492 } else {
7493 // Otherwise get the operands we need to create SCEV's for before creating
7494 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7495 // just use it.
7496 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7497 }
7498
7499 if (CreatedSCEV) {
7500 insertValueToMap(CurV, CreatedSCEV);
7501 } else {
7502 // Queue CurV for SCEV creation, followed by its's operands which need to
7503 // be constructed first.
7504 Stack.emplace_back(CurV, true);
7505 for (Value *Op : Ops)
7506 Stack.emplace_back(Op, false);
7507 }
7508 }
7509
7510 return getExistingSCEV(V);
7511}
7512
7513const SCEV *
7514ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7515 if (!isSCEVable(V->getType()))
7516 return getUnknown(V);
7517
7518 if (Instruction *I = dyn_cast<Instruction>(V)) {
7519 // Don't attempt to analyze instructions in blocks that aren't
7520 // reachable. Such instructions don't matter, and they aren't required
7521 // to obey basic rules for definitions dominating uses which this
7522 // analysis depends on.
7523 if (!DT.isReachableFromEntry(I->getParent()))
7524 return getUnknown(PoisonValue::get(V->getType()));
7525 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7526 return getConstant(CI);
7527 else if (isa<GlobalAlias>(V))
7528 return getUnknown(V);
7529 else if (!isa<ConstantExpr>(V))
7530 return getUnknown(V);
7531
7532 Operator *U = cast<Operator>(V);
7533 if (auto BO =
7534 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7535 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7536 switch (BO->Opcode) {
7537 case Instruction::Add:
7538 case Instruction::Mul: {
7539 // For additions and multiplications, traverse add/mul chains for which we
7540 // can potentially create a single SCEV, to reduce the number of
7541 // get{Add,Mul}Expr calls.
7542 do {
7543 if (BO->Op) {
7544 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7545 Ops.push_back(BO->Op);
7546 break;
7547 }
7548 }
7549 Ops.push_back(BO->RHS);
7550 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7551 dyn_cast<Instruction>(V));
7552 if (!NewBO ||
7553 (BO->Opcode == Instruction::Add &&
7554 (NewBO->Opcode != Instruction::Add &&
7555 NewBO->Opcode != Instruction::Sub)) ||
7556 (BO->Opcode == Instruction::Mul &&
7557 NewBO->Opcode != Instruction::Mul)) {
7558 Ops.push_back(BO->LHS);
7559 break;
7560 }
7561 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7562 // requires a SCEV for the LHS.
7563 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7564 auto *I = dyn_cast<Instruction>(BO->Op);
7565 if (I && programUndefinedIfPoison(I)) {
7566 Ops.push_back(BO->LHS);
7567 break;
7568 }
7569 }
7570 BO = NewBO;
7571 } while (true);
7572 return nullptr;
7573 }
7574 case Instruction::Sub:
7575 case Instruction::UDiv:
7576 case Instruction::URem:
7577 break;
7578 case Instruction::AShr:
7579 case Instruction::Shl:
7580 case Instruction::Xor:
7581 if (!IsConstArg)
7582 return nullptr;
7583 break;
7584 case Instruction::And:
7585 case Instruction::Or:
7586 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7587 return nullptr;
7588 break;
7589 case Instruction::LShr:
7590 return getUnknown(V);
7591 default:
7592 llvm_unreachable("Unhandled binop");
7593 break;
7594 }
7595
7596 Ops.push_back(BO->LHS);
7597 Ops.push_back(BO->RHS);
7598 return nullptr;
7599 }
7600
7601 switch (U->getOpcode()) {
7602 case Instruction::Trunc:
7603 case Instruction::ZExt:
7604 case Instruction::SExt:
7605 case Instruction::PtrToInt:
7606 Ops.push_back(U->getOperand(0));
7607 return nullptr;
7608
7609 case Instruction::BitCast:
7610 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7611 Ops.push_back(U->getOperand(0));
7612 return nullptr;
7613 }
7614 return getUnknown(V);
7615
7616 case Instruction::SDiv:
7617 case Instruction::SRem:
7618 Ops.push_back(U->getOperand(0));
7619 Ops.push_back(U->getOperand(1));
7620 return nullptr;
7621
7622 case Instruction::GetElementPtr:
7623 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7624 "GEP source element type must be sized");
7625 for (Value *Index : U->operands())
7626 Ops.push_back(Index);
7627 return nullptr;
7628
7629 case Instruction::IntToPtr:
7630 return getUnknown(V);
7631
7632 case Instruction::PHI:
7633 // Keep constructing SCEVs' for phis recursively for now.
7634 return nullptr;
7635
7636 case Instruction::Select: {
7637 // Check if U is a select that can be simplified to a SCEVUnknown.
7638 auto CanSimplifyToUnknown = [this, U]() {
7639 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7640 return false;
7641
7642 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7643 if (!ICI)
7644 return false;
7645 Value *LHS = ICI->getOperand(0);
7646 Value *RHS = ICI->getOperand(1);
7647 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7648 ICI->getPredicate() == CmpInst::ICMP_NE) {
7649 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7650 return true;
7651 } else if (getTypeSizeInBits(LHS->getType()) >
7652 getTypeSizeInBits(U->getType()))
7653 return true;
7654 return false;
7655 };
7656 if (CanSimplifyToUnknown())
7657 return getUnknown(U);
7658
7659 for (Value *Inc : U->operands())
7660 Ops.push_back(Inc);
7661 return nullptr;
7662 break;
7663 }
7664 case Instruction::Call:
7665 case Instruction::Invoke:
7666 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7667 Ops.push_back(RV);
7668 return nullptr;
7669 }
7670
7671 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7672 switch (II->getIntrinsicID()) {
7673 case Intrinsic::abs:
7674 Ops.push_back(II->getArgOperand(0));
7675 return nullptr;
7676 case Intrinsic::umax:
7677 case Intrinsic::umin:
7678 case Intrinsic::smax:
7679 case Intrinsic::smin:
7680 case Intrinsic::usub_sat:
7681 case Intrinsic::uadd_sat:
7682 Ops.push_back(II->getArgOperand(0));
7683 Ops.push_back(II->getArgOperand(1));
7684 return nullptr;
7685 case Intrinsic::start_loop_iterations:
7686 case Intrinsic::annotation:
7687 case Intrinsic::ptr_annotation:
7688 Ops.push_back(II->getArgOperand(0));
7689 return nullptr;
7690 default:
7691 break;
7692 }
7693 }
7694 break;
7695 }
7696
7697 return nullptr;
7698}
7699
7700const SCEV *ScalarEvolution::createSCEV(Value *V) {
7701 if (!isSCEVable(V->getType()))
7702 return getUnknown(V);
7703
7704 if (Instruction *I = dyn_cast<Instruction>(V)) {
7705 // Don't attempt to analyze instructions in blocks that aren't
7706 // reachable. Such instructions don't matter, and they aren't required
7707 // to obey basic rules for definitions dominating uses which this
7708 // analysis depends on.
7709 if (!DT.isReachableFromEntry(I->getParent()))
7710 return getUnknown(PoisonValue::get(V->getType()));
7711 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7712 return getConstant(CI);
7713 else if (isa<GlobalAlias>(V))
7714 return getUnknown(V);
7715 else if (!isa<ConstantExpr>(V))
7716 return getUnknown(V);
7717
7718 const SCEV *LHS;
7719 const SCEV *RHS;
7720
7721 Operator *U = cast<Operator>(V);
7722 if (auto BO =
7723 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7724 switch (BO->Opcode) {
7725 case Instruction::Add: {
7726 // The simple thing to do would be to just call getSCEV on both operands
7727 // and call getAddExpr with the result. However if we're looking at a
7728 // bunch of things all added together, this can be quite inefficient,
7729 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7730 // Instead, gather up all the operands and make a single getAddExpr call.
7731 // LLVM IR canonical form means we need only traverse the left operands.
7733 do {
7734 if (BO->Op) {
7735 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7736 AddOps.push_back(OpSCEV);
7737 break;
7738 }
7739
7740 // If a NUW or NSW flag can be applied to the SCEV for this
7741 // addition, then compute the SCEV for this addition by itself
7742 // with a separate call to getAddExpr. We need to do that
7743 // instead of pushing the operands of the addition onto AddOps,
7744 // since the flags are only known to apply to this particular
7745 // addition - they may not apply to other additions that can be
7746 // formed with operands from AddOps.
7747 const SCEV *RHS = getSCEV(BO->RHS);
7748 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7749 if (Flags != SCEV::FlagAnyWrap) {
7750 const SCEV *LHS = getSCEV(BO->LHS);
7751 if (BO->Opcode == Instruction::Sub)
7752 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7753 else
7754 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7755 break;
7756 }
7757 }
7758
7759 if (BO->Opcode == Instruction::Sub)
7760 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7761 else
7762 AddOps.push_back(getSCEV(BO->RHS));
7763
7764 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7765 dyn_cast<Instruction>(V));
7766 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7767 NewBO->Opcode != Instruction::Sub)) {
7768 AddOps.push_back(getSCEV(BO->LHS));
7769 break;
7770 }
7771 BO = NewBO;
7772 } while (true);
7773
7774 return getAddExpr(AddOps);
7775 }
7776
7777 case Instruction::Mul: {
7779 do {
7780 if (BO->Op) {
7781 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7782 MulOps.push_back(OpSCEV);
7783 break;
7784 }
7785
7786 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7787 if (Flags != SCEV::FlagAnyWrap) {
7788 LHS = getSCEV(BO->LHS);
7789 RHS = getSCEV(BO->RHS);
7790 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7791 break;
7792 }
7793 }
7794
7795 MulOps.push_back(getSCEV(BO->RHS));
7796 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7797 dyn_cast<Instruction>(V));
7798 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7799 MulOps.push_back(getSCEV(BO->LHS));
7800 break;
7801 }
7802 BO = NewBO;
7803 } while (true);
7804
7805 return getMulExpr(MulOps);
7806 }
7807 case Instruction::UDiv:
7808 LHS = getSCEV(BO->LHS);
7809 RHS = getSCEV(BO->RHS);
7810 return getUDivExpr(LHS, RHS);
7811 case Instruction::URem:
7812 LHS = getSCEV(BO->LHS);
7813 RHS = getSCEV(BO->RHS);
7814 return getURemExpr(LHS, RHS);
7815 case Instruction::Sub: {
7817 if (BO->Op)
7818 Flags = getNoWrapFlagsFromUB(BO->Op);
7819 LHS = getSCEV(BO->LHS);
7820 RHS = getSCEV(BO->RHS);
7821 return getMinusSCEV(LHS, RHS, Flags);
7822 }
7823 case Instruction::And:
7824 // For an expression like x&255 that merely masks off the high bits,
7825 // use zext(trunc(x)) as the SCEV expression.
7826 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7827 if (CI->isZero())
7828 return getSCEV(BO->RHS);
7829 if (CI->isMinusOne())
7830 return getSCEV(BO->LHS);
7831 const APInt &A = CI->getValue();
7832
7833 // Instcombine's ShrinkDemandedConstant may strip bits out of
7834 // constants, obscuring what would otherwise be a low-bits mask.
7835 // Use computeKnownBits to compute what ShrinkDemandedConstant
7836 // knew about to reconstruct a low-bits mask value.
7837 unsigned LZ = A.countl_zero();
7838 unsigned TZ = A.countr_zero();
7839 unsigned BitWidth = A.getBitWidth();
7840 KnownBits Known(BitWidth);
7841 computeKnownBits(BO->LHS, Known, getDataLayout(),
7842 0, &AC, nullptr, &DT);
7843
7844 APInt EffectiveMask =
7845 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7846 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7847 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7848 const SCEV *LHS = getSCEV(BO->LHS);
7849 const SCEV *ShiftedLHS = nullptr;
7850 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7851 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7852 // For an expression like (x * 8) & 8, simplify the multiply.
7853 unsigned MulZeros = OpC->getAPInt().countr_zero();
7854 unsigned GCD = std::min(MulZeros, TZ);
7855 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7857 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7858 append_range(MulOps, LHSMul->operands().drop_front());
7859 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7860 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7861 }
7862 }
7863 if (!ShiftedLHS)
7864 ShiftedLHS = getUDivExpr(LHS, MulCount);
7865 return getMulExpr(
7867 getTruncateExpr(ShiftedLHS,
7868 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7869 BO->LHS->getType()),
7870 MulCount);
7871 }
7872 }
7873 // Binary `and` is a bit-wise `umin`.
7874 if (BO->LHS->getType()->isIntegerTy(1)) {
7875 LHS = getSCEV(BO->LHS);
7876 RHS = getSCEV(BO->RHS);
7877 return getUMinExpr(LHS, RHS);
7878 }
7879 break;
7880
7881 case Instruction::Or:
7882 // Binary `or` is a bit-wise `umax`.
7883 if (BO->LHS->getType()->isIntegerTy(1)) {
7884 LHS = getSCEV(BO->LHS);
7885 RHS = getSCEV(BO->RHS);
7886 return getUMaxExpr(LHS, RHS);
7887 }
7888 break;
7889
7890 case Instruction::Xor:
7891 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7892 // If the RHS of xor is -1, then this is a not operation.
7893 if (CI->isMinusOne())
7894 return getNotSCEV(getSCEV(BO->LHS));
7895
7896 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7897 // This is a variant of the check for xor with -1, and it handles
7898 // the case where instcombine has trimmed non-demanded bits out
7899 // of an xor with -1.
7900 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7901 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7902 if (LBO->getOpcode() == Instruction::And &&
7903 LCI->getValue() == CI->getValue())
7904 if (const SCEVZeroExtendExpr *Z =
7905 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7906 Type *UTy = BO->LHS->getType();
7907 const SCEV *Z0 = Z->getOperand();
7908 Type *Z0Ty = Z0->getType();
7909 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7910
7911 // If C is a low-bits mask, the zero extend is serving to
7912 // mask off the high bits. Complement the operand and
7913 // re-apply the zext.
7914 if (CI->getValue().isMask(Z0TySize))
7915 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7916
7917 // If C is a single bit, it may be in the sign-bit position
7918 // before the zero-extend. In this case, represent the xor
7919 // using an add, which is equivalent, and re-apply the zext.
7920 APInt Trunc = CI->getValue().trunc(Z0TySize);
7921 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7922 Trunc.isSignMask())
7923 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7924 UTy);
7925 }
7926 }
7927 break;
7928
7929 case Instruction::Shl:
7930 // Turn shift left of a constant amount into a multiply.
7931 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7932 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7933
7934 // If the shift count is not less than the bitwidth, the result of
7935 // the shift is undefined. Don't try to analyze it, because the
7936 // resolution chosen here may differ from the resolution chosen in
7937 // other parts of the compiler.
7938 if (SA->getValue().uge(BitWidth))
7939 break;
7940
7941 // We can safely preserve the nuw flag in all cases. It's also safe to
7942 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7943 // requires special handling. It can be preserved as long as we're not
7944 // left shifting by bitwidth - 1.
7945 auto Flags = SCEV::FlagAnyWrap;
7946 if (BO->Op) {
7947 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7948 if ((MulFlags & SCEV::FlagNSW) &&
7949 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7951 if (MulFlags & SCEV::FlagNUW)
7953 }
7954
7955 ConstantInt *X = ConstantInt::get(
7956 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7957 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7958 }
7959 break;
7960
7961 case Instruction::AShr:
7962 // AShr X, C, where C is a constant.
7963 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7964 if (!CI)
7965 break;
7966
7967 Type *OuterTy = BO->LHS->getType();
7969 // If the shift count is not less than the bitwidth, the result of
7970 // the shift is undefined. Don't try to analyze it, because the
7971 // resolution chosen here may differ from the resolution chosen in
7972 // other parts of the compiler.
7973 if (CI->getValue().uge(BitWidth))
7974 break;
7975
7976 if (CI->isZero())
7977 return getSCEV(BO->LHS); // shift by zero --> noop
7978
7979 uint64_t AShrAmt = CI->getZExtValue();
7980 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7981
7982 Operator *L = dyn_cast<Operator>(BO->LHS);
7983 const SCEV *AddTruncateExpr = nullptr;
7984 ConstantInt *ShlAmtCI = nullptr;
7985 const SCEV *AddConstant = nullptr;
7986
7987 if (L && L->getOpcode() == Instruction::Add) {
7988 // X = Shl A, n
7989 // Y = Add X, c
7990 // Z = AShr Y, m
7991 // n, c and m are constants.
7992
7993 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7994 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7995 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7996 if (AddOperandCI) {
7997 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7998 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7999 // since we truncate to TruncTy, the AddConstant should be of the
8000 // same type, so create a new Constant with type same as TruncTy.
8001 // Also, the Add constant should be shifted right by AShr amount.
8002 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8003 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8004 // we model the expression as sext(add(trunc(A), c << n)), since the
8005 // sext(trunc) part is already handled below, we create a
8006 // AddExpr(TruncExp) which will be used later.
8007 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8008 }
8009 }
8010 } else if (L && L->getOpcode() == Instruction::Shl) {
8011 // X = Shl A, n
8012 // Y = AShr X, m
8013 // Both n and m are constant.
8014
8015 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8016 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8017 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8018 }
8019
8020 if (AddTruncateExpr && ShlAmtCI) {
8021 // We can merge the two given cases into a single SCEV statement,
8022 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8023 // a simpler case. The following code handles the two cases:
8024 //
8025 // 1) For a two-shift sext-inreg, i.e. n = m,
8026 // use sext(trunc(x)) as the SCEV expression.
8027 //
8028 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8029 // expression. We already checked that ShlAmt < BitWidth, so
8030 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8031 // ShlAmt - AShrAmt < Amt.
8032 const APInt &ShlAmt = ShlAmtCI->getValue();
8033 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8035 ShlAmtCI->getZExtValue() - AShrAmt);
8036 const SCEV *CompositeExpr =
8037 getMulExpr(AddTruncateExpr, getConstant(Mul));
8038 if (L->getOpcode() != Instruction::Shl)
8039 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8040
8041 return getSignExtendExpr(CompositeExpr, OuterTy);
8042 }
8043 }
8044 break;
8045 }
8046 }
8047
8048 switch (U->getOpcode()) {
8049 case Instruction::Trunc:
8050 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8051
8052 case Instruction::ZExt:
8053 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8054
8055 case Instruction::SExt:
8056 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8057 dyn_cast<Instruction>(V))) {
8058 // The NSW flag of a subtract does not always survive the conversion to
8059 // A + (-1)*B. By pushing sign extension onto its operands we are much
8060 // more likely to preserve NSW and allow later AddRec optimisations.
8061 //
8062 // NOTE: This is effectively duplicating this logic from getSignExtend:
8063 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8064 // but by that point the NSW information has potentially been lost.
8065 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8066 Type *Ty = U->getType();
8067 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8068 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8069 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8070 }
8071 }
8072 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8073
8074 case Instruction::BitCast:
8075 // BitCasts are no-op casts so we just eliminate the cast.
8076 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8077 return getSCEV(U->getOperand(0));
8078 break;
8079
8080 case Instruction::PtrToInt: {
8081 // Pointer to integer cast is straight-forward, so do model it.
8082 const SCEV *Op = getSCEV(U->getOperand(0));
8083 Type *DstIntTy = U->getType();
8084 // But only if effective SCEV (integer) type is wide enough to represent
8085 // all possible pointer values.
8086 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8087 if (isa<SCEVCouldNotCompute>(IntOp))
8088 return getUnknown(V);
8089 return IntOp;
8090 }
8091 case Instruction::IntToPtr:
8092 // Just don't deal with inttoptr casts.
8093 return getUnknown(V);
8094
8095 case Instruction::SDiv:
8096 // If both operands are non-negative, this is just an udiv.
8097 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8098 isKnownNonNegative(getSCEV(U->getOperand(1))))
8099 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8100 break;
8101
8102 case Instruction::SRem:
8103 // If both operands are non-negative, this is just an urem.
8104 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8105 isKnownNonNegative(getSCEV(U->getOperand(1))))
8106 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8107 break;
8108
8109 case Instruction::GetElementPtr:
8110 return createNodeForGEP(cast<GEPOperator>(U));
8111
8112 case Instruction::PHI:
8113 return createNodeForPHI(cast<PHINode>(U));
8114
8115 case Instruction::Select:
8116 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8117 U->getOperand(2));
8118
8119 case Instruction::Call:
8120 case Instruction::Invoke:
8121 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8122 return getSCEV(RV);
8123
8124 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8125 switch (II->getIntrinsicID()) {
8126 case Intrinsic::abs:
8127 return getAbsExpr(
8128 getSCEV(II->getArgOperand(0)),
8129 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8130 case Intrinsic::umax:
8131 LHS = getSCEV(II->getArgOperand(0));
8132 RHS = getSCEV(II->getArgOperand(1));
8133 return getUMaxExpr(LHS, RHS);
8134 case Intrinsic::umin:
8135 LHS = getSCEV(II->getArgOperand(0));
8136 RHS = getSCEV(II->getArgOperand(1));
8137 return getUMinExpr(LHS, RHS);
8138 case Intrinsic::smax:
8139 LHS = getSCEV(II->getArgOperand(0));
8140 RHS = getSCEV(II->getArgOperand(1));
8141 return getSMaxExpr(LHS, RHS);
8142 case Intrinsic::smin:
8143 LHS = getSCEV(II->getArgOperand(0));
8144 RHS = getSCEV(II->getArgOperand(1));
8145 return getSMinExpr(LHS, RHS);
8146 case Intrinsic::usub_sat: {
8147 const SCEV *X = getSCEV(II->getArgOperand(0));
8148 const SCEV *Y = getSCEV(II->getArgOperand(1));
8149 const SCEV *ClampedY = getUMinExpr(X, Y);
8150 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8151 }
8152 case Intrinsic::uadd_sat: {
8153 const SCEV *X = getSCEV(II->getArgOperand(0));
8154 const SCEV *Y = getSCEV(II->getArgOperand(1));
8155 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8156 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8157 }
8158 case Intrinsic::start_loop_iterations:
8159 case Intrinsic::annotation:
8160 case Intrinsic::ptr_annotation:
8161 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8162 // just eqivalent to the first operand for SCEV purposes.
8163 return getSCEV(II->getArgOperand(0));
8164 case Intrinsic::vscale:
8165 return getVScale(II->getType());
8166 default:
8167 break;
8168 }
8169 }
8170 break;
8171 }
8172
8173 return getUnknown(V);
8174}
8175
8176//===----------------------------------------------------------------------===//
8177// Iteration Count Computation Code
8178//
8179
8181 if (isa<SCEVCouldNotCompute>(ExitCount))
8182 return getCouldNotCompute();
8183
8184 auto *ExitCountType = ExitCount->getType();
8185 assert(ExitCountType->isIntegerTy());
8186 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8187 1 + ExitCountType->getScalarSizeInBits());
8188 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8189}
8190
8192 Type *EvalTy,
8193 const Loop *L) {
8194 if (isa<SCEVCouldNotCompute>(ExitCount))
8195 return getCouldNotCompute();
8196
8197 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8198 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8199
8200 auto CanAddOneWithoutOverflow = [&]() {
8201 ConstantRange ExitCountRange =
8202 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8203 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8204 return true;
8205
8206 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8207 getMinusOne(ExitCount->getType()));
8208 };
8209
8210 // If we need to zero extend the backedge count, check if we can add one to
8211 // it prior to zero extending without overflow. Provided this is safe, it
8212 // allows better simplification of the +1.
8213 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8214 return getZeroExtendExpr(
8215 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8216
8217 // Get the total trip count from the count by adding 1. This may wrap.
8218 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8219}
8220
8221static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8222 if (!ExitCount)
8223 return 0;
8224
8225 ConstantInt *ExitConst = ExitCount->getValue();
8226
8227 // Guard against huge trip counts.
8228 if (ExitConst->getValue().getActiveBits() > 32)
8229 return 0;
8230
8231 // In case of integer overflow, this returns 0, which is correct.
8232 return ((unsigned)ExitConst->getZExtValue()) + 1;
8233}
8234
8236 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8237 return getConstantTripCount(ExitCount);
8238}
8239
8240unsigned
8242 const BasicBlock *ExitingBlock) {
8243 assert(ExitingBlock && "Must pass a non-null exiting block!");
8244 assert(L->isLoopExiting(ExitingBlock) &&
8245 "Exiting block must actually branch out of the loop!");
8246 const SCEVConstant *ExitCount =
8247 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8248 return getConstantTripCount(ExitCount);
8249}
8250
8252 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8253
8254 const auto *MaxExitCount =
8255 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8257 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8258}
8259
8261 SmallVector<BasicBlock *, 8> ExitingBlocks;
8262 L->getExitingBlocks(ExitingBlocks);
8263
8264 std::optional<unsigned> Res;
8265 for (auto *ExitingBB : ExitingBlocks) {
8266 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8267 if (!Res)
8268 Res = Multiple;
8269 Res = (unsigned)std::gcd(*Res, Multiple);
8270 }
8271 return Res.value_or(1);
8272}
8273
8275 const SCEV *ExitCount) {
8276 if (ExitCount == getCouldNotCompute())
8277 return 1;
8278
8279 // Get the trip count
8280 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8281
8282 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8283 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8284 // the greatest power of 2 divisor less than 2^32.
8285 return Multiple.getActiveBits() > 32
8286 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8287 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8288}
8289
8290/// Returns the largest constant divisor of the trip count of this loop as a
8291/// normal unsigned value, if possible. This means that the actual trip count is
8292/// always a multiple of the returned value (don't forget the trip count could
8293/// very well be zero as well!).
8294///
8295/// Returns 1 if the trip count is unknown or not guaranteed to be the
8296/// multiple of a constant (which is also the case if the trip count is simply
8297/// constant, use getSmallConstantTripCount for that case), Will also return 1
8298/// if the trip count is very large (>= 2^32).
8299///
8300/// As explained in the comments for getSmallConstantTripCount, this assumes
8301/// that control exits the loop via ExitingBlock.
8302unsigned
8304 const BasicBlock *ExitingBlock) {
8305 assert(ExitingBlock && "Must pass a non-null exiting block!");
8306 assert(L->isLoopExiting(ExitingBlock) &&
8307 "Exiting block must actually branch out of the loop!");
8308 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8309 return getSmallConstantTripMultiple(L, ExitCount);
8310}
8311
8313 const BasicBlock *ExitingBlock,
8314 ExitCountKind Kind) {
8315 switch (Kind) {
8316 case Exact:
8317 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8318 case SymbolicMaximum:
8319 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8320 case ConstantMaximum:
8321 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8322 };
8323 llvm_unreachable("Invalid ExitCountKind!");
8324}
8325
8327 const Loop *L, const BasicBlock *ExitingBlock,
8329 switch (Kind) {
8330 case Exact:
8331 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8332 Predicates);
8333 case SymbolicMaximum:
8334 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8335 Predicates);
8336 case ConstantMaximum:
8337 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8338 Predicates);
8339 };
8340 llvm_unreachable("Invalid ExitCountKind!");
8341}
8342
8345 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8346}
8347
8349 ExitCountKind Kind) {
8350 switch (Kind) {
8351 case Exact:
8352 return getBackedgeTakenInfo(L).getExact(L, this);
8353 case ConstantMaximum:
8354 return getBackedgeTakenInfo(L).getConstantMax(this);
8355 case SymbolicMaximum:
8356 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8357 };
8358 llvm_unreachable("Invalid ExitCountKind!");
8359}
8360
8363 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8364}
8365
8368 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8369}
8370
8372 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8373}
8374
8375/// Push PHI nodes in the header of the given loop onto the given Worklist.
8376static void PushLoopPHIs(const Loop *L,
8379 BasicBlock *Header = L->getHeader();
8380
8381 // Push all Loop-header PHIs onto the Worklist stack.
8382 for (PHINode &PN : Header->phis())
8383 if (Visited.insert(&PN).second)
8384 Worklist.push_back(&PN);
8385}
8386
8387ScalarEvolution::BackedgeTakenInfo &
8388ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8389 auto &BTI = getBackedgeTakenInfo(L);
8390 if (BTI.hasFullInfo())
8391 return BTI;
8392
8393 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8394
8395 if (!Pair.second)
8396 return Pair.first->second;
8397
8398 BackedgeTakenInfo Result =
8399 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8400
8401 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8402}
8403
8404ScalarEvolution::BackedgeTakenInfo &
8405ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8406 // Initially insert an invalid entry for this loop. If the insertion
8407 // succeeds, proceed to actually compute a backedge-taken count and
8408 // update the value. The temporary CouldNotCompute value tells SCEV
8409 // code elsewhere that it shouldn't attempt to request a new
8410 // backedge-taken count, which could result in infinite recursion.
8411 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8412 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8413 if (!Pair.second)
8414 return Pair.first->second;
8415
8416 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8417 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8418 // must be cleared in this scope.
8419 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8420
8421 // Now that we know more about the trip count for this loop, forget any
8422 // existing SCEV values for PHI nodes in this loop since they are only
8423 // conservative estimates made without the benefit of trip count
8424 // information. This invalidation is not necessary for correctness, and is
8425 // only done to produce more precise results.
8426 if (Result.hasAnyInfo()) {
8427 // Invalidate any expression using an addrec in this loop.
8429 auto LoopUsersIt = LoopUsers.find(L);
8430 if (LoopUsersIt != LoopUsers.end())
8431 append_range(ToForget, LoopUsersIt->second);
8432 forgetMemoizedResults(ToForget);
8433
8434 // Invalidate constant-evolved loop header phis.
8435 for (PHINode &PN : L->getHeader()->phis())
8436 ConstantEvolutionLoopExitValue.erase(&PN);
8437 }
8438
8439 // Re-lookup the insert position, since the call to
8440 // computeBackedgeTakenCount above could result in a
8441 // recusive call to getBackedgeTakenInfo (on a different
8442 // loop), which would invalidate the iterator computed
8443 // earlier.
8444 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8445}
8446
8448 // This method is intended to forget all info about loops. It should
8449 // invalidate caches as if the following happened:
8450 // - The trip counts of all loops have changed arbitrarily
8451 // - Every llvm::Value has been updated in place to produce a different
8452 // result.
8453 BackedgeTakenCounts.clear();
8454 PredicatedBackedgeTakenCounts.clear();
8455 BECountUsers.clear();
8456 LoopPropertiesCache.clear();
8457 ConstantEvolutionLoopExitValue.clear();
8458 ValueExprMap.clear();
8459 ValuesAtScopes.clear();
8460 ValuesAtScopesUsers.clear();
8461 LoopDispositions.clear();
8462 BlockDispositions.clear();
8463 UnsignedRanges.clear();
8464 SignedRanges.clear();
8465 ExprValueMap.clear();
8466 HasRecMap.clear();
8467 ConstantMultipleCache.clear();
8468 PredicatedSCEVRewrites.clear();
8469 FoldCache.clear();
8470 FoldCacheUser.clear();
8471}
8472void ScalarEvolution::visitAndClearUsers(
8476 while (!Worklist.empty()) {
8477 Instruction *I = Worklist.pop_back_val();
8478 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8479 continue;
8480
8482 ValueExprMap.find_as(static_cast<Value *>(I));
8483 if (It != ValueExprMap.end()) {
8484 eraseValueFromMap(It->first);
8485 ToForget.push_back(It->second);
8486 if (PHINode *PN = dyn_cast<PHINode>(I))
8487 ConstantEvolutionLoopExitValue.erase(PN);
8488 }
8489
8490 PushDefUseChildren(I, Worklist, Visited);
8491 }
8492}
8493
8495 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8499
8500 // Iterate over all the loops and sub-loops to drop SCEV information.
8501 while (!LoopWorklist.empty()) {
8502 auto *CurrL = LoopWorklist.pop_back_val();
8503
8504 // Drop any stored trip count value.
8505 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8506 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8507
8508 // Drop information about predicated SCEV rewrites for this loop.
8509 for (auto I = PredicatedSCEVRewrites.begin();
8510 I != PredicatedSCEVRewrites.end();) {
8511 std::pair<const SCEV *, const Loop *> Entry = I->first;
8512 if (Entry.second == CurrL)
8513 PredicatedSCEVRewrites.erase(I++);
8514 else
8515 ++I;
8516 }
8517
8518 auto LoopUsersItr = LoopUsers.find(CurrL);
8519 if (LoopUsersItr != LoopUsers.end()) {
8520 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8521 LoopUsersItr->second.end());
8522 }
8523
8524 // Drop information about expressions based on loop-header PHIs.
8525 PushLoopPHIs(CurrL, Worklist, Visited);
8526 visitAndClearUsers(Worklist, Visited, ToForget);
8527
8528 LoopPropertiesCache.erase(CurrL);
8529 // Forget all contained loops too, to avoid dangling entries in the
8530 // ValuesAtScopes map.
8531 LoopWorklist.append(CurrL->begin(), CurrL->end());
8532 }
8533 forgetMemoizedResults(ToForget);
8534}
8535
8537 forgetLoop(L->getOutermostLoop());
8538}
8539
8541 Instruction *I = dyn_cast<Instruction>(V);
8542 if (!I) return;
8543
8544 // Drop information about expressions based on loop-header PHIs.
8548 Worklist.push_back(I);
8549 Visited.insert(I);
8550 visitAndClearUsers(Worklist, Visited, ToForget);
8551
8552 forgetMemoizedResults(ToForget);
8553}
8554
8556 if (!isSCEVable(V->getType()))
8557 return;
8558
8559 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8560 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8561 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8562 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8563 if (const SCEV *S = getExistingSCEV(V)) {
8564 struct InvalidationRootCollector {
8565 Loop *L;
8567
8568 InvalidationRootCollector(Loop *L) : L(L) {}
8569
8570 bool follow(const SCEV *S) {
8571 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8572 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8573 if (L->contains(I))
8574 Roots.push_back(S);
8575 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8576 if (L->contains(AddRec->getLoop()))
8577 Roots.push_back(S);
8578 }
8579 return true;
8580 }
8581 bool isDone() const { return false; }
8582 };
8583
8584 InvalidationRootCollector C(L);
8585 visitAll(S, C);
8586 forgetMemoizedResults(C.Roots);
8587 }
8588
8589 // Also perform the normal invalidation.
8590 forgetValue(V);
8591}
8592
8593void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8594
8596 // Unless a specific value is passed to invalidation, completely clear both
8597 // caches.
8598 if (!V) {
8599 BlockDispositions.clear();
8600 LoopDispositions.clear();
8601 return;
8602 }
8603
8604 if (!isSCEVable(V->getType()))
8605 return;
8606
8607 const SCEV *S = getExistingSCEV(V);
8608 if (!S)
8609 return;
8610
8611 // Invalidate the block and loop dispositions cached for S. Dispositions of
8612 // S's users may change if S's disposition changes (i.e. a user may change to
8613 // loop-invariant, if S changes to loop invariant), so also invalidate
8614 // dispositions of S's users recursively.
8615 SmallVector<const SCEV *, 8> Worklist = {S};
8617 while (!Worklist.empty()) {
8618 const SCEV *Curr = Worklist.pop_back_val();
8619 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8620 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8621 if (!LoopDispoRemoved && !BlockDispoRemoved)
8622 continue;
8623 auto Users = SCEVUsers.find(Curr);
8624 if (Users != SCEVUsers.end())
8625 for (const auto *User : Users->second)
8626 if (Seen.insert(User).second)
8627 Worklist.push_back(User);
8628 }
8629}
8630
8631/// Get the exact loop backedge taken count considering all loop exits. A
8632/// computable result can only be returned for loops with all exiting blocks
8633/// dominating the latch. howFarToZero assumes that the limit of each loop test
8634/// is never skipped. This is a valid assumption as long as the loop exits via
8635/// that test. For precise results, it is the caller's responsibility to specify
8636/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8637const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8638 const Loop *L, ScalarEvolution *SE,
8640 // If any exits were not computable, the loop is not computable.
8641 if (!isComplete() || ExitNotTaken.empty())
8642 return SE->getCouldNotCompute();
8643
8644 const BasicBlock *Latch = L->getLoopLatch();
8645 // All exiting blocks we have collected must dominate the only backedge.
8646 if (!Latch)
8647 return SE->getCouldNotCompute();
8648
8649 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8650 // count is simply a minimum out of all these calculated exit counts.
8652 for (const auto &ENT : ExitNotTaken) {
8653 const SCEV *BECount = ENT.ExactNotTaken;
8654 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8655 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8656 "We should only have known counts for exiting blocks that dominate "
8657 "latch!");
8658
8659 Ops.push_back(BECount);
8660
8661 if (Preds)
8662 append_range(*Preds, ENT.Predicates);
8663
8664 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8665 "Predicate should be always true!");
8666 }
8667
8668 // If an earlier exit exits on the first iteration (exit count zero), then
8669 // a later poison exit count should not propagate into the result. This are
8670 // exactly the semantics provided by umin_seq.
8671 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8672}
8673
8674const ScalarEvolution::ExitNotTakenInfo *
8675ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8676 const BasicBlock *ExitingBlock,
8677 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8678 for (const auto &ENT : ExitNotTaken)
8679 if (ENT.ExitingBlock == ExitingBlock) {
8680 if (ENT.hasAlwaysTruePredicate())
8681 return &ENT;
8682 else if (Predicates) {
8683 append_range(*Predicates, ENT.Predicates);
8684 return &ENT;
8685 }
8686 }
8687
8688 return nullptr;
8689}
8690
8691/// getConstantMax - Get the constant max backedge taken count for the loop.
8692const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8693 ScalarEvolution *SE,
8694 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8695 if (!getConstantMax())
8696 return SE->getCouldNotCompute();
8697
8698 for (const auto &ENT : ExitNotTaken)
8699 if (!ENT.hasAlwaysTruePredicate()) {
8700 if (!Predicates)
8701 return SE->getCouldNotCompute();
8702 append_range(*Predicates, ENT.Predicates);
8703 }
8704
8705 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8706 isa<SCEVConstant>(getConstantMax())) &&
8707 "No point in having a non-constant max backedge taken count!");
8708 return getConstantMax();
8709}
8710
8711const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8712 const Loop *L, ScalarEvolution *SE,
8714 if (!SymbolicMax) {
8715 // Form an expression for the maximum exit count possible for this loop. We
8716 // merge the max and exact information to approximate a version of
8717 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8718 // constants.
8720
8721 for (const auto &ENT : ExitNotTaken) {
8722 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8723 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8724 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8725 "We should only have known counts for exiting blocks that "
8726 "dominate latch!");
8727 ExitCounts.push_back(ExitCount);
8728 if (Predicates)
8729 append_range(*Predicates, ENT.Predicates);
8730
8731 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8732 "Predicate should be always true!");
8733 }
8734 }
8735 if (ExitCounts.empty())
8736 SymbolicMax = SE->getCouldNotCompute();
8737 else
8738 SymbolicMax =
8739 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8740 }
8741 return SymbolicMax;
8742}
8743
8744bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8745 ScalarEvolution *SE) const {
8746 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8747 return !ENT.hasAlwaysTruePredicate();
8748 };
8749 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8750}
8751
8753 : ExitLimit(E, E, E, false) {}
8754
8756 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8757 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8759 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8760 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8761 // If we prove the max count is zero, so is the symbolic bound. This happens
8762 // in practice due to differences in a) how context sensitive we've chosen
8763 // to be and b) how we reason about bounds implied by UB.
8764 if (ConstantMaxNotTaken->isZero()) {
8766 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8767 }
8768
8769 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8770 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8771 "Exact is not allowed to be less precise than Constant Max");
8772 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8773 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8774 "Exact is not allowed to be less precise than Symbolic Max");
8775 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8776 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8777 "Symbolic Max is not allowed to be less precise than Constant Max");
8778 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8779 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8780 "No point in having a non-constant max backedge taken count!");
8782 for (const auto PredList : PredLists)
8783 for (const auto *P : PredList) {
8784 if (SeenPreds.contains(P))
8785 continue;
8786 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8787 SeenPreds.insert(P);
8788 Predicates.push_back(P);
8789 }
8790 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8791 "Backedge count should be int");
8792 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8794 "Max backedge count should be int");
8795}
8796
8798 const SCEV *ConstantMaxNotTaken,
8799 const SCEV *SymbolicMaxNotTaken,
8800 bool MaxOrZero,
8802 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8803 ArrayRef({PredList})) {}
8804
8805/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8806/// computable exit into a persistent ExitNotTakenInfo array.
8807ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8809 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8810 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8811 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8812
8813 ExitNotTaken.reserve(ExitCounts.size());
8814 std::transform(ExitCounts.begin(), ExitCounts.end(),
8815 std::back_inserter(ExitNotTaken),
8816 [&](const EdgeExitInfo &EEI) {
8817 BasicBlock *ExitBB = EEI.first;
8818 const ExitLimit &EL = EEI.second;
8819 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8820 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8821 EL.Predicates);
8822 });
8823 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8824 isa<SCEVConstant>(ConstantMax)) &&
8825 "No point in having a non-constant max backedge taken count!");
8826}
8827
8828/// Compute the number of times the backedge of the specified loop will execute.
8829ScalarEvolution::BackedgeTakenInfo
8830ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8831 bool AllowPredicates) {
8832 SmallVector<BasicBlock *, 8> ExitingBlocks;
8833 L->getExitingBlocks(ExitingBlocks);
8834
8835 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8836
8838 bool CouldComputeBECount = true;
8839 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8840 const SCEV *MustExitMaxBECount = nullptr;
8841 const SCEV *MayExitMaxBECount = nullptr;
8842 bool MustExitMaxOrZero = false;
8843 bool IsOnlyExit = ExitingBlocks.size() == 1;
8844
8845 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8846 // and compute maxBECount.
8847 // Do a union of all the predicates here.
8848 for (BasicBlock *ExitBB : ExitingBlocks) {
8849 // We canonicalize untaken exits to br (constant), ignore them so that
8850 // proving an exit untaken doesn't negatively impact our ability to reason
8851 // about the loop as whole.
8852 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8853 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8854 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8855 if (ExitIfTrue == CI->isZero())
8856 continue;
8857 }
8858
8859 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8860
8861 assert((AllowPredicates || EL.Predicates.empty()) &&
8862 "Predicated exit limit when predicates are not allowed!");
8863
8864 // 1. For each exit that can be computed, add an entry to ExitCounts.
8865 // CouldComputeBECount is true only if all exits can be computed.
8866 if (EL.ExactNotTaken != getCouldNotCompute())
8867 ++NumExitCountsComputed;
8868 else
8869 // We couldn't compute an exact value for this exit, so
8870 // we won't be able to compute an exact value for the loop.
8871 CouldComputeBECount = false;
8872 // Remember exit count if either exact or symbolic is known. Because
8873 // Exact always implies symbolic, only check symbolic.
8874 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8875 ExitCounts.emplace_back(ExitBB, EL);
8876 else {
8877 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8878 "Exact is known but symbolic isn't?");
8879 ++NumExitCountsNotComputed;
8880 }
8881
8882 // 2. Derive the loop's MaxBECount from each exit's max number of
8883 // non-exiting iterations. Partition the loop exits into two kinds:
8884 // LoopMustExits and LoopMayExits.
8885 //
8886 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8887 // is a LoopMayExit. If any computable LoopMustExit is found, then
8888 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8889 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8890 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8891 // any
8892 // computable EL.ConstantMaxNotTaken.
8893 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8894 DT.dominates(ExitBB, Latch)) {
8895 if (!MustExitMaxBECount) {
8896 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8897 MustExitMaxOrZero = EL.MaxOrZero;
8898 } else {
8899 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8900 EL.ConstantMaxNotTaken);
8901 }
8902 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8903 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8904 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8905 else {
8906 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8907 EL.ConstantMaxNotTaken);
8908 }
8909 }
8910 }
8911 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8912 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8913 // The loop backedge will be taken the maximum or zero times if there's
8914 // a single exit that must be taken the maximum or zero times.
8915 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8916
8917 // Remember which SCEVs are used in exit limits for invalidation purposes.
8918 // We only care about non-constant SCEVs here, so we can ignore
8919 // EL.ConstantMaxNotTaken
8920 // and MaxBECount, which must be SCEVConstant.
8921 for (const auto &Pair : ExitCounts) {
8922 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8923 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8924 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8925 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8926 {L, AllowPredicates});
8927 }
8928 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8929 MaxBECount, MaxOrZero);
8930}
8931
8933ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8934 bool IsOnlyExit, bool AllowPredicates) {
8935 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8936 // If our exiting block does not dominate the latch, then its connection with
8937 // loop's exit limit may be far from trivial.
8938 const BasicBlock *Latch = L->getLoopLatch();
8939 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8940 return getCouldNotCompute();
8941
8942 Instruction *Term = ExitingBlock->getTerminator();
8943 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8944 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8945 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8946 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8947 "It should have one successor in loop and one exit block!");
8948 // Proceed to the next level to examine the exit condition expression.
8949 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8950 /*ControlsOnlyExit=*/IsOnlyExit,
8951 AllowPredicates);
8952 }
8953
8954 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8955 // For switch, make sure that there is a single exit from the loop.
8956 BasicBlock *Exit = nullptr;
8957 for (auto *SBB : successors(ExitingBlock))
8958 if (!L->contains(SBB)) {
8959 if (Exit) // Multiple exit successors.
8960 return getCouldNotCompute();
8961 Exit = SBB;
8962 }
8963 assert(Exit && "Exiting block must have at least one exit");
8964 return computeExitLimitFromSingleExitSwitch(
8965 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8966 }
8967
8968 return getCouldNotCompute();
8969}
8970
8972 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8973 bool AllowPredicates) {
8974 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8975 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8976 ControlsOnlyExit, AllowPredicates);
8977}
8978
8979std::optional<ScalarEvolution::ExitLimit>
8980ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8981 bool ExitIfTrue, bool ControlsOnlyExit,
8982 bool AllowPredicates) {
8983 (void)this->L;
8984 (void)this->ExitIfTrue;
8985 (void)this->AllowPredicates;
8986
8987 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8988 this->AllowPredicates == AllowPredicates &&
8989 "Variance in assumed invariant key components!");
8990 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8991 if (Itr == TripCountMap.end())
8992 return std::nullopt;
8993 return Itr->second;
8994}
8995
8996void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8997 bool ExitIfTrue,
8998 bool ControlsOnlyExit,
8999 bool AllowPredicates,
9000 const ExitLimit &EL) {
9001 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9002 this->AllowPredicates == AllowPredicates &&
9003 "Variance in assumed invariant key components!");
9004
9005 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9006 assert(InsertResult.second && "Expected successful insertion!");
9007 (void)InsertResult;
9008 (void)ExitIfTrue;
9009}
9010
9011ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9012 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9013 bool ControlsOnlyExit, bool AllowPredicates) {
9014
9015 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9016 AllowPredicates))
9017 return *MaybeEL;
9018
9019 ExitLimit EL = computeExitLimitFromCondImpl(
9020 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9021 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9022 return EL;
9023}
9024
9025ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9026 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9027 bool ControlsOnlyExit, bool AllowPredicates) {
9028 // Handle BinOp conditions (And, Or).
9029 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9030 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9031 return *LimitFromBinOp;
9032
9033 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9034 // Proceed to the next level to examine the icmp.
9035 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9036 ExitLimit EL =
9037 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9038 if (EL.hasFullInfo() || !AllowPredicates)
9039 return EL;
9040
9041 // Try again, but use SCEV predicates this time.
9042 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9043 ControlsOnlyExit,
9044 /*AllowPredicates=*/true);
9045 }
9046
9047 // Check for a constant condition. These are normally stripped out by
9048 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9049 // preserve the CFG and is temporarily leaving constant conditions
9050 // in place.
9051 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9052 if (ExitIfTrue == !CI->getZExtValue())
9053 // The backedge is always taken.
9054 return getCouldNotCompute();
9055 // The backedge is never taken.
9056 return getZero(CI->getType());
9057 }
9058
9059 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9060 // with a constant step, we can form an equivalent icmp predicate and figure
9061 // out how many iterations will be taken before we exit.
9062 const WithOverflowInst *WO;
9063 const APInt *C;
9064 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9065 match(WO->getRHS(), m_APInt(C))) {
9066 ConstantRange NWR =
9068 WO->getNoWrapKind());
9069 CmpInst::Predicate Pred;
9070 APInt NewRHSC, Offset;
9071 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9072 if (!ExitIfTrue)
9073 Pred = ICmpInst::getInversePredicate(Pred);
9074 auto *LHS = getSCEV(WO->getLHS());
9075 if (Offset != 0)
9077 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9078 ControlsOnlyExit, AllowPredicates);
9079 if (EL.hasAnyInfo())
9080 return EL;
9081 }
9082
9083 // If it's not an integer or pointer comparison then compute it the hard way.
9084 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9085}
9086
9087std::optional<ScalarEvolution::ExitLimit>
9088ScalarEvolution::computeExitLimitFromCondFromBinOp(
9089 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9090 bool ControlsOnlyExit, bool AllowPredicates) {
9091 // Check if the controlling expression for this loop is an And or Or.
9092 Value *Op0, *Op1;
9093 bool IsAnd = false;
9094 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9095 IsAnd = true;
9096 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9097 IsAnd = false;
9098 else
9099 return std::nullopt;
9100
9101 // EitherMayExit is true in these two cases:
9102 // br (and Op0 Op1), loop, exit
9103 // br (or Op0 Op1), exit, loop
9104 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9105 ExitLimit EL0 = computeExitLimitFromCondCached(
9106 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9107 AllowPredicates);
9108 ExitLimit EL1 = computeExitLimitFromCondCached(
9109 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9110 AllowPredicates);
9111
9112 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9113 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9114 if (isa<ConstantInt>(Op1))
9115 return Op1 == NeutralElement ? EL0 : EL1;
9116 if (isa<ConstantInt>(Op0))
9117 return Op0 == NeutralElement ? EL1 : EL0;
9118
9119 const SCEV *BECount = getCouldNotCompute();
9120 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9121 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9122 if (EitherMayExit) {
9123 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9124 // Both conditions must be same for the loop to continue executing.
9125 // Choose the less conservative count.
9126 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9127 EL1.ExactNotTaken != getCouldNotCompute()) {
9128 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9129 UseSequentialUMin);
9130 }
9131 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9132 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9133 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9134 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9135 else
9136 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9137 EL1.ConstantMaxNotTaken);
9138 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9139 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9140 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9141 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9142 else
9143 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9144 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9145 } else {
9146 // Both conditions must be same at the same time for the loop to exit.
9147 // For now, be conservative.
9148 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9149 BECount = EL0.ExactNotTaken;
9150 }
9151
9152 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9153 // to be more aggressive when computing BECount than when computing
9154 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9155 // and
9156 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9157 // EL1.ConstantMaxNotTaken to not.
9158 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9159 !isa<SCEVCouldNotCompute>(BECount))
9160 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9161 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9162 SymbolicMaxBECount =
9163 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9164 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9165 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9166}
9167
9168ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9169 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9170 bool AllowPredicates) {
9171 // If the condition was exit on true, convert the condition to exit on false
9172 CmpPredicate Pred;
9173 if (!ExitIfTrue)
9174 Pred = ExitCond->getCmpPredicate();
9175 else
9176 Pred = ExitCond->getInverseCmpPredicate();
9177 const ICmpInst::Predicate OriginalPred = Pred;
9178
9179 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9180 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9181
9182 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9183 AllowPredicates);
9184 if (EL.hasAnyInfo())
9185 return EL;
9186
9187 auto *ExhaustiveCount =
9188 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9189
9190 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9191 return ExhaustiveCount;
9192
9193 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9194 ExitCond->getOperand(1), L, OriginalPred);
9195}
9196ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9197 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9198 bool ControlsOnlyExit, bool AllowPredicates) {
9199
9200 // Try to evaluate any dependencies out of the loop.
9201 LHS = getSCEVAtScope(LHS, L);
9202 RHS = getSCEVAtScope(RHS, L);
9203
9204 // At this point, we would like to compute how many iterations of the
9205 // loop the predicate will return true for these inputs.
9206 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9207 // If there is a loop-invariant, force it into the RHS.
9208 std::swap(LHS, RHS);
9210 }
9211
9212 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9214 // Simplify the operands before analyzing them.
9215 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9216
9217 // If we have a comparison of a chrec against a constant, try to use value
9218 // ranges to answer this query.
9219 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9220 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9221 if (AddRec->getLoop() == L) {
9222 // Form the constant range.
9223 ConstantRange CompRange =
9224 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9225
9226 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9227 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9228 }
9229
9230 // If this loop must exit based on this condition (or execute undefined
9231 // behaviour), see if we can improve wrap flags. This is essentially
9232 // a must execute style proof.
9233 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9234 // If we can prove the test sequence produced must repeat the same values
9235 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9236 // because if it did, we'd have an infinite (undefined) loop.
9237 // TODO: We can peel off any functions which are invertible *in L*. Loop
9238 // invariant terms are effectively constants for our purposes here.
9239 auto *InnerLHS = LHS;
9240 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9241 InnerLHS = ZExt->getOperand();
9242 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9243 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9244 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9245 /*OrNegative=*/true)) {
9246 auto Flags = AR->getNoWrapFlags();
9247 Flags = setFlags(Flags, SCEV::FlagNW);
9250 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9251 }
9252
9253 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9254 // From no-self-wrap, this follows trivially from the fact that every
9255 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9256 // last value before (un)signed wrap. Since we know that last value
9257 // didn't exit, nor will any smaller one.
9258 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9259 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9260 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9261 AR && AR->getLoop() == L && AR->isAffine() &&
9262 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9263 isKnownPositive(AR->getStepRecurrence(*this))) {
9264 auto Flags = AR->getNoWrapFlags();
9265 Flags = setFlags(Flags, WrapType);
9268 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9269 }
9270 }
9271 }
9272
9273 switch (Pred) {
9274 case ICmpInst::ICMP_NE: { // while (X != Y)
9275 // Convert to: while (X-Y != 0)
9276 if (LHS->getType()->isPointerTy()) {
9278 if (isa<SCEVCouldNotCompute>(LHS))
9279 return LHS;
9280 }
9281 if (RHS->getType()->isPointerTy()) {
9283 if (isa<SCEVCouldNotCompute>(RHS))
9284 return RHS;
9285 }
9286 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9287 AllowPredicates);
9288 if (EL.hasAnyInfo())
9289 return EL;
9290 break;
9291 }
9292 case ICmpInst::ICMP_EQ: { // while (X == Y)
9293 // Convert to: while (X-Y == 0)
9294 if (LHS->getType()->isPointerTy()) {
9296 if (isa<SCEVCouldNotCompute>(LHS))
9297 return LHS;
9298 }
9299 if (RHS->getType()->isPointerTy()) {
9301 if (isa<SCEVCouldNotCompute>(RHS))
9302 return RHS;
9303 }
9304 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9305 if (EL.hasAnyInfo()) return EL;
9306 break;
9307 }
9308 case ICmpInst::ICMP_SLE:
9309 case ICmpInst::ICMP_ULE:
9310 // Since the loop is finite, an invariant RHS cannot include the boundary
9311 // value, otherwise it would loop forever.
9312 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9313 !isLoopInvariant(RHS, L)) {
9314 // Otherwise, perform the addition in a wider type, to avoid overflow.
9315 // If the LHS is an addrec with the appropriate nowrap flag, the
9316 // extension will be sunk into it and the exit count can be analyzed.
9317 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9318 if (!OldType)
9319 break;
9320 // Prefer doubling the bitwidth over adding a single bit to make it more
9321 // likely that we use a legal type.
9322 auto *NewType =
9323 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9324 if (ICmpInst::isSigned(Pred)) {
9325 LHS = getSignExtendExpr(LHS, NewType);
9326 RHS = getSignExtendExpr(RHS, NewType);
9327 } else {
9328 LHS = getZeroExtendExpr(LHS, NewType);
9329 RHS = getZeroExtendExpr(RHS, NewType);
9330 }
9331 }
9332 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9333 [[fallthrough]];
9334 case ICmpInst::ICMP_SLT:
9335 case ICmpInst::ICMP_ULT: { // while (X < Y)
9336 bool IsSigned = ICmpInst::isSigned(Pred);
9337 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9338 AllowPredicates);
9339 if (EL.hasAnyInfo())
9340 return EL;
9341 break;
9342 }
9343 case ICmpInst::ICMP_SGE:
9344 case ICmpInst::ICMP_UGE:
9345 // Since the loop is finite, an invariant RHS cannot include the boundary
9346 // value, otherwise it would loop forever.
9347 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9348 !isLoopInvariant(RHS, L))
9349 break;
9350 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9351 [[fallthrough]];
9352 case ICmpInst::ICMP_SGT:
9353 case ICmpInst::ICMP_UGT: { // while (X > Y)
9354 bool IsSigned = ICmpInst::isSigned(Pred);
9355 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9356 AllowPredicates);
9357 if (EL.hasAnyInfo())
9358 return EL;
9359 break;
9360 }
9361 default:
9362 break;
9363 }
9364
9365 return getCouldNotCompute();
9366}
9367
9369ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9370 SwitchInst *Switch,
9371 BasicBlock *ExitingBlock,
9372 bool ControlsOnlyExit) {
9373 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9374
9375 // Give up if the exit is the default dest of a switch.
9376 if (Switch->getDefaultDest() == ExitingBlock)
9377 return getCouldNotCompute();
9378
9379 assert(L->contains(Switch->getDefaultDest()) &&
9380 "Default case must not exit the loop!");
9381 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9382 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9383
9384 // while (X != Y) --> while (X-Y != 0)
9385 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9386 if (EL.hasAnyInfo())
9387 return EL;
9388
9389 return getCouldNotCompute();
9390}
9391
9392static ConstantInt *
9394 ScalarEvolution &SE) {
9395 const SCEV *InVal = SE.getConstant(C);
9396 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9397 assert(isa<SCEVConstant>(Val) &&
9398 "Evaluation of SCEV at constant didn't fold correctly?");
9399 return cast<SCEVConstant>(Val)->getValue();
9400}
9401
9402ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9403 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9404 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9405 if (!RHS)
9406 return getCouldNotCompute();
9407
9408 const BasicBlock *Latch = L->getLoopLatch();
9409 if (!Latch)
9410 return getCouldNotCompute();
9411
9412 const BasicBlock *Predecessor = L->getLoopPredecessor();
9413 if (!Predecessor)
9414 return getCouldNotCompute();
9415
9416 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9417 // Return LHS in OutLHS and shift_opt in OutOpCode.
9418 auto MatchPositiveShift =
9419 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9420
9421 using namespace PatternMatch;
9422
9423 ConstantInt *ShiftAmt;
9424 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9425 OutOpCode = Instruction::LShr;
9426 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9427 OutOpCode = Instruction::AShr;
9428 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9429 OutOpCode = Instruction::Shl;
9430 else
9431 return false;
9432
9433 return ShiftAmt->getValue().isStrictlyPositive();
9434 };
9435
9436 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9437 //
9438 // loop:
9439 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9440 // %iv.shifted = lshr i32 %iv, <positive constant>
9441 //
9442 // Return true on a successful match. Return the corresponding PHI node (%iv
9443 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9444 auto MatchShiftRecurrence =
9445 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9446 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9447
9448 {
9450 Value *V;
9451
9452 // If we encounter a shift instruction, "peel off" the shift operation,
9453 // and remember that we did so. Later when we inspect %iv's backedge
9454 // value, we will make sure that the backedge value uses the same
9455 // operation.
9456 //
9457 // Note: the peeled shift operation does not have to be the same
9458 // instruction as the one feeding into the PHI's backedge value. We only
9459 // really care about it being the same *kind* of shift instruction --
9460 // that's all that is required for our later inferences to hold.
9461 if (MatchPositiveShift(LHS, V, OpC)) {
9462 PostShiftOpCode = OpC;
9463 LHS = V;
9464 }
9465 }
9466
9467 PNOut = dyn_cast<PHINode>(LHS);
9468 if (!PNOut || PNOut->getParent() != L->getHeader())
9469 return false;
9470
9471 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9472 Value *OpLHS;
9473
9474 return
9475 // The backedge value for the PHI node must be a shift by a positive
9476 // amount
9477 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9478
9479 // of the PHI node itself
9480 OpLHS == PNOut &&
9481
9482 // and the kind of shift should be match the kind of shift we peeled
9483 // off, if any.
9484 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9485 };
9486
9487 PHINode *PN;
9489 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9490 return getCouldNotCompute();
9491
9492 const DataLayout &DL = getDataLayout();
9493
9494 // The key rationale for this optimization is that for some kinds of shift
9495 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9496 // within a finite number of iterations. If the condition guarding the
9497 // backedge (in the sense that the backedge is taken if the condition is true)
9498 // is false for the value the shift recurrence stabilizes to, then we know
9499 // that the backedge is taken only a finite number of times.
9500
9501 ConstantInt *StableValue = nullptr;
9502 switch (OpCode) {
9503 default:
9504 llvm_unreachable("Impossible case!");
9505
9506 case Instruction::AShr: {
9507 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9508 // bitwidth(K) iterations.
9509 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9510 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9511 Predecessor->getTerminator(), &DT);
9512 auto *Ty = cast<IntegerType>(RHS->getType());
9513 if (Known.isNonNegative())
9514 StableValue = ConstantInt::get(Ty, 0);
9515 else if (Known.isNegative())
9516 StableValue = ConstantInt::get(Ty, -1, true);
9517 else
9518 return getCouldNotCompute();
9519
9520 break;
9521 }
9522 case Instruction::LShr:
9523 case Instruction::Shl:
9524 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9525 // stabilize to 0 in at most bitwidth(K) iterations.
9526 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9527 break;
9528 }
9529
9530 auto *Result =
9531 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9532 assert(Result->getType()->isIntegerTy(1) &&
9533 "Otherwise cannot be an operand to a branch instruction");
9534
9535 if (Result->isZeroValue()) {
9536 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9537 const SCEV *UpperBound =
9539 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9540 }
9541
9542 return getCouldNotCompute();
9543}
9544
9545/// Return true if we can constant fold an instruction of the specified type,
9546/// assuming that all operands were constants.
9547static bool CanConstantFold(const Instruction *I) {
9548 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9549 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9550 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9551 return true;
9552
9553 if (const CallInst *CI = dyn_cast<CallInst>(I))
9554 if (const Function *F = CI->getCalledFunction())
9555 return canConstantFoldCallTo(CI, F);
9556 return false;
9557}
9558
9559/// Determine whether this instruction can constant evolve within this loop
9560/// assuming its operands can all constant evolve.
9561static bool canConstantEvolve(Instruction *I, const Loop *L) {
9562 // An instruction outside of the loop can't be derived from a loop PHI.
9563 if (!L->contains(I)) return false;
9564
9565 if (isa<PHINode>(I)) {
9566 // We don't currently keep track of the control flow needed to evaluate
9567 // PHIs, so we cannot handle PHIs inside of loops.
9568 return L->getHeader() == I->getParent();
9569 }
9570
9571 // If we won't be able to constant fold this expression even if the operands
9572 // are constants, bail early.
9573 return CanConstantFold(I);
9574}
9575
9576/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9577/// recursing through each instruction operand until reaching a loop header phi.
9578static PHINode *
9581 unsigned Depth) {
9583 return nullptr;
9584
9585 // Otherwise, we can evaluate this instruction if all of its operands are
9586 // constant or derived from a PHI node themselves.
9587 PHINode *PHI = nullptr;
9588 for (Value *Op : UseInst->operands()) {
9589 if (isa<Constant>(Op)) continue;
9590
9591 Instruction *OpInst = dyn_cast<Instruction>(Op);
9592 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9593
9594 PHINode *P = dyn_cast<PHINode>(OpInst);
9595 if (!P)
9596 // If this operand is already visited, reuse the prior result.
9597 // We may have P != PHI if this is the deepest point at which the
9598 // inconsistent paths meet.
9599 P = PHIMap.lookup(OpInst);
9600 if (!P) {
9601 // Recurse and memoize the results, whether a phi is found or not.
9602 // This recursive call invalidates pointers into PHIMap.
9603 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9604 PHIMap[OpInst] = P;
9605 }
9606 if (!P)
9607 return nullptr; // Not evolving from PHI
9608 if (PHI && PHI != P)
9609 return nullptr; // Evolving from multiple different PHIs.
9610 PHI = P;
9611 }
9612 // This is a expression evolving from a constant PHI!
9613 return PHI;
9614}
9615
9616/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9617/// in the loop that V is derived from. We allow arbitrary operations along the
9618/// way, but the operands of an operation must either be constants or a value
9619/// derived from a constant PHI. If this expression does not fit with these
9620/// constraints, return null.
9622 Instruction *I = dyn_cast<Instruction>(V);
9623 if (!I || !canConstantEvolve(I, L)) return nullptr;
9624
9625 if (PHINode *PN = dyn_cast<PHINode>(I))
9626 return PN;
9627
9628 // Record non-constant instructions contained by the loop.
9630 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9631}
9632
9633/// EvaluateExpression - Given an expression that passes the
9634/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9635/// in the loop has the value PHIVal. If we can't fold this expression for some
9636/// reason, return null.
9639 const DataLayout &DL,
9640 const TargetLibraryInfo *TLI) {
9641 // Convenient constant check, but redundant for recursive calls.
9642 if (Constant *C = dyn_cast<Constant>(V)) return C;
9643 Instruction *I = dyn_cast<Instruction>(V);
9644 if (!I) return nullptr;
9645
9646 if (Constant *C = Vals.lookup(I)) return C;
9647
9648 // An instruction inside the loop depends on a value outside the loop that we
9649 // weren't given a mapping for, or a value such as a call inside the loop.
9650 if (!canConstantEvolve(I, L)) return nullptr;
9651
9652 // An unmapped PHI can be due to a branch or another loop inside this loop,
9653 // or due to this not being the initial iteration through a loop where we
9654 // couldn't compute the evolution of this particular PHI last time.
9655 if (isa<PHINode>(I)) return nullptr;
9656
9657 std::vector<Constant*> Operands(I->getNumOperands());
9658
9659 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9660 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9661 if (!Operand) {
9662 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9663 if (!Operands[i]) return nullptr;
9664 continue;
9665 }
9666 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9667 Vals[Operand] = C;
9668 if (!C) return nullptr;
9669 Operands[i] = C;
9670 }
9671
9672 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9673 /*AllowNonDeterministic=*/false);
9674}
9675
9676
9677// If every incoming value to PN except the one for BB is a specific Constant,
9678// return that, else return nullptr.
9680 Constant *IncomingVal = nullptr;
9681
9682 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9683 if (PN->getIncomingBlock(i) == BB)
9684 continue;
9685
9686 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9687 if (!CurrentVal)
9688 return nullptr;
9689
9690 if (IncomingVal != CurrentVal) {
9691 if (IncomingVal)
9692 return nullptr;
9693 IncomingVal = CurrentVal;
9694 }
9695 }
9696
9697 return IncomingVal;
9698}
9699
9700/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9701/// in the header of its containing loop, we know the loop executes a
9702/// constant number of times, and the PHI node is just a recurrence
9703/// involving constants, fold it.
9704Constant *
9705ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9706 const APInt &BEs,
9707 const Loop *L) {
9708 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9709 if (!Inserted)
9710 return I->second;
9711
9713 return nullptr; // Not going to evaluate it.
9714
9715 Constant *&RetVal = I->second;
9716
9718 BasicBlock *Header = L->getHeader();
9719 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9720
9721 BasicBlock *Latch = L->getLoopLatch();
9722 if (!Latch)
9723 return nullptr;
9724
9725 for (PHINode &PHI : Header->phis()) {
9726 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9727 CurrentIterVals[&PHI] = StartCST;
9728 }
9729 if (!CurrentIterVals.count(PN))
9730 return RetVal = nullptr;
9731
9732 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9733
9734 // Execute the loop symbolically to determine the exit value.
9735 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9736 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9737
9738 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9739 unsigned IterationNum = 0;
9740 const DataLayout &DL = getDataLayout();
9741 for (; ; ++IterationNum) {
9742 if (IterationNum == NumIterations)
9743 return RetVal = CurrentIterVals[PN]; // Got exit value!
9744
9745 // Compute the value of the PHIs for the next iteration.
9746 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9748 Constant *NextPHI =
9749 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9750 if (!NextPHI)
9751 return nullptr; // Couldn't evaluate!
9752 NextIterVals[PN] = NextPHI;
9753
9754 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9755
9756 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9757 // cease to be able to evaluate one of them or if they stop evolving,
9758 // because that doesn't necessarily prevent us from computing PN.
9760 for (const auto &I : CurrentIterVals) {
9761 PHINode *PHI = dyn_cast<PHINode>(I.first);
9762 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9763 PHIsToCompute.emplace_back(PHI, I.second);
9764 }
9765 // We use two distinct loops because EvaluateExpression may invalidate any
9766 // iterators into CurrentIterVals.
9767 for (const auto &I : PHIsToCompute) {
9768 PHINode *PHI = I.first;
9769 Constant *&NextPHI = NextIterVals[PHI];
9770 if (!NextPHI) { // Not already computed.
9771 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9772 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9773 }
9774 if (NextPHI != I.second)
9775 StoppedEvolving = false;
9776 }
9777
9778 // If all entries in CurrentIterVals == NextIterVals then we can stop
9779 // iterating, the loop can't continue to change.
9780 if (StoppedEvolving)
9781 return RetVal = CurrentIterVals[PN];
9782
9783 CurrentIterVals.swap(NextIterVals);
9784 }
9785}
9786
9787const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9788 Value *Cond,
9789 bool ExitWhen) {
9791 if (!PN) return getCouldNotCompute();
9792
9793 // If the loop is canonicalized, the PHI will have exactly two entries.
9794 // That's the only form we support here.
9795 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9796
9798 BasicBlock *Header = L->getHeader();
9799 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9800
9801 BasicBlock *Latch = L->getLoopLatch();
9802 assert(Latch && "Should follow from NumIncomingValues == 2!");
9803
9804 for (PHINode &PHI : Header->phis()) {
9805 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9806 CurrentIterVals[&PHI] = StartCST;
9807 }
9808 if (!CurrentIterVals.count(PN))
9809 return getCouldNotCompute();
9810
9811 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9812 // the loop symbolically to determine when the condition gets a value of
9813 // "ExitWhen".
9814 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9815 const DataLayout &DL = getDataLayout();
9816 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9817 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9818 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9819
9820 // Couldn't symbolically evaluate.
9821 if (!CondVal) return getCouldNotCompute();
9822
9823 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9824 ++NumBruteForceTripCountsComputed;
9825 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9826 }
9827
9828 // Update all the PHI nodes for the next iteration.
9830
9831 // Create a list of which PHIs we need to compute. We want to do this before
9832 // calling EvaluateExpression on them because that may invalidate iterators
9833 // into CurrentIterVals.
9834 SmallVector<PHINode *, 8> PHIsToCompute;
9835 for (const auto &I : CurrentIterVals) {
9836 PHINode *PHI = dyn_cast<PHINode>(I.first);
9837 if (!PHI || PHI->getParent() != Header) continue;
9838 PHIsToCompute.push_back(PHI);
9839 }
9840 for (PHINode *PHI : PHIsToCompute) {
9841 Constant *&NextPHI = NextIterVals[PHI];
9842 if (NextPHI) continue; // Already computed!
9843
9844 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9845 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9846 }
9847 CurrentIterVals.swap(NextIterVals);
9848 }
9849
9850 // Too many iterations were needed to evaluate.
9851 return getCouldNotCompute();
9852}
9853
9854const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9856 ValuesAtScopes[V];
9857 // Check to see if we've folded this expression at this loop before.
9858 for (auto &LS : Values)
9859 if (LS.first == L)
9860 return LS.second ? LS.second : V;
9861
9862 Values.emplace_back(L, nullptr);
9863
9864 // Otherwise compute it.
9865 const SCEV *C = computeSCEVAtScope(V, L);
9866 for (auto &LS : reverse(ValuesAtScopes[V]))
9867 if (LS.first == L) {
9868 LS.second = C;
9869 if (!isa<SCEVConstant>(C))
9870 ValuesAtScopesUsers[C].push_back({L, V});
9871 break;
9872 }
9873 return C;
9874}
9875
9876/// This builds up a Constant using the ConstantExpr interface. That way, we
9877/// will return Constants for objects which aren't represented by a
9878/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9879/// Returns NULL if the SCEV isn't representable as a Constant.
9881 switch (V->getSCEVType()) {
9882 case scCouldNotCompute:
9883 case scAddRecExpr:
9884 case scVScale:
9885 return nullptr;
9886 case scConstant:
9887 return cast<SCEVConstant>(V)->getValue();
9888 case scUnknown:
9889 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9890 case scPtrToInt: {
9891 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9892 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9893 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9894
9895 return nullptr;
9896 }
9897 case scTruncate: {
9898 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9899 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9900 return ConstantExpr::getTrunc(CastOp, ST->getType());
9901 return nullptr;
9902 }
9903 case scAddExpr: {
9904 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9905 Constant *C = nullptr;
9906 for (const SCEV *Op : SA->operands()) {
9908 if (!OpC)
9909 return nullptr;
9910 if (!C) {
9911 C = OpC;
9912 continue;
9913 }
9914 assert(!C->getType()->isPointerTy() &&
9915 "Can only have one pointer, and it must be last");
9916 if (OpC->getType()->isPointerTy()) {
9917 // The offsets have been converted to bytes. We can add bytes using
9918 // an i8 GEP.
9920 OpC, C);
9921 } else {
9922 C = ConstantExpr::getAdd(C, OpC);
9923 }
9924 }
9925 return C;
9926 }
9927 case scMulExpr:
9928 case scSignExtend:
9929 case scZeroExtend:
9930 case scUDivExpr:
9931 case scSMaxExpr:
9932 case scUMaxExpr:
9933 case scSMinExpr:
9934 case scUMinExpr:
9936 return nullptr;
9937 }
9938 llvm_unreachable("Unknown SCEV kind!");
9939}
9940
9941const SCEV *
9942ScalarEvolution::getWithOperands(const SCEV *S,
9944 switch (S->getSCEVType()) {
9945 case scTruncate:
9946 case scZeroExtend:
9947 case scSignExtend:
9948 case scPtrToInt:
9949 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9950 case scAddRecExpr: {
9951 auto *AddRec = cast<SCEVAddRecExpr>(S);
9952 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9953 }
9954 case scAddExpr:
9955 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9956 case scMulExpr:
9957 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9958 case scUDivExpr:
9959 return getUDivExpr(NewOps[0], NewOps[1]);
9960 case scUMaxExpr:
9961 case scSMaxExpr:
9962 case scUMinExpr:
9963 case scSMinExpr:
9964 return getMinMaxExpr(S->getSCEVType(), NewOps);
9966 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9967 case scConstant:
9968 case scVScale:
9969 case scUnknown:
9970 return S;
9971 case scCouldNotCompute:
9972 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9973 }
9974 llvm_unreachable("Unknown SCEV kind!");
9975}
9976
9977const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9978 switch (V->getSCEVType()) {
9979 case scConstant:
9980 case scVScale:
9981 return V;
9982 case scAddRecExpr: {
9983 // If this is a loop recurrence for a loop that does not contain L, then we
9984 // are dealing with the final value computed by the loop.
9985 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9986 // First, attempt to evaluate each operand.
9987 // Avoid performing the look-up in the common case where the specified
9988 // expression has no loop-variant portions.
9989 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9990 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9991 if (OpAtScope == AddRec->getOperand(i))
9992 continue;
9993
9994 // Okay, at least one of these operands is loop variant but might be
9995 // foldable. Build a new instance of the folded commutative expression.
9997 NewOps.reserve(AddRec->getNumOperands());
9998 append_range(NewOps, AddRec->operands().take_front(i));
9999 NewOps.push_back(OpAtScope);
10000 for (++i; i != e; ++i)
10001 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10002
10003 const SCEV *FoldedRec = getAddRecExpr(
10004 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10005 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10006 // The addrec may be folded to a nonrecurrence, for example, if the
10007 // induction variable is multiplied by zero after constant folding. Go
10008 // ahead and return the folded value.
10009 if (!AddRec)
10010 return FoldedRec;
10011 break;
10012 }
10013
10014 // If the scope is outside the addrec's loop, evaluate it by using the
10015 // loop exit value of the addrec.
10016 if (!AddRec->getLoop()->contains(L)) {
10017 // To evaluate this recurrence, we need to know how many times the AddRec
10018 // loop iterates. Compute this now.
10019 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10020 if (BackedgeTakenCount == getCouldNotCompute())
10021 return AddRec;
10022
10023 // Then, evaluate the AddRec.
10024 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10025 }
10026
10027 return AddRec;
10028 }
10029 case scTruncate:
10030 case scZeroExtend:
10031 case scSignExtend:
10032 case scPtrToInt:
10033 case scAddExpr:
10034 case scMulExpr:
10035 case scUDivExpr:
10036 case scUMaxExpr:
10037 case scSMaxExpr:
10038 case scUMinExpr:
10039 case scSMinExpr:
10040 case scSequentialUMinExpr: {
10041 ArrayRef<const SCEV *> Ops = V->operands();
10042 // Avoid performing the look-up in the common case where the specified
10043 // expression has no loop-variant portions.
10044 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10045 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10046 if (OpAtScope != Ops[i]) {
10047 // Okay, at least one of these operands is loop variant but might be
10048 // foldable. Build a new instance of the folded commutative expression.
10050 NewOps.reserve(Ops.size());
10051 append_range(NewOps, Ops.take_front(i));
10052 NewOps.push_back(OpAtScope);
10053
10054 for (++i; i != e; ++i) {
10055 OpAtScope = getSCEVAtScope(Ops[i], L);
10056 NewOps.push_back(OpAtScope);
10057 }
10058
10059 return getWithOperands(V, NewOps);
10060 }
10061 }
10062 // If we got here, all operands are loop invariant.
10063 return V;
10064 }
10065 case scUnknown: {
10066 // If this instruction is evolved from a constant-evolving PHI, compute the
10067 // exit value from the loop without using SCEVs.
10068 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10069 Instruction *I = dyn_cast<Instruction>(SU->getValue());
10070 if (!I)
10071 return V; // This is some other type of SCEVUnknown, just return it.
10072
10073 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10074 const Loop *CurrLoop = this->LI[I->getParent()];
10075 // Looking for loop exit value.
10076 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10077 PN->getParent() == CurrLoop->getHeader()) {
10078 // Okay, there is no closed form solution for the PHI node. Check
10079 // to see if the loop that contains it has a known backedge-taken
10080 // count. If so, we may be able to force computation of the exit
10081 // value.
10082 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10083 // This trivial case can show up in some degenerate cases where
10084 // the incoming IR has not yet been fully simplified.
10085 if (BackedgeTakenCount->isZero()) {
10086 Value *InitValue = nullptr;
10087 bool MultipleInitValues = false;
10088 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10089 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10090 if (!InitValue)
10091 InitValue = PN->getIncomingValue(i);
10092 else if (InitValue != PN->getIncomingValue(i)) {
10093 MultipleInitValues = true;
10094 break;
10095 }
10096 }
10097 }
10098 if (!MultipleInitValues && InitValue)
10099 return getSCEV(InitValue);
10100 }
10101 // Do we have a loop invariant value flowing around the backedge
10102 // for a loop which must execute the backedge?
10103 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10104 isKnownNonZero(BackedgeTakenCount) &&
10105 PN->getNumIncomingValues() == 2) {
10106
10107 unsigned InLoopPred =
10108 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10109 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10110 if (CurrLoop->isLoopInvariant(BackedgeVal))
10111 return getSCEV(BackedgeVal);
10112 }
10113 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10114 // Okay, we know how many times the containing loop executes. If
10115 // this is a constant evolving PHI node, get the final value at
10116 // the specified iteration number.
10117 Constant *RV =
10118 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10119 if (RV)
10120 return getSCEV(RV);
10121 }
10122 }
10123 }
10124
10125 // Okay, this is an expression that we cannot symbolically evaluate
10126 // into a SCEV. Check to see if it's possible to symbolically evaluate
10127 // the arguments into constants, and if so, try to constant propagate the
10128 // result. This is particularly useful for computing loop exit values.
10129 if (!CanConstantFold(I))
10130 return V; // This is some other type of SCEVUnknown, just return it.
10131
10133 Operands.reserve(I->getNumOperands());
10134 bool MadeImprovement = false;
10135 for (Value *Op : I->operands()) {
10136 if (Constant *C = dyn_cast<Constant>(Op)) {
10137 Operands.push_back(C);
10138 continue;
10139 }
10140
10141 // If any of the operands is non-constant and if they are
10142 // non-integer and non-pointer, don't even try to analyze them
10143 // with scev techniques.
10144 if (!isSCEVable(Op->getType()))
10145 return V;
10146
10147 const SCEV *OrigV = getSCEV(Op);
10148 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10149 MadeImprovement |= OrigV != OpV;
10150
10152 if (!C)
10153 return V;
10154 assert(C->getType() == Op->getType() && "Type mismatch");
10155 Operands.push_back(C);
10156 }
10157
10158 // Check to see if getSCEVAtScope actually made an improvement.
10159 if (!MadeImprovement)
10160 return V; // This is some other type of SCEVUnknown, just return it.
10161
10162 Constant *C = nullptr;
10163 const DataLayout &DL = getDataLayout();
10164 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10165 /*AllowNonDeterministic=*/false);
10166 if (!C)
10167 return V;
10168 return getSCEV(C);
10169 }
10170 case scCouldNotCompute:
10171 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10172 }
10173 llvm_unreachable("Unknown SCEV type!");
10174}
10175
10177 return getSCEVAtScope(getSCEV(V), L);
10178}
10179
10180const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10181 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10182 return stripInjectiveFunctions(ZExt->getOperand());
10183 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10184 return stripInjectiveFunctions(SExt->getOperand());
10185 return S;
10186}
10187
10188/// Finds the minimum unsigned root of the following equation:
10189///
10190/// A * X = B (mod N)
10191///
10192/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10193/// A and B isn't important.
10194///
10195/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10196static const SCEV *
10199
10200 ScalarEvolution &SE) {
10201 uint32_t BW = A.getBitWidth();
10202 assert(BW == SE.getTypeSizeInBits(B->getType()));
10203 assert(A != 0 && "A must be non-zero.");
10204
10205 // 1. D = gcd(A, N)
10206 //
10207 // The gcd of A and N may have only one prime factor: 2. The number of
10208 // trailing zeros in A is its multiplicity
10209 uint32_t Mult2 = A.countr_zero();
10210 // D = 2^Mult2
10211
10212 // 2. Check if B is divisible by D.
10213 //
10214 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10215 // is not less than multiplicity of this prime factor for D.
10216 if (SE.getMinTrailingZeros(B) < Mult2) {
10217 // Check if we can prove there's no remainder using URem.
10218 const SCEV *URem =
10219 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10220 const SCEV *Zero = SE.getZero(B->getType());
10221 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10222 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10223 if (!Predicates)
10224 return SE.getCouldNotCompute();
10225
10226 // Avoid adding a predicate that is known to be false.
10227 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10228 return SE.getCouldNotCompute();
10229 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10230 }
10231 }
10232
10233 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10234 // modulo (N / D).
10235 //
10236 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10237 // (N / D) in general. The inverse itself always fits into BW bits, though,
10238 // so we immediately truncate it.
10239 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10240 APInt I = AD.multiplicativeInverse().zext(BW);
10241
10242 // 4. Compute the minimum unsigned root of the equation:
10243 // I * (B / D) mod (N / D)
10244 // To simplify the computation, we factor out the divide by D:
10245 // (I * B mod N) / D
10246 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10247 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10248}
10249
10250/// For a given quadratic addrec, generate coefficients of the corresponding
10251/// quadratic equation, multiplied by a common value to ensure that they are
10252/// integers.
10253/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10254/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10255/// were multiplied by, and BitWidth is the bit width of the original addrec
10256/// coefficients.
10257/// This function returns std::nullopt if the addrec coefficients are not
10258/// compile- time constants.
10259static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10261 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10262 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10263 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10264 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10265 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10266 << *AddRec << '\n');
10267
10268 // We currently can only solve this if the coefficients are constants.
10269 if (!LC || !MC || !NC) {
10270 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10271 return std::nullopt;
10272 }
10273
10274 APInt L = LC->getAPInt();
10275 APInt M = MC->getAPInt();
10276 APInt N = NC->getAPInt();
10277 assert(!N.isZero() && "This is not a quadratic addrec");
10278
10279 unsigned BitWidth = LC->getAPInt().getBitWidth();
10280 unsigned NewWidth = BitWidth + 1;
10281 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10282 << BitWidth << '\n');
10283 // The sign-extension (as opposed to a zero-extension) here matches the
10284 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10285 N = N.sext(NewWidth);
10286 M = M.sext(NewWidth);
10287 L = L.sext(NewWidth);
10288
10289 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10290 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10291 // L+M, L+2M+N, L+3M+3N, ...
10292 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10293 //
10294 // The equation Acc = 0 is then
10295 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10296 // In a quadratic form it becomes:
10297 // N n^2 + (2M-N) n + 2L = 0.
10298
10299 APInt A = N;
10300 APInt B = 2 * M - A;
10301 APInt C = 2 * L;
10302 APInt T = APInt(NewWidth, 2);
10303 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10304 << "x + " << C << ", coeff bw: " << NewWidth
10305 << ", multiplied by " << T << '\n');
10306 return std::make_tuple(A, B, C, T, BitWidth);
10307}
10308
10309/// Helper function to compare optional APInts:
10310/// (a) if X and Y both exist, return min(X, Y),
10311/// (b) if neither X nor Y exist, return std::nullopt,
10312/// (c) if exactly one of X and Y exists, return that value.
10313static std::optional<APInt> MinOptional(std::optional<APInt> X,
10314 std::optional<APInt> Y) {
10315 if (X && Y) {
10316 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10317 APInt XW = X->sext(W);
10318 APInt YW = Y->sext(W);
10319 return XW.slt(YW) ? *X : *Y;
10320 }
10321 if (!X && !Y)
10322 return std::nullopt;
10323 return X ? *X : *Y;
10324}
10325
10326/// Helper function to truncate an optional APInt to a given BitWidth.
10327/// When solving addrec-related equations, it is preferable to return a value
10328/// that has the same bit width as the original addrec's coefficients. If the
10329/// solution fits in the original bit width, truncate it (except for i1).
10330/// Returning a value of a different bit width may inhibit some optimizations.
10331///
10332/// In general, a solution to a quadratic equation generated from an addrec
10333/// may require BW+1 bits, where BW is the bit width of the addrec's
10334/// coefficients. The reason is that the coefficients of the quadratic
10335/// equation are BW+1 bits wide (to avoid truncation when converting from
10336/// the addrec to the equation).
10337static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10338 unsigned BitWidth) {
10339 if (!X)
10340 return std::nullopt;
10341 unsigned W = X->getBitWidth();
10342 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10343 return X->trunc(BitWidth);
10344 return X;
10345}
10346
10347/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10348/// iterations. The values L, M, N are assumed to be signed, and they
10349/// should all have the same bit widths.
10350/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10351/// where BW is the bit width of the addrec's coefficients.
10352/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10353/// returned as such, otherwise the bit width of the returned value may
10354/// be greater than BW.
10355///
10356/// This function returns std::nullopt if
10357/// (a) the addrec coefficients are not constant, or
10358/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10359/// like x^2 = 5, no integer solutions exist, in other cases an integer
10360/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10361static std::optional<APInt>
10363 APInt A, B, C, M;
10364 unsigned BitWidth;
10365 auto T = GetQuadraticEquation(AddRec);
10366 if (!T)
10367 return std::nullopt;
10368
10369 std::tie(A, B, C, M, BitWidth) = *T;
10370 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10371 std::optional<APInt> X =
10373 if (!X)
10374 return std::nullopt;
10375
10376 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10377 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10378 if (!V->isZero())
10379 return std::nullopt;
10380
10381 return TruncIfPossible(X, BitWidth);
10382}
10383
10384/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10385/// iterations. The values M, N are assumed to be signed, and they
10386/// should all have the same bit widths.
10387/// Find the least n such that c(n) does not belong to the given range,
10388/// while c(n-1) does.
10389///
10390/// This function returns std::nullopt if
10391/// (a) the addrec coefficients are not constant, or
10392/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10393/// bounds of the range.
10394static std::optional<APInt>
10396 const ConstantRange &Range, ScalarEvolution &SE) {
10397 assert(AddRec->getOperand(0)->isZero() &&
10398 "Starting value of addrec should be 0");
10399 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10400 << Range << ", addrec " << *AddRec << '\n');
10401 // This case is handled in getNumIterationsInRange. Here we can assume that
10402 // we start in the range.
10403 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10404 "Addrec's initial value should be in range");
10405
10406 APInt A, B, C, M;
10407 unsigned BitWidth;
10408 auto T = GetQuadraticEquation(AddRec);
10409 if (!T)
10410 return std::nullopt;
10411
10412 // Be careful about the return value: there can be two reasons for not
10413 // returning an actual number. First, if no solutions to the equations
10414 // were found, and second, if the solutions don't leave the given range.
10415 // The first case means that the actual solution is "unknown", the second
10416 // means that it's known, but not valid. If the solution is unknown, we
10417 // cannot make any conclusions.
10418 // Return a pair: the optional solution and a flag indicating if the
10419 // solution was found.
10420 auto SolveForBoundary =
10421 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10422 // Solve for signed overflow and unsigned overflow, pick the lower
10423 // solution.
10424 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10425 << Bound << " (before multiplying by " << M << ")\n");
10426 Bound *= M; // The quadratic equation multiplier.
10427
10428 std::optional<APInt> SO;
10429 if (BitWidth > 1) {
10430 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10431 "signed overflow\n");
10433 }
10434 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10435 "unsigned overflow\n");
10436 std::optional<APInt> UO =
10438
10439 auto LeavesRange = [&] (const APInt &X) {
10440 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10441 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10442 if (Range.contains(V0->getValue()))
10443 return false;
10444 // X should be at least 1, so X-1 is non-negative.
10445 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10446 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10447 if (Range.contains(V1->getValue()))
10448 return true;
10449 return false;
10450 };
10451
10452 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10453 // can be a solution, but the function failed to find it. We cannot treat it
10454 // as "no solution".
10455 if (!SO || !UO)
10456 return {std::nullopt, false};
10457
10458 // Check the smaller value first to see if it leaves the range.
10459 // At this point, both SO and UO must have values.
10460 std::optional<APInt> Min = MinOptional(SO, UO);
10461 if (LeavesRange(*Min))
10462 return { Min, true };
10463 std::optional<APInt> Max = Min == SO ? UO : SO;
10464 if (LeavesRange(*Max))
10465 return { Max, true };
10466
10467 // Solutions were found, but were eliminated, hence the "true".
10468 return {std::nullopt, true};
10469 };
10470
10471 std::tie(A, B, C, M, BitWidth) = *T;
10472 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10473 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10474 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10475 auto SL = SolveForBoundary(Lower);
10476 auto SU = SolveForBoundary(Upper);
10477 // If any of the solutions was unknown, no meaninigful conclusions can
10478 // be made.
10479 if (!SL.second || !SU.second)
10480 return std::nullopt;
10481
10482 // Claim: The correct solution is not some value between Min and Max.
10483 //
10484 // Justification: Assuming that Min and Max are different values, one of
10485 // them is when the first signed overflow happens, the other is when the
10486 // first unsigned overflow happens. Crossing the range boundary is only
10487 // possible via an overflow (treating 0 as a special case of it, modeling
10488 // an overflow as crossing k*2^W for some k).
10489 //
10490 // The interesting case here is when Min was eliminated as an invalid
10491 // solution, but Max was not. The argument is that if there was another
10492 // overflow between Min and Max, it would also have been eliminated if
10493 // it was considered.
10494 //
10495 // For a given boundary, it is possible to have two overflows of the same
10496 // type (signed/unsigned) without having the other type in between: this
10497 // can happen when the vertex of the parabola is between the iterations
10498 // corresponding to the overflows. This is only possible when the two
10499 // overflows cross k*2^W for the same k. In such case, if the second one
10500 // left the range (and was the first one to do so), the first overflow
10501 // would have to enter the range, which would mean that either we had left
10502 // the range before or that we started outside of it. Both of these cases
10503 // are contradictions.
10504 //
10505 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10506 // solution is not some value between the Max for this boundary and the
10507 // Min of the other boundary.
10508 //
10509 // Justification: Assume that we had such Max_A and Min_B corresponding
10510 // to range boundaries A and B and such that Max_A < Min_B. If there was
10511 // a solution between Max_A and Min_B, it would have to be caused by an
10512 // overflow corresponding to either A or B. It cannot correspond to B,
10513 // since Min_B is the first occurrence of such an overflow. If it
10514 // corresponded to A, it would have to be either a signed or an unsigned
10515 // overflow that is larger than both eliminated overflows for A. But
10516 // between the eliminated overflows and this overflow, the values would
10517 // cover the entire value space, thus crossing the other boundary, which
10518 // is a contradiction.
10519
10520 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10521}
10522
10523ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10524 const Loop *L,
10525 bool ControlsOnlyExit,
10526 bool AllowPredicates) {
10527
10528 // This is only used for loops with a "x != y" exit test. The exit condition
10529 // is now expressed as a single expression, V = x-y. So the exit test is
10530 // effectively V != 0. We know and take advantage of the fact that this
10531 // expression only being used in a comparison by zero context.
10532
10534 // If the value is a constant
10535 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10536 // If the value is already zero, the branch will execute zero times.
10537 if (C->getValue()->isZero()) return C;
10538 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10539 }
10540
10541 const SCEVAddRecExpr *AddRec =
10542 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10543
10544 if (!AddRec && AllowPredicates)
10545 // Try to make this an AddRec using runtime tests, in the first X
10546 // iterations of this loop, where X is the SCEV expression found by the
10547 // algorithm below.
10548 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10549
10550 if (!AddRec || AddRec->getLoop() != L)
10551 return getCouldNotCompute();
10552
10553 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10554 // the quadratic equation to solve it.
10555 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10556 // We can only use this value if the chrec ends up with an exact zero
10557 // value at this index. When solving for "X*X != 5", for example, we
10558 // should not accept a root of 2.
10559 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10560 const auto *R = cast<SCEVConstant>(getConstant(*S));
10561 return ExitLimit(R, R, R, false, Predicates);
10562 }
10563 return getCouldNotCompute();
10564 }
10565
10566 // Otherwise we can only handle this if it is affine.
10567 if (!AddRec->isAffine())
10568 return getCouldNotCompute();
10569
10570 // If this is an affine expression, the execution count of this branch is
10571 // the minimum unsigned root of the following equation:
10572 //
10573 // Start + Step*N = 0 (mod 2^BW)
10574 //
10575 // equivalent to:
10576 //
10577 // Step*N = -Start (mod 2^BW)
10578 //
10579 // where BW is the common bit width of Start and Step.
10580
10581 // Get the initial value for the loop.
10582 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10583 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10584
10585 if (!isLoopInvariant(Step, L))
10586 return getCouldNotCompute();
10587
10588 LoopGuards Guards = LoopGuards::collect(L, *this);
10589 // Specialize step for this loop so we get context sensitive facts below.
10590 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10591
10592 // For positive steps (counting up until unsigned overflow):
10593 // N = -Start/Step (as unsigned)
10594 // For negative steps (counting down to zero):
10595 // N = Start/-Step
10596 // First compute the unsigned distance from zero in the direction of Step.
10597 bool CountDown = isKnownNegative(StepWLG);
10598 if (!CountDown && !isKnownNonNegative(StepWLG))
10599 return getCouldNotCompute();
10600
10601 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10602 // Handle unitary steps, which cannot wraparound.
10603 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10604 // N = Distance (as unsigned)
10605
10606 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10607 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10608 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10609
10610 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10611 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10612 // case, and see if we can improve the bound.
10613 //
10614 // Explicitly handling this here is necessary because getUnsignedRange
10615 // isn't context-sensitive; it doesn't know that we only care about the
10616 // range inside the loop.
10617 const SCEV *Zero = getZero(Distance->getType());
10618 const SCEV *One = getOne(Distance->getType());
10619 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10620 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10621 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10622 // as "unsigned_max(Distance + 1) - 1".
10623 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10624 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10625 }
10626 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10627 Predicates);
10628 }
10629
10630 // If the condition controls loop exit (the loop exits only if the expression
10631 // is true) and the addition is no-wrap we can use unsigned divide to
10632 // compute the backedge count. In this case, the step may not divide the
10633 // distance, but we don't care because if the condition is "missed" the loop
10634 // will have undefined behavior due to wrapping.
10635 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10636 loopHasNoAbnormalExits(AddRec->getLoop())) {
10637
10638 // If the stride is zero, the loop must be infinite. In C++, most loops
10639 // are finite by assumption, in which case the step being zero implies
10640 // UB must execute if the loop is entered.
10641 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10642 return getCouldNotCompute();
10643
10644 const SCEV *Exact =
10645 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10646 const SCEV *ConstantMax = getCouldNotCompute();
10647 if (Exact != getCouldNotCompute()) {
10649 ConstantMax =
10651 }
10652 const SCEV *SymbolicMax =
10653 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10654 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10655 }
10656
10657 // Solve the general equation.
10658 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10659 if (!StepC || StepC->getValue()->isZero())
10660 return getCouldNotCompute();
10662 StepC->getAPInt(), getNegativeSCEV(Start),
10663 AllowPredicates ? &Predicates : nullptr, *this);
10664
10665 const SCEV *M = E;
10666 if (E != getCouldNotCompute()) {
10667 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10668 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10669 }
10670 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10671 return ExitLimit(E, M, S, false, Predicates);
10672}
10673
10675ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10676 // Loops that look like: while (X == 0) are very strange indeed. We don't
10677 // handle them yet except for the trivial case. This could be expanded in the
10678 // future as needed.
10679
10680 // If the value is a constant, check to see if it is known to be non-zero
10681 // already. If so, the backedge will execute zero times.
10682 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10683 if (!C->getValue()->isZero())
10684 return getZero(C->getType());
10685 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10686 }
10687
10688 // We could implement others, but I really doubt anyone writes loops like
10689 // this, and if they did, they would already be constant folded.
10690 return getCouldNotCompute();
10691}
10692
10693std::pair<const BasicBlock *, const BasicBlock *>
10694ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10695 const {
10696 // If the block has a unique predecessor, then there is no path from the
10697 // predecessor to the block that does not go through the direct edge
10698 // from the predecessor to the block.
10699 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10700 return {Pred, BB};
10701
10702 // A loop's header is defined to be a block that dominates the loop.
10703 // If the header has a unique predecessor outside the loop, it must be
10704 // a block that has exactly one successor that can reach the loop.
10705 if (const Loop *L = LI.getLoopFor(BB))
10706 return {L->getLoopPredecessor(), L->getHeader()};
10707
10708 return {nullptr, BB};
10709}
10710
10711/// SCEV structural equivalence is usually sufficient for testing whether two
10712/// expressions are equal, however for the purposes of looking for a condition
10713/// guarding a loop, it can be useful to be a little more general, since a
10714/// front-end may have replicated the controlling expression.
10715static bool HasSameValue(const SCEV *A, const SCEV *B) {
10716 // Quick check to see if they are the same SCEV.
10717 if (A == B) return true;
10718
10719 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10720 // Not all instructions that are "identical" compute the same value. For
10721 // instance, two distinct alloca instructions allocating the same type are
10722 // identical and do not read memory; but compute distinct values.
10723 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10724 };
10725
10726 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10727 // two different instructions with the same value. Check for this case.
10728 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10729 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10730 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10731 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10732 if (ComputesEqualValues(AI, BI))
10733 return true;
10734
10735 // Otherwise assume they may have a different value.
10736 return false;
10737}
10738
10739static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10740 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10741 if (!Add || Add->getNumOperands() != 2)
10742 return false;
10743 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10744 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10745 LHS = Add->getOperand(1);
10746 RHS = ME->getOperand(1);
10747 return true;
10748 }
10749 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10750 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10751 LHS = Add->getOperand(0);
10752 RHS = ME->getOperand(1);
10753 return true;
10754 }
10755 return false;
10756}
10757
10759 const SCEV *&RHS, unsigned Depth) {
10760 bool Changed = false;
10761 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10762 // '0 != 0'.
10763 auto TrivialCase = [&](bool TriviallyTrue) {
10765 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10766 return true;
10767 };
10768 // If we hit the max recursion limit bail out.
10769 if (Depth >= 3)
10770 return false;
10771
10772 // Canonicalize a constant to the right side.
10773 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10774 // Check for both operands constant.
10775 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10776 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10777 return TrivialCase(false);
10778 return TrivialCase(true);
10779 }
10780 // Otherwise swap the operands to put the constant on the right.
10781 std::swap(LHS, RHS);
10783 Changed = true;
10784 }
10785
10786 // If we're comparing an addrec with a value which is loop-invariant in the
10787 // addrec's loop, put the addrec on the left. Also make a dominance check,
10788 // as both operands could be addrecs loop-invariant in each other's loop.
10789 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10790 const Loop *L = AR->getLoop();
10791 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10792 std::swap(LHS, RHS);
10794 Changed = true;
10795 }
10796 }
10797
10798 // If there's a constant operand, canonicalize comparisons with boundary
10799 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10800 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10801 const APInt &RA = RC->getAPInt();
10802
10803 bool SimplifiedByConstantRange = false;
10804
10805 if (!ICmpInst::isEquality(Pred)) {
10807 if (ExactCR.isFullSet())
10808 return TrivialCase(true);
10809 if (ExactCR.isEmptySet())
10810 return TrivialCase(false);
10811
10812 APInt NewRHS;
10813 CmpInst::Predicate NewPred;
10814 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10815 ICmpInst::isEquality(NewPred)) {
10816 // We were able to convert an inequality to an equality.
10817 Pred = NewPred;
10818 RHS = getConstant(NewRHS);
10819 Changed = SimplifiedByConstantRange = true;
10820 }
10821 }
10822
10823 if (!SimplifiedByConstantRange) {
10824 switch (Pred) {
10825 default:
10826 break;
10827 case ICmpInst::ICMP_EQ:
10828 case ICmpInst::ICMP_NE:
10829 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10830 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10831 Changed = true;
10832 break;
10833
10834 // The "Should have been caught earlier!" messages refer to the fact
10835 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10836 // should have fired on the corresponding cases, and canonicalized the
10837 // check to trivial case.
10838
10839 case ICmpInst::ICMP_UGE:
10840 assert(!RA.isMinValue() && "Should have been caught earlier!");
10841 Pred = ICmpInst::ICMP_UGT;
10842 RHS = getConstant(RA - 1);
10843 Changed = true;
10844 break;
10845 case ICmpInst::ICMP_ULE:
10846 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10847 Pred = ICmpInst::ICMP_ULT;
10848 RHS = getConstant(RA + 1);
10849 Changed = true;
10850 break;
10851 case ICmpInst::ICMP_SGE:
10852 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10853 Pred = ICmpInst::ICMP_SGT;
10854 RHS = getConstant(RA - 1);
10855 Changed = true;
10856 break;
10857 case ICmpInst::ICMP_SLE:
10858 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10859 Pred = ICmpInst::ICMP_SLT;
10860 RHS = getConstant(RA + 1);
10861 Changed = true;
10862 break;
10863 }
10864 }
10865 }
10866
10867 // Check for obvious equality.
10868 if (HasSameValue(LHS, RHS)) {
10869 if (ICmpInst::isTrueWhenEqual(Pred))
10870 return TrivialCase(true);
10872 return TrivialCase(false);
10873 }
10874
10875 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10876 // adding or subtracting 1 from one of the operands.
10877 switch (Pred) {
10878 case ICmpInst::ICMP_SLE:
10879 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10880 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10882 Pred = ICmpInst::ICMP_SLT;
10883 Changed = true;
10884 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10885 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10887 Pred = ICmpInst::ICMP_SLT;
10888 Changed = true;
10889 }
10890 break;
10891 case ICmpInst::ICMP_SGE:
10892 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10893 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10895 Pred = ICmpInst::ICMP_SGT;
10896 Changed = true;
10897 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10898 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10900 Pred = ICmpInst::ICMP_SGT;
10901 Changed = true;
10902 }
10903 break;
10904 case ICmpInst::ICMP_ULE:
10905 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10906 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10908 Pred = ICmpInst::ICMP_ULT;
10909 Changed = true;
10910 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10911 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10912 Pred = ICmpInst::ICMP_ULT;
10913 Changed = true;
10914 }
10915 break;
10916 case ICmpInst::ICMP_UGE:
10917 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10918 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10919 Pred = ICmpInst::ICMP_UGT;
10920 Changed = true;
10921 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10922 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10924 Pred = ICmpInst::ICMP_UGT;
10925 Changed = true;
10926 }
10927 break;
10928 default:
10929 break;
10930 }
10931
10932 // TODO: More simplifications are possible here.
10933
10934 // Recursively simplify until we either hit a recursion limit or nothing
10935 // changes.
10936 if (Changed)
10937 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10938
10939 return Changed;
10940}
10941
10943 return getSignedRangeMax(S).isNegative();
10944}
10945
10948}
10949
10951 return !getSignedRangeMin(S).isNegative();
10952}
10953
10956}
10957
10959 // Query push down for cases where the unsigned range is
10960 // less than sufficient.
10961 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10962 return isKnownNonZero(SExt->getOperand(0));
10963 return getUnsignedRangeMin(S) != 0;
10964}
10965
10967 bool OrNegative) {
10968 auto NonRecursive = [this, OrNegative](const SCEV *S) {
10969 if (auto *C = dyn_cast<SCEVConstant>(S))
10970 return C->getAPInt().isPowerOf2() ||
10971 (OrNegative && C->getAPInt().isNegatedPowerOf2());
10972
10973 // The vscale_range indicates vscale is a power-of-two.
10974 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10975 };
10976
10977 if (NonRecursive(S))
10978 return true;
10979
10980 auto *Mul = dyn_cast<SCEVMulExpr>(S);
10981 if (!Mul)
10982 return false;
10983 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10984}
10985
10986std::pair<const SCEV *, const SCEV *>
10988 // Compute SCEV on entry of loop L.
10989 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10990 if (Start == getCouldNotCompute())
10991 return { Start, Start };
10992 // Compute post increment SCEV for loop L.
10993 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10994 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10995 return { Start, PostInc };
10996}
10997
10999 const SCEV *RHS) {
11000 // First collect all loops.
11002 getUsedLoops(LHS, LoopsUsed);
11003 getUsedLoops(RHS, LoopsUsed);
11004
11005 if (LoopsUsed.empty())
11006 return false;
11007
11008 // Domination relationship must be a linear order on collected loops.
11009#ifndef NDEBUG
11010 for (const auto *L1 : LoopsUsed)
11011 for (const auto *L2 : LoopsUsed)
11012 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11013 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11014 "Domination relationship is not a linear order");
11015#endif
11016
11017 const Loop *MDL =
11018 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11019 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11020 });
11021
11022 // Get init and post increment value for LHS.
11023 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11024 // if LHS contains unknown non-invariant SCEV then bail out.
11025 if (SplitLHS.first == getCouldNotCompute())
11026 return false;
11027 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11028 // Get init and post increment value for RHS.
11029 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11030 // if RHS contains unknown non-invariant SCEV then bail out.
11031 if (SplitRHS.first == getCouldNotCompute())
11032 return false;
11033 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11034 // It is possible that init SCEV contains an invariant load but it does
11035 // not dominate MDL and is not available at MDL loop entry, so we should
11036 // check it here.
11037 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11038 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11039 return false;
11040
11041 // It seems backedge guard check is faster than entry one so in some cases
11042 // it can speed up whole estimation by short circuit
11043 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11044 SplitRHS.second) &&
11045 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11046}
11047
11049 const SCEV *RHS) {
11050 // Canonicalize the inputs first.
11051 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11052
11053 if (isKnownViaInduction(Pred, LHS, RHS))
11054 return true;
11055
11056 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11057 return true;
11058
11059 // Otherwise see what can be done with some simple reasoning.
11060 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11061}
11062
11064 const SCEV *LHS,
11065 const SCEV *RHS) {
11066 if (isKnownPredicate(Pred, LHS, RHS))
11067 return true;
11069 return false;
11070 return std::nullopt;
11071}
11072
11074 const SCEV *RHS,
11075 const Instruction *CtxI) {
11076 // TODO: Analyze guards and assumes from Context's block.
11077 return isKnownPredicate(Pred, LHS, RHS) ||
11079}
11080
11081std::optional<bool>
11083 const SCEV *RHS, const Instruction *CtxI) {
11084 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11085 if (KnownWithoutContext)
11086 return KnownWithoutContext;
11087
11088 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11089 return true;
11092 return false;
11093 return std::nullopt;
11094}
11095
11097 const SCEVAddRecExpr *LHS,
11098 const SCEV *RHS) {
11099 const Loop *L = LHS->getLoop();
11100 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11101 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11102}
11103
11104std::optional<ScalarEvolution::MonotonicPredicateType>
11106 ICmpInst::Predicate Pred) {
11107 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11108
11109#ifndef NDEBUG
11110 // Verify an invariant: inverting the predicate should turn a monotonically
11111 // increasing change to a monotonically decreasing one, and vice versa.
11112 if (Result) {
11113 auto ResultSwapped =
11114 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11115
11116 assert(*ResultSwapped != *Result &&
11117 "monotonicity should flip as we flip the predicate");
11118 }
11119#endif
11120
11121 return Result;
11122}
11123
11124std::optional<ScalarEvolution::MonotonicPredicateType>
11125ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11126 ICmpInst::Predicate Pred) {
11127 // A zero step value for LHS means the induction variable is essentially a
11128 // loop invariant value. We don't really depend on the predicate actually
11129 // flipping from false to true (for increasing predicates, and the other way
11130 // around for decreasing predicates), all we care about is that *if* the
11131 // predicate changes then it only changes from false to true.
11132 //
11133 // A zero step value in itself is not very useful, but there may be places
11134 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11135 // as general as possible.
11136
11137 // Only handle LE/LT/GE/GT predicates.
11138 if (!ICmpInst::isRelational(Pred))
11139 return std::nullopt;
11140
11141 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11142 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11143 "Should be greater or less!");
11144
11145 // Check that AR does not wrap.
11146 if (ICmpInst::isUnsigned(Pred)) {
11147 if (!LHS->hasNoUnsignedWrap())
11148 return std::nullopt;
11150 }
11151 assert(ICmpInst::isSigned(Pred) &&
11152 "Relational predicate is either signed or unsigned!");
11153 if (!LHS->hasNoSignedWrap())
11154 return std::nullopt;
11155
11156 const SCEV *Step = LHS->getStepRecurrence(*this);
11157
11158 if (isKnownNonNegative(Step))
11160
11161 if (isKnownNonPositive(Step))
11163
11164 return std::nullopt;
11165}
11166
11167std::optional<ScalarEvolution::LoopInvariantPredicate>
11169 const SCEV *RHS, const Loop *L,
11170 const Instruction *CtxI) {
11171 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11172 if (!isLoopInvariant(RHS, L)) {
11173 if (!isLoopInvariant(LHS, L))
11174 return std::nullopt;
11175
11176 std::swap(LHS, RHS);
11178 }
11179
11180 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11181 if (!ArLHS || ArLHS->getLoop() != L)
11182 return std::nullopt;
11183
11184 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11185 if (!MonotonicType)
11186 return std::nullopt;
11187 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11188 // true as the loop iterates, and the backedge is control dependent on
11189 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11190 //
11191 // * if the predicate was false in the first iteration then the predicate
11192 // is never evaluated again, since the loop exits without taking the
11193 // backedge.
11194 // * if the predicate was true in the first iteration then it will
11195 // continue to be true for all future iterations since it is
11196 // monotonically increasing.
11197 //
11198 // For both the above possibilities, we can replace the loop varying
11199 // predicate with its value on the first iteration of the loop (which is
11200 // loop invariant).
11201 //
11202 // A similar reasoning applies for a monotonically decreasing predicate, by
11203 // replacing true with false and false with true in the above two bullets.
11205 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11206
11209 RHS);
11210
11211 if (!CtxI)
11212 return std::nullopt;
11213 // Try to prove via context.
11214 // TODO: Support other cases.
11215 switch (Pred) {
11216 default:
11217 break;
11218 case ICmpInst::ICMP_ULE:
11219 case ICmpInst::ICMP_ULT: {
11220 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11221 // Given preconditions
11222 // (1) ArLHS does not cross the border of positive and negative parts of
11223 // range because of:
11224 // - Positive step; (TODO: lift this limitation)
11225 // - nuw - does not cross zero boundary;
11226 // - nsw - does not cross SINT_MAX boundary;
11227 // (2) ArLHS <s RHS
11228 // (3) RHS >=s 0
11229 // we can replace the loop variant ArLHS <u RHS condition with loop
11230 // invariant Start(ArLHS) <u RHS.
11231 //
11232 // Because of (1) there are two options:
11233 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11234 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11235 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11236 // Because of (2) ArLHS <u RHS is trivially true.
11237 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11238 // We can strengthen this to Start(ArLHS) <u RHS.
11239 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11240 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11241 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11243 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11245 RHS);
11246 }
11247 }
11248
11249 return std::nullopt;
11250}
11251
11252std::optional<ScalarEvolution::LoopInvariantPredicate>
11254 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11255 const Instruction *CtxI, const SCEV *MaxIter) {
11257 Pred, LHS, RHS, L, CtxI, MaxIter))
11258 return LIP;
11259 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11260 // Number of iterations expressed as UMIN isn't always great for expressing
11261 // the value on the last iteration. If the straightforward approach didn't
11262 // work, try the following trick: if the a predicate is invariant for X, it
11263 // is also invariant for umin(X, ...). So try to find something that works
11264 // among subexpressions of MaxIter expressed as umin.
11265 for (auto *Op : UMin->operands())
11267 Pred, LHS, RHS, L, CtxI, Op))
11268 return LIP;
11269 return std::nullopt;
11270}
11271
11272std::optional<ScalarEvolution::LoopInvariantPredicate>
11274 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11275 const Instruction *CtxI, const SCEV *MaxIter) {
11276 // Try to prove the following set of facts:
11277 // - The predicate is monotonic in the iteration space.
11278 // - If the check does not fail on the 1st iteration:
11279 // - No overflow will happen during first MaxIter iterations;
11280 // - It will not fail on the MaxIter'th iteration.
11281 // If the check does fail on the 1st iteration, we leave the loop and no
11282 // other checks matter.
11283
11284 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11285 if (!isLoopInvariant(RHS, L)) {
11286 if (!isLoopInvariant(LHS, L))
11287 return std::nullopt;
11288
11289 std::swap(LHS, RHS);
11291 }
11292
11293 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11294 if (!AR || AR->getLoop() != L)
11295 return std::nullopt;
11296
11297 // The predicate must be relational (i.e. <, <=, >=, >).
11298 if (!ICmpInst::isRelational(Pred))
11299 return std::nullopt;
11300
11301 // TODO: Support steps other than +/- 1.
11302 const SCEV *Step = AR->getStepRecurrence(*this);
11303 auto *One = getOne(Step->getType());
11304 auto *MinusOne = getNegativeSCEV(One);
11305 if (Step != One && Step != MinusOne)
11306 return std::nullopt;
11307
11308 // Type mismatch here means that MaxIter is potentially larger than max
11309 // unsigned value in start type, which mean we cannot prove no wrap for the
11310 // indvar.
11311 if (AR->getType() != MaxIter->getType())
11312 return std::nullopt;
11313
11314 // Value of IV on suggested last iteration.
11315 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11316 // Does it still meet the requirement?
11317 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11318 return std::nullopt;
11319 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11320 // not exceed max unsigned value of this type), this effectively proves
11321 // that there is no wrap during the iteration. To prove that there is no
11322 // signed/unsigned wrap, we need to check that
11323 // Start <= Last for step = 1 or Start >= Last for step = -1.
11324 ICmpInst::Predicate NoOverflowPred =
11326 if (Step == MinusOne)
11327 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11328 const SCEV *Start = AR->getStart();
11329 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11330 return std::nullopt;
11331
11332 // Everything is fine.
11333 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11334}
11335
11336bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11337 const SCEV *LHS,
11338 const SCEV *RHS) {
11339 if (HasSameValue(LHS, RHS))
11340 return ICmpInst::isTrueWhenEqual(Pred);
11341
11342 // This code is split out from isKnownPredicate because it is called from
11343 // within isLoopEntryGuardedByCond.
11344
11345 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11346 const ConstantRange &RangeRHS) {
11347 return RangeLHS.icmp(Pred, RangeRHS);
11348 };
11349
11350 // The check at the top of the function catches the case where the values are
11351 // known to be equal.
11352 if (Pred == CmpInst::ICMP_EQ)
11353 return false;
11354
11355 if (Pred == CmpInst::ICMP_NE) {
11356 auto SL = getSignedRange(LHS);
11357 auto SR = getSignedRange(RHS);
11358 if (CheckRanges(SL, SR))
11359 return true;
11360 auto UL = getUnsignedRange(LHS);
11361 auto UR = getUnsignedRange(RHS);
11362 if (CheckRanges(UL, UR))
11363 return true;
11364 auto *Diff = getMinusSCEV(LHS, RHS);
11365 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11366 }
11367
11368 if (CmpInst::isSigned(Pred)) {
11369 auto SL = getSignedRange(LHS);
11370 auto SR = getSignedRange(RHS);
11371 return CheckRanges(SL, SR);
11372 }
11373
11374 auto UL = getUnsignedRange(LHS);
11375 auto UR = getUnsignedRange(RHS);
11376 return CheckRanges(UL, UR);
11377}
11378
11379bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11380 const SCEV *LHS,
11381 const SCEV *RHS) {
11382 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11383 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11384 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11385 // OutC1 and OutC2.
11386 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11387 APInt &OutC1, APInt &OutC2,
11388 SCEV::NoWrapFlags ExpectedFlags) {
11389 const SCEV *XNonConstOp, *XConstOp;
11390 const SCEV *YNonConstOp, *YConstOp;
11391 SCEV::NoWrapFlags XFlagsPresent;
11392 SCEV::NoWrapFlags YFlagsPresent;
11393
11394 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11395 XConstOp = getZero(X->getType());
11396 XNonConstOp = X;
11397 XFlagsPresent = ExpectedFlags;
11398 }
11399 if (!isa<SCEVConstant>(XConstOp) ||
11400 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11401 return false;
11402
11403 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11404 YConstOp = getZero(Y->getType());
11405 YNonConstOp = Y;
11406 YFlagsPresent = ExpectedFlags;
11407 }
11408
11409 if (!isa<SCEVConstant>(YConstOp) ||
11410 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11411 return false;
11412
11413 if (YNonConstOp != XNonConstOp)
11414 return false;
11415
11416 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11417 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11418
11419 return true;
11420 };
11421
11422 APInt C1;
11423 APInt C2;
11424
11425 switch (Pred) {
11426 default:
11427 break;
11428
11429 case ICmpInst::ICMP_SGE:
11430 std::swap(LHS, RHS);
11431 [[fallthrough]];
11432 case ICmpInst::ICMP_SLE:
11433 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11434 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11435 return true;
11436
11437 break;
11438
11439 case ICmpInst::ICMP_SGT:
11440 std::swap(LHS, RHS);
11441 [[fallthrough]];
11442 case ICmpInst::ICMP_SLT:
11443 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11444 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11445 return true;
11446
11447 break;
11448
11449 case ICmpInst::ICMP_UGE:
11450 std::swap(LHS, RHS);
11451 [[fallthrough]];
11452 case ICmpInst::ICMP_ULE:
11453 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11454 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11455 return true;
11456
11457 break;
11458
11459 case ICmpInst::ICMP_UGT:
11460 std::swap(LHS, RHS);
11461 [[fallthrough]];
11462 case ICmpInst::ICMP_ULT:
11463 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11464 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11465 return true;
11466 break;
11467 }
11468
11469 return false;
11470}
11471
11472bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11473 const SCEV *LHS,
11474 const SCEV *RHS) {
11475 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11476 return false;
11477
11478 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11479 // the stack can result in exponential time complexity.
11480 SaveAndRestore Restore(ProvingSplitPredicate, true);
11481
11482 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11483 //
11484 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11485 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11486 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11487 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11488 // use isKnownPredicate later if needed.
11489 return isKnownNonNegative(RHS) &&
11492}
11493
11494bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11495 const SCEV *LHS, const SCEV *RHS) {
11496 // No need to even try if we know the module has no guards.
11497 if (!HasGuards)
11498 return false;
11499
11500 return any_of(*BB, [&](const Instruction &I) {
11501 using namespace llvm::PatternMatch;
11502
11503 Value *Condition;
11504 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11505 m_Value(Condition))) &&
11506 isImpliedCond(Pred, LHS, RHS, Condition, false);
11507 });
11508}
11509
11510/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11511/// protected by a conditional between LHS and RHS. This is used to
11512/// to eliminate casts.
11514 CmpPredicate Pred,
11515 const SCEV *LHS,
11516 const SCEV *RHS) {
11517 // Interpret a null as meaning no loop, where there is obviously no guard
11518 // (interprocedural conditions notwithstanding). Do not bother about
11519 // unreachable loops.
11520 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11521 return true;
11522
11523 if (VerifyIR)
11524 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11525 "This cannot be done on broken IR!");
11526
11527
11528 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11529 return true;
11530
11531 BasicBlock *Latch = L->getLoopLatch();
11532 if (!Latch)
11533 return false;
11534
11535 BranchInst *LoopContinuePredicate =
11536 dyn_cast<BranchInst>(Latch->getTerminator());
11537 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11538 isImpliedCond(Pred, LHS, RHS,
11539 LoopContinuePredicate->getCondition(),
11540 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11541 return true;
11542
11543 // We don't want more than one activation of the following loops on the stack
11544 // -- that can lead to O(n!) time complexity.
11545 if (WalkingBEDominatingConds)
11546 return false;
11547
11548 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11549
11550 // See if we can exploit a trip count to prove the predicate.
11551 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11552 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11553 if (LatchBECount != getCouldNotCompute()) {
11554 // We know that Latch branches back to the loop header exactly
11555 // LatchBECount times. This means the backdege condition at Latch is
11556 // equivalent to "{0,+,1} u< LatchBECount".
11557 Type *Ty = LatchBECount->getType();
11558 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11559 const SCEV *LoopCounter =
11560 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11561 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11562 LatchBECount))
11563 return true;
11564 }
11565
11566 // Check conditions due to any @llvm.assume intrinsics.
11567 for (auto &AssumeVH : AC.assumptions()) {
11568 if (!AssumeVH)
11569 continue;
11570 auto *CI = cast<CallInst>(AssumeVH);
11571 if (!DT.dominates(CI, Latch->getTerminator()))
11572 continue;
11573
11574 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11575 return true;
11576 }
11577
11578 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11579 return true;
11580
11581 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11582 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11583 assert(DTN && "should reach the loop header before reaching the root!");
11584
11585 BasicBlock *BB = DTN->getBlock();
11586 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11587 return true;
11588
11589 BasicBlock *PBB = BB->getSinglePredecessor();
11590 if (!PBB)
11591 continue;
11592
11593 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11594 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11595 continue;
11596
11597 Value *Condition = ContinuePredicate->getCondition();
11598
11599 // If we have an edge `E` within the loop body that dominates the only
11600 // latch, the condition guarding `E` also guards the backedge. This
11601 // reasoning works only for loops with a single latch.
11602
11603 BasicBlockEdge DominatingEdge(PBB, BB);
11604 if (DominatingEdge.isSingleEdge()) {
11605 // We're constructively (and conservatively) enumerating edges within the
11606 // loop body that dominate the latch. The dominator tree better agree
11607 // with us on this:
11608 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11609
11610 if (isImpliedCond(Pred, LHS, RHS, Condition,
11611 BB != ContinuePredicate->getSuccessor(0)))
11612 return true;
11613 }
11614 }
11615
11616 return false;
11617}
11618
11620 CmpPredicate Pred,
11621 const SCEV *LHS,
11622 const SCEV *RHS) {
11623 // Do not bother proving facts for unreachable code.
11624 if (!DT.isReachableFromEntry(BB))
11625 return true;
11626 if (VerifyIR)
11627 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11628 "This cannot be done on broken IR!");
11629
11630 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11631 // the facts (a >= b && a != b) separately. A typical situation is when the
11632 // non-strict comparison is known from ranges and non-equality is known from
11633 // dominating predicates. If we are proving strict comparison, we always try
11634 // to prove non-equality and non-strict comparison separately.
11635 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11636 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11637 bool ProvedNonStrictComparison = false;
11638 bool ProvedNonEquality = false;
11639
11640 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11641 if (!ProvedNonStrictComparison)
11642 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11643 if (!ProvedNonEquality)
11644 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11645 if (ProvedNonStrictComparison && ProvedNonEquality)
11646 return true;
11647 return false;
11648 };
11649
11650 if (ProvingStrictComparison) {
11651 auto ProofFn = [&](CmpPredicate P) {
11652 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11653 };
11654 if (SplitAndProve(ProofFn))
11655 return true;
11656 }
11657
11658 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11659 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11660 const Instruction *CtxI = &BB->front();
11661 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11662 return true;
11663 if (ProvingStrictComparison) {
11664 auto ProofFn = [&](CmpPredicate P) {
11665 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11666 };
11667 if (SplitAndProve(ProofFn))
11668 return true;
11669 }
11670 return false;
11671 };
11672
11673 // Starting at the block's predecessor, climb up the predecessor chain, as long
11674 // as there are predecessors that can be found that have unique successors
11675 // leading to the original block.
11676 const Loop *ContainingLoop = LI.getLoopFor(BB);
11677 const BasicBlock *PredBB;
11678 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11679 PredBB = ContainingLoop->getLoopPredecessor();
11680 else
11681 PredBB = BB->getSinglePredecessor();
11682 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11683 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11684 const BranchInst *BlockEntryPredicate =
11685 dyn_cast<BranchInst>(Pair.first->getTerminator());
11686 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11687 continue;
11688
11689 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11690 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11691 return true;
11692 }
11693
11694 // Check conditions due to any @llvm.assume intrinsics.
11695 for (auto &AssumeVH : AC.assumptions()) {
11696 if (!AssumeVH)
11697 continue;
11698 auto *CI = cast<CallInst>(AssumeVH);
11699 if (!DT.dominates(CI, BB))
11700 continue;
11701
11702 if (ProveViaCond(CI->getArgOperand(0), false))
11703 return true;
11704 }
11705
11706 // Check conditions due to any @llvm.experimental.guard intrinsics.
11707 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11708 F.getParent(), Intrinsic::experimental_guard);
11709 if (GuardDecl)
11710 for (const auto *GU : GuardDecl->users())
11711 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11712 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11713 if (ProveViaCond(Guard->getArgOperand(0), false))
11714 return true;
11715 return false;
11716}
11717
11719 const SCEV *LHS,
11720 const SCEV *RHS) {
11721 // Interpret a null as meaning no loop, where there is obviously no guard
11722 // (interprocedural conditions notwithstanding).
11723 if (!L)
11724 return false;
11725
11726 // Both LHS and RHS must be available at loop entry.
11728 "LHS is not available at Loop Entry");
11730 "RHS is not available at Loop Entry");
11731
11732 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11733 return true;
11734
11735 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11736}
11737
11738bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11739 const SCEV *RHS,
11740 const Value *FoundCondValue, bool Inverse,
11741 const Instruction *CtxI) {
11742 // False conditions implies anything. Do not bother analyzing it further.
11743 if (FoundCondValue ==
11744 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11745 return true;
11746
11747 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11748 return false;
11749
11750 auto ClearOnExit =
11751 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11752
11753 // Recursively handle And and Or conditions.
11754 const Value *Op0, *Op1;
11755 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11756 if (!Inverse)
11757 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11758 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11759 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11760 if (Inverse)
11761 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11762 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11763 }
11764
11765 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11766 if (!ICI) return false;
11767
11768 // Now that we found a conditional branch that dominates the loop or controls
11769 // the loop latch. Check to see if it is the comparison we are looking for.
11770 CmpPredicate FoundPred;
11771 if (Inverse)
11772 FoundPred = ICI->getInverseCmpPredicate();
11773 else
11774 FoundPred = ICI->getCmpPredicate();
11775
11776 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11777 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11778
11779 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11780}
11781
11782bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11783 const SCEV *RHS, CmpPredicate FoundPred,
11784 const SCEV *FoundLHS, const SCEV *FoundRHS,
11785 const Instruction *CtxI) {
11786 // Balance the types.
11787 if (getTypeSizeInBits(LHS->getType()) <
11788 getTypeSizeInBits(FoundLHS->getType())) {
11789 // For unsigned and equality predicates, try to prove that both found
11790 // operands fit into narrow unsigned range. If so, try to prove facts in
11791 // narrow types.
11792 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11793 !FoundRHS->getType()->isPointerTy()) {
11794 auto *NarrowType = LHS->getType();
11795 auto *WideType = FoundLHS->getType();
11796 auto BitWidth = getTypeSizeInBits(NarrowType);
11797 const SCEV *MaxValue = getZeroExtendExpr(
11799 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11800 MaxValue) &&
11801 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11802 MaxValue)) {
11803 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11804 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11805 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11806 TruncFoundRHS, CtxI))
11807 return true;
11808 }
11809 }
11810
11811 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11812 return false;
11813 if (CmpInst::isSigned(Pred)) {
11814 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11815 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11816 } else {
11817 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11818 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11819 }
11820 } else if (getTypeSizeInBits(LHS->getType()) >
11821 getTypeSizeInBits(FoundLHS->getType())) {
11822 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11823 return false;
11824 if (CmpInst::isSigned(FoundPred)) {
11825 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11826 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11827 } else {
11828 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11829 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11830 }
11831 }
11832 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11833 FoundRHS, CtxI);
11834}
11835
11836bool ScalarEvolution::isImpliedCondBalancedTypes(
11837 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11838 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11840 getTypeSizeInBits(FoundLHS->getType()) &&
11841 "Types should be balanced!");
11842 // Canonicalize the query to match the way instcombine will have
11843 // canonicalized the comparison.
11844 if (SimplifyICmpOperands(Pred, LHS, RHS))
11845 if (LHS == RHS)
11846 return CmpInst::isTrueWhenEqual(Pred);
11847 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11848 if (FoundLHS == FoundRHS)
11849 return CmpInst::isFalseWhenEqual(FoundPred);
11850
11851 // Check to see if we can make the LHS or RHS match.
11852 if (LHS == FoundRHS || RHS == FoundLHS) {
11853 if (isa<SCEVConstant>(RHS)) {
11854 std::swap(FoundLHS, FoundRHS);
11855 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11856 } else {
11857 std::swap(LHS, RHS);
11859 }
11860 }
11861
11862 // Check whether the found predicate is the same as the desired predicate.
11863 // FIXME: use CmpPredicate::getMatching here.
11864 if (FoundPred == static_cast<CmpInst::Predicate>(Pred))
11865 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11866
11867 // Check whether swapping the found predicate makes it the same as the
11868 // desired predicate.
11869 // FIXME: use CmpPredicate::getMatching here.
11870 if (ICmpInst::getSwappedCmpPredicate(FoundPred) ==
11871 static_cast<CmpInst::Predicate>(Pred)) {
11872 // We can write the implication
11873 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11874 // using one of the following ways:
11875 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11876 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11877 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11878 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11879 // Forms 1. and 2. require swapping the operands of one condition. Don't
11880 // do this if it would break canonical constant/addrec ordering.
11881 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11882 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11883 CtxI);
11884 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11885 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11886
11887 // There's no clear preference between forms 3. and 4., try both. Avoid
11888 // forming getNotSCEV of pointer values as the resulting subtract is
11889 // not legal.
11890 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11891 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11892 FoundLHS, FoundRHS, CtxI))
11893 return true;
11894
11895 if (!FoundLHS->getType()->isPointerTy() &&
11896 !FoundRHS->getType()->isPointerTy() &&
11897 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11898 getNotSCEV(FoundRHS), CtxI))
11899 return true;
11900
11901 return false;
11902 }
11903
11904 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11905 CmpInst::Predicate P2) {
11906 assert(P1 != P2 && "Handled earlier!");
11907 return CmpInst::isRelational(P2) &&
11909 };
11910 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11911 // Unsigned comparison is the same as signed comparison when both the
11912 // operands are non-negative or negative.
11913 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11914 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11915 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11916 // Create local copies that we can freely swap and canonicalize our
11917 // conditions to "le/lt".
11918 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11919 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11920 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11921 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11922 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
11923 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
11924 std::swap(CanonicalLHS, CanonicalRHS);
11925 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11926 }
11927 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11928 "Must be!");
11929 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11930 ICmpInst::isLE(CanonicalFoundPred)) &&
11931 "Must be!");
11932 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11933 // Use implication:
11934 // x <u y && y >=s 0 --> x <s y.
11935 // If we can prove the left part, the right part is also proven.
11936 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11937 CanonicalRHS, CanonicalFoundLHS,
11938 CanonicalFoundRHS);
11939 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11940 // Use implication:
11941 // x <s y && y <s 0 --> x <u y.
11942 // If we can prove the left part, the right part is also proven.
11943 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11944 CanonicalRHS, CanonicalFoundLHS,
11945 CanonicalFoundRHS);
11946 }
11947
11948 // Check if we can make progress by sharpening ranges.
11949 if (FoundPred == ICmpInst::ICMP_NE &&
11950 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11951
11952 const SCEVConstant *C = nullptr;
11953 const SCEV *V = nullptr;
11954
11955 if (isa<SCEVConstant>(FoundLHS)) {
11956 C = cast<SCEVConstant>(FoundLHS);
11957 V = FoundRHS;
11958 } else {
11959 C = cast<SCEVConstant>(FoundRHS);
11960 V = FoundLHS;
11961 }
11962
11963 // The guarding predicate tells us that C != V. If the known range
11964 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11965 // range we consider has to correspond to same signedness as the
11966 // predicate we're interested in folding.
11967
11968 APInt Min = ICmpInst::isSigned(Pred) ?
11970
11971 if (Min == C->getAPInt()) {
11972 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11973 // This is true even if (Min + 1) wraps around -- in case of
11974 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11975
11976 APInt SharperMin = Min + 1;
11977
11978 switch (Pred) {
11979 case ICmpInst::ICMP_SGE:
11980 case ICmpInst::ICMP_UGE:
11981 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11982 // RHS, we're done.
11983 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11984 CtxI))
11985 return true;
11986 [[fallthrough]];
11987
11988 case ICmpInst::ICMP_SGT:
11989 case ICmpInst::ICMP_UGT:
11990 // We know from the range information that (V `Pred` Min ||
11991 // V == Min). We know from the guarding condition that !(V
11992 // == Min). This gives us
11993 //
11994 // V `Pred` Min || V == Min && !(V == Min)
11995 // => V `Pred` Min
11996 //
11997 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11998
11999 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12000 return true;
12001 break;
12002
12003 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12004 case ICmpInst::ICMP_SLE:
12005 case ICmpInst::ICMP_ULE:
12006 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12007 LHS, V, getConstant(SharperMin), CtxI))
12008 return true;
12009 [[fallthrough]];
12010
12011 case ICmpInst::ICMP_SLT:
12012 case ICmpInst::ICMP_ULT:
12013 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12014 LHS, V, getConstant(Min), CtxI))
12015 return true;
12016 break;
12017
12018 default:
12019 // No change
12020 break;
12021 }
12022 }
12023 }
12024
12025 // Check whether the actual condition is beyond sufficient.
12026 if (FoundPred == ICmpInst::ICMP_EQ)
12027 if (ICmpInst::isTrueWhenEqual(Pred))
12028 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12029 return true;
12030 if (Pred == ICmpInst::ICMP_NE)
12031 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12032 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12033 return true;
12034
12035 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12036 return true;
12037
12038 // Otherwise assume the worst.
12039 return false;
12040}
12041
12042bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12043 const SCEV *&L, const SCEV *&R,
12044 SCEV::NoWrapFlags &Flags) {
12045 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12046 if (!AE || AE->getNumOperands() != 2)
12047 return false;
12048
12049 L = AE->getOperand(0);
12050 R = AE->getOperand(1);
12051 Flags = AE->getNoWrapFlags();
12052 return true;
12053}
12054
12055std::optional<APInt>
12057 // We avoid subtracting expressions here because this function is usually
12058 // fairly deep in the call stack (i.e. is called many times).
12059
12060 unsigned BW = getTypeSizeInBits(More->getType());
12061 APInt Diff(BW, 0);
12062 APInt DiffMul(BW, 1);
12063 // Try various simplifications to reduce the difference to a constant. Limit
12064 // the number of allowed simplifications to keep compile-time low.
12065 for (unsigned I = 0; I < 8; ++I) {
12066 if (More == Less)
12067 return Diff;
12068
12069 // Reduce addrecs with identical steps to their start value.
12070 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
12071 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12072 const auto *MAR = cast<SCEVAddRecExpr>(More);
12073
12074 if (LAR->getLoop() != MAR->getLoop())
12075 return std::nullopt;
12076
12077 // We look at affine expressions only; not for correctness but to keep
12078 // getStepRecurrence cheap.
12079 if (!LAR->isAffine() || !MAR->isAffine())
12080 return std::nullopt;
12081
12082 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12083 return std::nullopt;
12084
12085 Less = LAR->getStart();
12086 More = MAR->getStart();
12087 continue;
12088 }
12089
12090 // Try to match a common constant multiply.
12091 auto MatchConstMul =
12092 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12093 auto *M = dyn_cast<SCEVMulExpr>(S);
12094 if (!M || M->getNumOperands() != 2 ||
12095 !isa<SCEVConstant>(M->getOperand(0)))
12096 return std::nullopt;
12097 return {
12098 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12099 };
12100 if (auto MatchedMore = MatchConstMul(More)) {
12101 if (auto MatchedLess = MatchConstMul(Less)) {
12102 if (MatchedMore->second == MatchedLess->second) {
12103 More = MatchedMore->first;
12104 Less = MatchedLess->first;
12105 DiffMul *= MatchedMore->second;
12106 continue;
12107 }
12108 }
12109 }
12110
12111 // Try to cancel out common factors in two add expressions.
12113 auto Add = [&](const SCEV *S, int Mul) {
12114 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12115 if (Mul == 1) {
12116 Diff += C->getAPInt() * DiffMul;
12117 } else {
12118 assert(Mul == -1);
12119 Diff -= C->getAPInt() * DiffMul;
12120 }
12121 } else
12122 Multiplicity[S] += Mul;
12123 };
12124 auto Decompose = [&](const SCEV *S, int Mul) {
12125 if (isa<SCEVAddExpr>(S)) {
12126 for (const SCEV *Op : S->operands())
12127 Add(Op, Mul);
12128 } else
12129 Add(S, Mul);
12130 };
12131 Decompose(More, 1);
12132 Decompose(Less, -1);
12133
12134 // Check whether all the non-constants cancel out, or reduce to new
12135 // More/Less values.
12136 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12137 for (const auto &[S, Mul] : Multiplicity) {
12138 if (Mul == 0)
12139 continue;
12140 if (Mul == 1) {
12141 if (NewMore)
12142 return std::nullopt;
12143 NewMore = S;
12144 } else if (Mul == -1) {
12145 if (NewLess)
12146 return std::nullopt;
12147 NewLess = S;
12148 } else
12149 return std::nullopt;
12150 }
12151
12152 // Values stayed the same, no point in trying further.
12153 if (NewMore == More || NewLess == Less)
12154 return std::nullopt;
12155
12156 More = NewMore;
12157 Less = NewLess;
12158
12159 // Reduced to constant.
12160 if (!More && !Less)
12161 return Diff;
12162
12163 // Left with variable on only one side, bail out.
12164 if (!More || !Less)
12165 return std::nullopt;
12166 }
12167
12168 // Did not reduce to constant.
12169 return std::nullopt;
12170}
12171
12172bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12173 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12174 const SCEV *FoundRHS, const Instruction *CtxI) {
12175 // Try to recognize the following pattern:
12176 //
12177 // FoundRHS = ...
12178 // ...
12179 // loop:
12180 // FoundLHS = {Start,+,W}
12181 // context_bb: // Basic block from the same loop
12182 // known(Pred, FoundLHS, FoundRHS)
12183 //
12184 // If some predicate is known in the context of a loop, it is also known on
12185 // each iteration of this loop, including the first iteration. Therefore, in
12186 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12187 // prove the original pred using this fact.
12188 if (!CtxI)
12189 return false;
12190 const BasicBlock *ContextBB = CtxI->getParent();
12191 // Make sure AR varies in the context block.
12192 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12193 const Loop *L = AR->getLoop();
12194 // Make sure that context belongs to the loop and executes on 1st iteration
12195 // (if it ever executes at all).
12196 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12197 return false;
12198 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12199 return false;
12200 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12201 }
12202
12203 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12204 const Loop *L = AR->getLoop();
12205 // Make sure that context belongs to the loop and executes on 1st iteration
12206 // (if it ever executes at all).
12207 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12208 return false;
12209 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12210 return false;
12211 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12212 }
12213
12214 return false;
12215}
12216
12217bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12218 const SCEV *LHS,
12219 const SCEV *RHS,
12220 const SCEV *FoundLHS,
12221 const SCEV *FoundRHS) {
12222 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12223 return false;
12224
12225 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12226 if (!AddRecLHS)
12227 return false;
12228
12229 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12230 if (!AddRecFoundLHS)
12231 return false;
12232
12233 // We'd like to let SCEV reason about control dependencies, so we constrain
12234 // both the inequalities to be about add recurrences on the same loop. This
12235 // way we can use isLoopEntryGuardedByCond later.
12236
12237 const Loop *L = AddRecFoundLHS->getLoop();
12238 if (L != AddRecLHS->getLoop())
12239 return false;
12240
12241 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12242 //
12243 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12244 // ... (2)
12245 //
12246 // Informal proof for (2), assuming (1) [*]:
12247 //
12248 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12249 //
12250 // Then
12251 //
12252 // FoundLHS s< FoundRHS s< INT_MIN - C
12253 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12254 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12255 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12256 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12257 // <=> FoundLHS + C s< FoundRHS + C
12258 //
12259 // [*]: (1) can be proved by ruling out overflow.
12260 //
12261 // [**]: This can be proved by analyzing all the four possibilities:
12262 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12263 // (A s>= 0, B s>= 0).
12264 //
12265 // Note:
12266 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12267 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12268 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12269 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12270 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12271 // C)".
12272
12273 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12274 if (!LDiff)
12275 return false;
12276 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12277 if (!RDiff || *LDiff != *RDiff)
12278 return false;
12279
12280 if (LDiff->isMinValue())
12281 return true;
12282
12283 APInt FoundRHSLimit;
12284
12285 if (Pred == CmpInst::ICMP_ULT) {
12286 FoundRHSLimit = -(*RDiff);
12287 } else {
12288 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12289 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12290 }
12291
12292 // Try to prove (1) or (2), as needed.
12293 return isAvailableAtLoopEntry(FoundRHS, L) &&
12294 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12295 getConstant(FoundRHSLimit));
12296}
12297
12298bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12299 const SCEV *RHS, const SCEV *FoundLHS,
12300 const SCEV *FoundRHS, unsigned Depth) {
12301 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12302
12303 auto ClearOnExit = make_scope_exit([&]() {
12304 if (LPhi) {
12305 bool Erased = PendingMerges.erase(LPhi);
12306 assert(Erased && "Failed to erase LPhi!");
12307 (void)Erased;
12308 }
12309 if (RPhi) {
12310 bool Erased = PendingMerges.erase(RPhi);
12311 assert(Erased && "Failed to erase RPhi!");
12312 (void)Erased;
12313 }
12314 });
12315
12316 // Find respective Phis and check that they are not being pending.
12317 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12318 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12319 if (!PendingMerges.insert(Phi).second)
12320 return false;
12321 LPhi = Phi;
12322 }
12323 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12324 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12325 // If we detect a loop of Phi nodes being processed by this method, for
12326 // example:
12327 //
12328 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12329 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12330 //
12331 // we don't want to deal with a case that complex, so return conservative
12332 // answer false.
12333 if (!PendingMerges.insert(Phi).second)
12334 return false;
12335 RPhi = Phi;
12336 }
12337
12338 // If none of LHS, RHS is a Phi, nothing to do here.
12339 if (!LPhi && !RPhi)
12340 return false;
12341
12342 // If there is a SCEVUnknown Phi we are interested in, make it left.
12343 if (!LPhi) {
12344 std::swap(LHS, RHS);
12345 std::swap(FoundLHS, FoundRHS);
12346 std::swap(LPhi, RPhi);
12348 }
12349
12350 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12351 const BasicBlock *LBB = LPhi->getParent();
12352 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12353
12354 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12355 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12356 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12357 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12358 };
12359
12360 if (RPhi && RPhi->getParent() == LBB) {
12361 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12362 // If we compare two Phis from the same block, and for each entry block
12363 // the predicate is true for incoming values from this block, then the
12364 // predicate is also true for the Phis.
12365 for (const BasicBlock *IncBB : predecessors(LBB)) {
12366 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12367 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12368 if (!ProvedEasily(L, R))
12369 return false;
12370 }
12371 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12372 // Case two: RHS is also a Phi from the same basic block, and it is an
12373 // AddRec. It means that there is a loop which has both AddRec and Unknown
12374 // PHIs, for it we can compare incoming values of AddRec from above the loop
12375 // and latch with their respective incoming values of LPhi.
12376 // TODO: Generalize to handle loops with many inputs in a header.
12377 if (LPhi->getNumIncomingValues() != 2) return false;
12378
12379 auto *RLoop = RAR->getLoop();
12380 auto *Predecessor = RLoop->getLoopPredecessor();
12381 assert(Predecessor && "Loop with AddRec with no predecessor?");
12382 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12383 if (!ProvedEasily(L1, RAR->getStart()))
12384 return false;
12385 auto *Latch = RLoop->getLoopLatch();
12386 assert(Latch && "Loop with AddRec with no latch?");
12387 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12388 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12389 return false;
12390 } else {
12391 // In all other cases go over inputs of LHS and compare each of them to RHS,
12392 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12393 // At this point RHS is either a non-Phi, or it is a Phi from some block
12394 // different from LBB.
12395 for (const BasicBlock *IncBB : predecessors(LBB)) {
12396 // Check that RHS is available in this block.
12397 if (!dominates(RHS, IncBB))
12398 return false;
12399 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12400 // Make sure L does not refer to a value from a potentially previous
12401 // iteration of a loop.
12402 if (!properlyDominates(L, LBB))
12403 return false;
12404 if (!ProvedEasily(L, RHS))
12405 return false;
12406 }
12407 }
12408 return true;
12409}
12410
12411bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12412 const SCEV *LHS,
12413 const SCEV *RHS,
12414 const SCEV *FoundLHS,
12415 const SCEV *FoundRHS) {
12416 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12417 // sure that we are dealing with same LHS.
12418 if (RHS == FoundRHS) {
12419 std::swap(LHS, RHS);
12420 std::swap(FoundLHS, FoundRHS);
12422 }
12423 if (LHS != FoundLHS)
12424 return false;
12425
12426 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12427 if (!SUFoundRHS)
12428 return false;
12429
12430 Value *Shiftee, *ShiftValue;
12431
12432 using namespace PatternMatch;
12433 if (match(SUFoundRHS->getValue(),
12434 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12435 auto *ShifteeS = getSCEV(Shiftee);
12436 // Prove one of the following:
12437 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12438 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12439 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12440 // ---> LHS <s RHS
12441 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12442 // ---> LHS <=s RHS
12443 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12444 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12445 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12446 if (isKnownNonNegative(ShifteeS))
12447 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12448 }
12449
12450 return false;
12451}
12452
12453bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12454 const SCEV *RHS,
12455 const SCEV *FoundLHS,
12456 const SCEV *FoundRHS,
12457 const Instruction *CtxI) {
12458 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12459 return true;
12460
12461 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12462 return true;
12463
12464 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12465 return true;
12466
12467 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12468 CtxI))
12469 return true;
12470
12471 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12472 FoundLHS, FoundRHS);
12473}
12474
12475/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12476template <typename MinMaxExprType>
12477static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12478 const SCEV *Candidate) {
12479 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12480 if (!MinMaxExpr)
12481 return false;
12482
12483 return is_contained(MinMaxExpr->operands(), Candidate);
12484}
12485
12487 CmpPredicate Pred, const SCEV *LHS,
12488 const SCEV *RHS) {
12489 // If both sides are affine addrecs for the same loop, with equal
12490 // steps, and we know the recurrences don't wrap, then we only
12491 // need to check the predicate on the starting values.
12492
12493 if (!ICmpInst::isRelational(Pred))
12494 return false;
12495
12496 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12497 if (!LAR)
12498 return false;
12499 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12500 if (!RAR)
12501 return false;
12502 if (LAR->getLoop() != RAR->getLoop())
12503 return false;
12504 if (!LAR->isAffine() || !RAR->isAffine())
12505 return false;
12506
12507 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12508 return false;
12509
12512 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12513 return false;
12514
12515 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12516}
12517
12518/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12519/// expression?
12521 const SCEV *LHS, const SCEV *RHS) {
12522 switch (Pred) {
12523 default:
12524 return false;
12525
12526 case ICmpInst::ICMP_SGE:
12527 std::swap(LHS, RHS);
12528 [[fallthrough]];
12529 case ICmpInst::ICMP_SLE:
12530 return
12531 // min(A, ...) <= A
12532 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12533 // A <= max(A, ...)
12534 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12535
12536 case ICmpInst::ICMP_UGE:
12537 std::swap(LHS, RHS);
12538 [[fallthrough]];
12539 case ICmpInst::ICMP_ULE:
12540 return
12541 // min(A, ...) <= A
12542 // FIXME: what about umin_seq?
12543 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12544 // A <= max(A, ...)
12545 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12546 }
12547
12548 llvm_unreachable("covered switch fell through?!");
12549}
12550
12551bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12552 const SCEV *RHS,
12553 const SCEV *FoundLHS,
12554 const SCEV *FoundRHS,
12555 unsigned Depth) {
12558 "LHS and RHS have different sizes?");
12559 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12560 getTypeSizeInBits(FoundRHS->getType()) &&
12561 "FoundLHS and FoundRHS have different sizes?");
12562 // We want to avoid hurting the compile time with analysis of too big trees.
12564 return false;
12565
12566 // We only want to work with GT comparison so far.
12567 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12569 std::swap(LHS, RHS);
12570 std::swap(FoundLHS, FoundRHS);
12571 }
12572
12573 // For unsigned, try to reduce it to corresponding signed comparison.
12574 if (Pred == ICmpInst::ICMP_UGT)
12575 // We can replace unsigned predicate with its signed counterpart if all
12576 // involved values are non-negative.
12577 // TODO: We could have better support for unsigned.
12578 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12579 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12580 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12581 // use this fact to prove that LHS and RHS are non-negative.
12582 const SCEV *MinusOne = getMinusOne(LHS->getType());
12583 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12584 FoundRHS) &&
12585 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12586 FoundRHS))
12587 Pred = ICmpInst::ICMP_SGT;
12588 }
12589
12590 if (Pred != ICmpInst::ICMP_SGT)
12591 return false;
12592
12593 auto GetOpFromSExt = [&](const SCEV *S) {
12594 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12595 return Ext->getOperand();
12596 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12597 // the constant in some cases.
12598 return S;
12599 };
12600
12601 // Acquire values from extensions.
12602 auto *OrigLHS = LHS;
12603 auto *OrigFoundLHS = FoundLHS;
12604 LHS = GetOpFromSExt(LHS);
12605 FoundLHS = GetOpFromSExt(FoundLHS);
12606
12607 // Is the SGT predicate can be proved trivially or using the found context.
12608 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12609 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12610 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12611 FoundRHS, Depth + 1);
12612 };
12613
12614 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12615 // We want to avoid creation of any new non-constant SCEV. Since we are
12616 // going to compare the operands to RHS, we should be certain that we don't
12617 // need any size extensions for this. So let's decline all cases when the
12618 // sizes of types of LHS and RHS do not match.
12619 // TODO: Maybe try to get RHS from sext to catch more cases?
12621 return false;
12622
12623 // Should not overflow.
12624 if (!LHSAddExpr->hasNoSignedWrap())
12625 return false;
12626
12627 auto *LL = LHSAddExpr->getOperand(0);
12628 auto *LR = LHSAddExpr->getOperand(1);
12629 auto *MinusOne = getMinusOne(RHS->getType());
12630
12631 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12632 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12633 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12634 };
12635 // Try to prove the following rule:
12636 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12637 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12638 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12639 return true;
12640 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12641 Value *LL, *LR;
12642 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12643
12644 using namespace llvm::PatternMatch;
12645
12646 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12647 // Rules for division.
12648 // We are going to perform some comparisons with Denominator and its
12649 // derivative expressions. In general case, creating a SCEV for it may
12650 // lead to a complex analysis of the entire graph, and in particular it
12651 // can request trip count recalculation for the same loop. This would
12652 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12653 // this, we only want to create SCEVs that are constants in this section.
12654 // So we bail if Denominator is not a constant.
12655 if (!isa<ConstantInt>(LR))
12656 return false;
12657
12658 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12659
12660 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12661 // then a SCEV for the numerator already exists and matches with FoundLHS.
12662 auto *Numerator = getExistingSCEV(LL);
12663 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12664 return false;
12665
12666 // Make sure that the numerator matches with FoundLHS and the denominator
12667 // is positive.
12668 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12669 return false;
12670
12671 auto *DTy = Denominator->getType();
12672 auto *FRHSTy = FoundRHS->getType();
12673 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12674 // One of types is a pointer and another one is not. We cannot extend
12675 // them properly to a wider type, so let us just reject this case.
12676 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12677 // to avoid this check.
12678 return false;
12679
12680 // Given that:
12681 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12682 auto *WTy = getWiderType(DTy, FRHSTy);
12683 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12684 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12685
12686 // Try to prove the following rule:
12687 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12688 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12689 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12690 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12691 if (isKnownNonPositive(RHS) &&
12692 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12693 return true;
12694
12695 // Try to prove the following rule:
12696 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12697 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12698 // If we divide it by Denominator > 2, then:
12699 // 1. If FoundLHS is negative, then the result is 0.
12700 // 2. If FoundLHS is non-negative, then the result is non-negative.
12701 // Anyways, the result is non-negative.
12702 auto *MinusOne = getMinusOne(WTy);
12703 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12704 if (isKnownNegative(RHS) &&
12705 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12706 return true;
12707 }
12708 }
12709
12710 // If our expression contained SCEVUnknown Phis, and we split it down and now
12711 // need to prove something for them, try to prove the predicate for every
12712 // possible incoming values of those Phis.
12713 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12714 return true;
12715
12716 return false;
12717}
12718
12719static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
12720 const SCEV *RHS) {
12721 // zext x u<= sext x, sext x s<= zext x
12722 const SCEV *Op;
12723 switch (Pred) {
12724 case ICmpInst::ICMP_SGE:
12725 std::swap(LHS, RHS);
12726 [[fallthrough]];
12727 case ICmpInst::ICMP_SLE: {
12728 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12729 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12731 }
12732 case ICmpInst::ICMP_UGE:
12733 std::swap(LHS, RHS);
12734 [[fallthrough]];
12735 case ICmpInst::ICMP_ULE: {
12736 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12737 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12739 }
12740 default:
12741 return false;
12742 };
12743 llvm_unreachable("unhandled case");
12744}
12745
12746bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12747 const SCEV *LHS,
12748 const SCEV *RHS) {
12749 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12750 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12751 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12752 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12753 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12754}
12755
12756bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12757 const SCEV *LHS,
12758 const SCEV *RHS,
12759 const SCEV *FoundLHS,
12760 const SCEV *FoundRHS) {
12761 switch (Pred) {
12762 default:
12763 llvm_unreachable("Unexpected CmpPredicate value!");
12764 case ICmpInst::ICMP_EQ:
12765 case ICmpInst::ICMP_NE:
12766 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12767 return true;
12768 break;
12769 case ICmpInst::ICMP_SLT:
12770 case ICmpInst::ICMP_SLE:
12771 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12772 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12773 return true;
12774 break;
12775 case ICmpInst::ICMP_SGT:
12776 case ICmpInst::ICMP_SGE:
12777 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12778 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12779 return true;
12780 break;
12781 case ICmpInst::ICMP_ULT:
12782 case ICmpInst::ICMP_ULE:
12783 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12784 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12785 return true;
12786 break;
12787 case ICmpInst::ICMP_UGT:
12788 case ICmpInst::ICMP_UGE:
12789 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12790 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12791 return true;
12792 break;
12793 }
12794
12795 // Maybe it can be proved via operations?
12796 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12797 return true;
12798
12799 return false;
12800}
12801
12802bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12803 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12804 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12805 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12806 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12807 // reduce the compile time impact of this optimization.
12808 return false;
12809
12810 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12811 if (!Addend)
12812 return false;
12813
12814 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12815
12816 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12817 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12818 ConstantRange FoundLHSRange =
12819 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12820
12821 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12822 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12823
12824 // We can also compute the range of values for `LHS` that satisfy the
12825 // consequent, "`LHS` `Pred` `RHS`":
12826 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12827 // The antecedent implies the consequent if every value of `LHS` that
12828 // satisfies the antecedent also satisfies the consequent.
12829 return LHSRange.icmp(Pred, ConstRHS);
12830}
12831
12832bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12833 bool IsSigned) {
12834 assert(isKnownPositive(Stride) && "Positive stride expected!");
12835
12836 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12837 const SCEV *One = getOne(Stride->getType());
12838
12839 if (IsSigned) {
12840 APInt MaxRHS = getSignedRangeMax(RHS);
12842 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12843
12844 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12845 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12846 }
12847
12848 APInt MaxRHS = getUnsignedRangeMax(RHS);
12849 APInt MaxValue = APInt::getMaxValue(BitWidth);
12850 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12851
12852 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12853 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12854}
12855
12856bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12857 bool IsSigned) {
12858
12859 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12860 const SCEV *One = getOne(Stride->getType());
12861
12862 if (IsSigned) {
12863 APInt MinRHS = getSignedRangeMin(RHS);
12865 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12866
12867 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12868 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12869 }
12870
12871 APInt MinRHS = getUnsignedRangeMin(RHS);
12872 APInt MinValue = APInt::getMinValue(BitWidth);
12873 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12874
12875 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12876 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12877}
12878
12880 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12881 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12882 // expression fixes the case of N=0.
12883 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12884 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12885 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12886}
12887
12888const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12889 const SCEV *Stride,
12890 const SCEV *End,
12891 unsigned BitWidth,
12892 bool IsSigned) {
12893 // The logic in this function assumes we can represent a positive stride.
12894 // If we can't, the backedge-taken count must be zero.
12895 if (IsSigned && BitWidth == 1)
12896 return getZero(Stride->getType());
12897
12898 // This code below only been closely audited for negative strides in the
12899 // unsigned comparison case, it may be correct for signed comparison, but
12900 // that needs to be established.
12901 if (IsSigned && isKnownNegative(Stride))
12902 return getCouldNotCompute();
12903
12904 // Calculate the maximum backedge count based on the range of values
12905 // permitted by Start, End, and Stride.
12906 APInt MinStart =
12907 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12908
12909 APInt MinStride =
12910 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12911
12912 // We assume either the stride is positive, or the backedge-taken count
12913 // is zero. So force StrideForMaxBECount to be at least one.
12914 APInt One(BitWidth, 1);
12915 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12916 : APIntOps::umax(One, MinStride);
12917
12918 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12919 : APInt::getMaxValue(BitWidth);
12920 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12921
12922 // Although End can be a MAX expression we estimate MaxEnd considering only
12923 // the case End = RHS of the loop termination condition. This is safe because
12924 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12925 // taken count.
12926 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12927 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12928
12929 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12930 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12931 : APIntOps::umax(MaxEnd, MinStart);
12932
12933 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12934 getConstant(StrideForMaxBECount) /* Step */);
12935}
12936
12938ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12939 const Loop *L, bool IsSigned,
12940 bool ControlsOnlyExit, bool AllowPredicates) {
12942
12943 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12944 bool PredicatedIV = false;
12945 if (!IV) {
12946 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12947 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12948 if (AR && AR->getLoop() == L && AR->isAffine()) {
12949 auto canProveNUW = [&]() {
12950 // We can use the comparison to infer no-wrap flags only if it fully
12951 // controls the loop exit.
12952 if (!ControlsOnlyExit)
12953 return false;
12954
12955 if (!isLoopInvariant(RHS, L))
12956 return false;
12957
12958 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12959 // We need the sequence defined by AR to strictly increase in the
12960 // unsigned integer domain for the logic below to hold.
12961 return false;
12962
12963 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12964 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12965 // If RHS <=u Limit, then there must exist a value V in the sequence
12966 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12967 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12968 // overflow occurs. This limit also implies that a signed comparison
12969 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12970 // the high bits on both sides must be zero.
12971 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12972 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12973 Limit = Limit.zext(OuterBitWidth);
12974 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12975 };
12976 auto Flags = AR->getNoWrapFlags();
12977 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12978 Flags = setFlags(Flags, SCEV::FlagNUW);
12979
12980 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12981 if (AR->hasNoUnsignedWrap()) {
12982 // Emulate what getZeroExtendExpr would have done during construction
12983 // if we'd been able to infer the fact just above at that time.
12984 const SCEV *Step = AR->getStepRecurrence(*this);
12985 Type *Ty = ZExt->getType();
12986 auto *S = getAddRecExpr(
12987 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12988 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12989 IV = dyn_cast<SCEVAddRecExpr>(S);
12990 }
12991 }
12992 }
12993 }
12994
12995
12996 if (!IV && AllowPredicates) {
12997 // Try to make this an AddRec using runtime tests, in the first X
12998 // iterations of this loop, where X is the SCEV expression found by the
12999 // algorithm below.
13000 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13001 PredicatedIV = true;
13002 }
13003
13004 // Avoid weird loops
13005 if (!IV || IV->getLoop() != L || !IV->isAffine())
13006 return getCouldNotCompute();
13007
13008 // A precondition of this method is that the condition being analyzed
13009 // reaches an exiting branch which dominates the latch. Given that, we can
13010 // assume that an increment which violates the nowrap specification and
13011 // produces poison must cause undefined behavior when the resulting poison
13012 // value is branched upon and thus we can conclude that the backedge is
13013 // taken no more often than would be required to produce that poison value.
13014 // Note that a well defined loop can exit on the iteration which violates
13015 // the nowrap specification if there is another exit (either explicit or
13016 // implicit/exceptional) which causes the loop to execute before the
13017 // exiting instruction we're analyzing would trigger UB.
13018 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13019 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13021
13022 const SCEV *Stride = IV->getStepRecurrence(*this);
13023
13024 bool PositiveStride = isKnownPositive(Stride);
13025
13026 // Avoid negative or zero stride values.
13027 if (!PositiveStride) {
13028 // We can compute the correct backedge taken count for loops with unknown
13029 // strides if we can prove that the loop is not an infinite loop with side
13030 // effects. Here's the loop structure we are trying to handle -
13031 //
13032 // i = start
13033 // do {
13034 // A[i] = i;
13035 // i += s;
13036 // } while (i < end);
13037 //
13038 // The backedge taken count for such loops is evaluated as -
13039 // (max(end, start + stride) - start - 1) /u stride
13040 //
13041 // The additional preconditions that we need to check to prove correctness
13042 // of the above formula is as follows -
13043 //
13044 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13045 // NoWrap flag).
13046 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13047 // no side effects within the loop)
13048 // c) loop has a single static exit (with no abnormal exits)
13049 //
13050 // Precondition a) implies that if the stride is negative, this is a single
13051 // trip loop. The backedge taken count formula reduces to zero in this case.
13052 //
13053 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13054 // then a zero stride means the backedge can't be taken without executing
13055 // undefined behavior.
13056 //
13057 // The positive stride case is the same as isKnownPositive(Stride) returning
13058 // true (original behavior of the function).
13059 //
13060 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13062 return getCouldNotCompute();
13063
13064 if (!isKnownNonZero(Stride)) {
13065 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13066 // if it might eventually be greater than start and if so, on which
13067 // iteration. We can't even produce a useful upper bound.
13068 if (!isLoopInvariant(RHS, L))
13069 return getCouldNotCompute();
13070
13071 // We allow a potentially zero stride, but we need to divide by stride
13072 // below. Since the loop can't be infinite and this check must control
13073 // the sole exit, we can infer the exit must be taken on the first
13074 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13075 // we know the numerator in the divides below must be zero, so we can
13076 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13077 // and produce the right result.
13078 // FIXME: Handle the case where Stride is poison?
13079 auto wouldZeroStrideBeUB = [&]() {
13080 // Proof by contradiction. Suppose the stride were zero. If we can
13081 // prove that the backedge *is* taken on the first iteration, then since
13082 // we know this condition controls the sole exit, we must have an
13083 // infinite loop. We can't have a (well defined) infinite loop per
13084 // check just above.
13085 // Note: The (Start - Stride) term is used to get the start' term from
13086 // (start' + stride,+,stride). Remember that we only care about the
13087 // result of this expression when stride == 0 at runtime.
13088 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13089 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13090 };
13091 if (!wouldZeroStrideBeUB()) {
13092 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13093 }
13094 }
13095 } else if (!NoWrap) {
13096 // Avoid proven overflow cases: this will ensure that the backedge taken
13097 // count will not generate any unsigned overflow.
13098 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13099 return getCouldNotCompute();
13100 }
13101
13102 // On all paths just preceeding, we established the following invariant:
13103 // IV can be assumed not to overflow up to and including the exiting
13104 // iteration. We proved this in one of two ways:
13105 // 1) We can show overflow doesn't occur before the exiting iteration
13106 // 1a) canIVOverflowOnLT, and b) step of one
13107 // 2) We can show that if overflow occurs, the loop must execute UB
13108 // before any possible exit.
13109 // Note that we have not yet proved RHS invariant (in general).
13110
13111 const SCEV *Start = IV->getStart();
13112
13113 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13114 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13115 // Use integer-typed versions for actual computation; we can't subtract
13116 // pointers in general.
13117 const SCEV *OrigStart = Start;
13118 const SCEV *OrigRHS = RHS;
13119 if (Start->getType()->isPointerTy()) {
13120 Start = getLosslessPtrToIntExpr(Start);
13121 if (isa<SCEVCouldNotCompute>(Start))
13122 return Start;
13123 }
13124 if (RHS->getType()->isPointerTy()) {
13126 if (isa<SCEVCouldNotCompute>(RHS))
13127 return RHS;
13128 }
13129
13130 const SCEV *End = nullptr, *BECount = nullptr,
13131 *BECountIfBackedgeTaken = nullptr;
13132 if (!isLoopInvariant(RHS, L)) {
13133 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13134 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13135 RHSAddRec->getNoWrapFlags()) {
13136 // The structure of loop we are trying to calculate backedge count of:
13137 //
13138 // left = left_start
13139 // right = right_start
13140 //
13141 // while(left < right){
13142 // ... do something here ...
13143 // left += s1; // stride of left is s1 (s1 > 0)
13144 // right += s2; // stride of right is s2 (s2 < 0)
13145 // }
13146 //
13147
13148 const SCEV *RHSStart = RHSAddRec->getStart();
13149 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13150
13151 // If Stride - RHSStride is positive and does not overflow, we can write
13152 // backedge count as ->
13153 // ceil((End - Start) /u (Stride - RHSStride))
13154 // Where, End = max(RHSStart, Start)
13155
13156 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13157 if (isKnownNegative(RHSStride) &&
13158 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13159 RHSStride)) {
13160
13161 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13162 if (isKnownPositive(Denominator)) {
13163 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13164 : getUMaxExpr(RHSStart, Start);
13165
13166 // We can do this because End >= Start, as End = max(RHSStart, Start)
13167 const SCEV *Delta = getMinusSCEV(End, Start);
13168
13169 BECount = getUDivCeilSCEV(Delta, Denominator);
13170 BECountIfBackedgeTaken =
13171 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13172 }
13173 }
13174 }
13175 if (BECount == nullptr) {
13176 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13177 // given the start, stride and max value for the end bound of the
13178 // loop (RHS), and the fact that IV does not overflow (which is
13179 // checked above).
13180 const SCEV *MaxBECount = computeMaxBECountForLT(
13181 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13182 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13183 MaxBECount, false /*MaxOrZero*/, Predicates);
13184 }
13185 } else {
13186 // We use the expression (max(End,Start)-Start)/Stride to describe the
13187 // backedge count, as if the backedge is taken at least once
13188 // max(End,Start) is End and so the result is as above, and if not
13189 // max(End,Start) is Start so we get a backedge count of zero.
13190 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13191 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13192 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13193 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13194 // Can we prove (max(RHS,Start) > Start - Stride?
13195 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13196 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13197 // In this case, we can use a refined formula for computing backedge
13198 // taken count. The general formula remains:
13199 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13200 // We want to use the alternate formula:
13201 // "((End - 1) - (Start - Stride)) /u Stride"
13202 // Let's do a quick case analysis to show these are equivalent under
13203 // our precondition that max(RHS,Start) > Start - Stride.
13204 // * For RHS <= Start, the backedge-taken count must be zero.
13205 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13206 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13207 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13208 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13209 // reducing this to the stride of 1 case.
13210 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13211 // Stride".
13212 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13213 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13214 // "((RHS - (Start - Stride) - 1) /u Stride".
13215 // Our preconditions trivially imply no overflow in that form.
13216 const SCEV *MinusOne = getMinusOne(Stride->getType());
13217 const SCEV *Numerator =
13218 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13219 BECount = getUDivExpr(Numerator, Stride);
13220 }
13221
13222 if (!BECount) {
13223 auto canProveRHSGreaterThanEqualStart = [&]() {
13224 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13225 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13226 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13227
13228 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13229 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13230 return true;
13231
13232 // (RHS > Start - 1) implies RHS >= Start.
13233 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13234 // "Start - 1" doesn't overflow.
13235 // * For signed comparison, if Start - 1 does overflow, it's equal
13236 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13237 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13238 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13239 //
13240 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13241 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13242 auto *StartMinusOne =
13243 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13244 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13245 };
13246
13247 // If we know that RHS >= Start in the context of loop, then we know
13248 // that max(RHS, Start) = RHS at this point.
13249 if (canProveRHSGreaterThanEqualStart()) {
13250 End = RHS;
13251 } else {
13252 // If RHS < Start, the backedge will be taken zero times. So in
13253 // general, we can write the backedge-taken count as:
13254 //
13255 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13256 //
13257 // We convert it to the following to make it more convenient for SCEV:
13258 //
13259 // ceil(max(RHS, Start) - Start) / Stride
13260 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13261
13262 // See what would happen if we assume the backedge is taken. This is
13263 // used to compute MaxBECount.
13264 BECountIfBackedgeTaken =
13265 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13266 }
13267
13268 // At this point, we know:
13269 //
13270 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13271 // 2. The index variable doesn't overflow.
13272 //
13273 // Therefore, we know N exists such that
13274 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13275 // doesn't overflow.
13276 //
13277 // Using this information, try to prove whether the addition in
13278 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13279 const SCEV *One = getOne(Stride->getType());
13280 bool MayAddOverflow = [&] {
13281 if (isKnownToBeAPowerOfTwo(Stride)) {
13282 // Suppose Stride is a power of two, and Start/End are unsigned
13283 // integers. Let UMAX be the largest representable unsigned
13284 // integer.
13285 //
13286 // By the preconditions of this function, we know
13287 // "(Start + Stride * N) >= End", and this doesn't overflow.
13288 // As a formula:
13289 //
13290 // End <= (Start + Stride * N) <= UMAX
13291 //
13292 // Subtracting Start from all the terms:
13293 //
13294 // End - Start <= Stride * N <= UMAX - Start
13295 //
13296 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13297 //
13298 // End - Start <= Stride * N <= UMAX
13299 //
13300 // Stride * N is a multiple of Stride. Therefore,
13301 //
13302 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13303 //
13304 // Since Stride is a power of two, UMAX + 1 is divisible by
13305 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13306 // write:
13307 //
13308 // End - Start <= Stride * N <= UMAX - Stride - 1
13309 //
13310 // Dropping the middle term:
13311 //
13312 // End - Start <= UMAX - Stride - 1
13313 //
13314 // Adding Stride - 1 to both sides:
13315 //
13316 // (End - Start) + (Stride - 1) <= UMAX
13317 //
13318 // In other words, the addition doesn't have unsigned overflow.
13319 //
13320 // A similar proof works if we treat Start/End as signed values.
13321 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13322 // to use signed max instead of unsigned max. Note that we're
13323 // trying to prove a lack of unsigned overflow in either case.
13324 return false;
13325 }
13326 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13327 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13328 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13329 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13330 // 1 <s End.
13331 //
13332 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13333 // End.
13334 return false;
13335 }
13336 return true;
13337 }();
13338
13339 const SCEV *Delta = getMinusSCEV(End, Start);
13340 if (!MayAddOverflow) {
13341 // floor((D + (S - 1)) / S)
13342 // We prefer this formulation if it's legal because it's fewer
13343 // operations.
13344 BECount =
13345 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13346 } else {
13347 BECount = getUDivCeilSCEV(Delta, Stride);
13348 }
13349 }
13350 }
13351
13352 const SCEV *ConstantMaxBECount;
13353 bool MaxOrZero = false;
13354 if (isa<SCEVConstant>(BECount)) {
13355 ConstantMaxBECount = BECount;
13356 } else if (BECountIfBackedgeTaken &&
13357 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13358 // If we know exactly how many times the backedge will be taken if it's
13359 // taken at least once, then the backedge count will either be that or
13360 // zero.
13361 ConstantMaxBECount = BECountIfBackedgeTaken;
13362 MaxOrZero = true;
13363 } else {
13364 ConstantMaxBECount = computeMaxBECountForLT(
13365 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13366 }
13367
13368 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13369 !isa<SCEVCouldNotCompute>(BECount))
13370 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13371
13372 const SCEV *SymbolicMaxBECount =
13373 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13374 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13375 Predicates);
13376}
13377
13378ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13379 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13380 bool ControlsOnlyExit, bool AllowPredicates) {
13382 // We handle only IV > Invariant
13383 if (!isLoopInvariant(RHS, L))
13384 return getCouldNotCompute();
13385
13386 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13387 if (!IV && AllowPredicates)
13388 // Try to make this an AddRec using runtime tests, in the first X
13389 // iterations of this loop, where X is the SCEV expression found by the
13390 // algorithm below.
13391 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13392
13393 // Avoid weird loops
13394 if (!IV || IV->getLoop() != L || !IV->isAffine())
13395 return getCouldNotCompute();
13396
13397 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13398 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13400
13401 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13402
13403 // Avoid negative or zero stride values
13404 if (!isKnownPositive(Stride))
13405 return getCouldNotCompute();
13406
13407 // Avoid proven overflow cases: this will ensure that the backedge taken count
13408 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13409 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13410 // behaviors like the case of C language.
13411 if (!Stride->isOne() && !NoWrap)
13412 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13413 return getCouldNotCompute();
13414
13415 const SCEV *Start = IV->getStart();
13416 const SCEV *End = RHS;
13417 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13418 // If we know that Start >= RHS in the context of loop, then we know that
13419 // min(RHS, Start) = RHS at this point.
13421 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13422 End = RHS;
13423 else
13424 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13425 }
13426
13427 if (Start->getType()->isPointerTy()) {
13428 Start = getLosslessPtrToIntExpr(Start);
13429 if (isa<SCEVCouldNotCompute>(Start))
13430 return Start;
13431 }
13432 if (End->getType()->isPointerTy()) {
13434 if (isa<SCEVCouldNotCompute>(End))
13435 return End;
13436 }
13437
13438 // Compute ((Start - End) + (Stride - 1)) / Stride.
13439 // FIXME: This can overflow. Holding off on fixing this for now;
13440 // howManyGreaterThans will hopefully be gone soon.
13441 const SCEV *One = getOne(Stride->getType());
13442 const SCEV *BECount = getUDivExpr(
13443 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13444
13445 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13446 : getUnsignedRangeMax(Start);
13447
13448 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13449 : getUnsignedRangeMin(Stride);
13450
13451 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13452 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13453 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13454
13455 // Although End can be a MIN expression we estimate MinEnd considering only
13456 // the case End = RHS. This is safe because in the other case (Start - End)
13457 // is zero, leading to a zero maximum backedge taken count.
13458 APInt MinEnd =
13459 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13460 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13461
13462 const SCEV *ConstantMaxBECount =
13463 isa<SCEVConstant>(BECount)
13464 ? BECount
13465 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13466 getConstant(MinStride));
13467
13468 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13469 ConstantMaxBECount = BECount;
13470 const SCEV *SymbolicMaxBECount =
13471 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13472
13473 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13474 Predicates);
13475}
13476
13478 ScalarEvolution &SE) const {
13479 if (Range.isFullSet()) // Infinite loop.
13480 return SE.getCouldNotCompute();
13481
13482 // If the start is a non-zero constant, shift the range to simplify things.
13483 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13484 if (!SC->getValue()->isZero()) {
13486 Operands[0] = SE.getZero(SC->getType());
13487 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13488 getNoWrapFlags(FlagNW));
13489 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13490 return ShiftedAddRec->getNumIterationsInRange(
13491 Range.subtract(SC->getAPInt()), SE);
13492 // This is strange and shouldn't happen.
13493 return SE.getCouldNotCompute();
13494 }
13495
13496 // The only time we can solve this is when we have all constant indices.
13497 // Otherwise, we cannot determine the overflow conditions.
13498 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13499 return SE.getCouldNotCompute();
13500
13501 // Okay at this point we know that all elements of the chrec are constants and
13502 // that the start element is zero.
13503
13504 // First check to see if the range contains zero. If not, the first
13505 // iteration exits.
13506 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13507 if (!Range.contains(APInt(BitWidth, 0)))
13508 return SE.getZero(getType());
13509
13510 if (isAffine()) {
13511 // If this is an affine expression then we have this situation:
13512 // Solve {0,+,A} in Range === Ax in Range
13513
13514 // We know that zero is in the range. If A is positive then we know that
13515 // the upper value of the range must be the first possible exit value.
13516 // If A is negative then the lower of the range is the last possible loop
13517 // value. Also note that we already checked for a full range.
13518 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13519 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13520
13521 // The exit value should be (End+A)/A.
13522 APInt ExitVal = (End + A).udiv(A);
13523 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13524
13525 // Evaluate at the exit value. If we really did fall out of the valid
13526 // range, then we computed our trip count, otherwise wrap around or other
13527 // things must have happened.
13528 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13529 if (Range.contains(Val->getValue()))
13530 return SE.getCouldNotCompute(); // Something strange happened
13531
13532 // Ensure that the previous value is in the range.
13535 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13536 "Linear scev computation is off in a bad way!");
13537 return SE.getConstant(ExitValue);
13538 }
13539
13540 if (isQuadratic()) {
13541 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13542 return SE.getConstant(*S);
13543 }
13544
13545 return SE.getCouldNotCompute();
13546}
13547
13548const SCEVAddRecExpr *
13550 assert(getNumOperands() > 1 && "AddRec with zero step?");
13551 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13552 // but in this case we cannot guarantee that the value returned will be an
13553 // AddRec because SCEV does not have a fixed point where it stops
13554 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13555 // may happen if we reach arithmetic depth limit while simplifying. So we
13556 // construct the returned value explicitly.
13558 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13559 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13560 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13561 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13562 // We know that the last operand is not a constant zero (otherwise it would
13563 // have been popped out earlier). This guarantees us that if the result has
13564 // the same last operand, then it will also not be popped out, meaning that
13565 // the returned value will be an AddRec.
13566 const SCEV *Last = getOperand(getNumOperands() - 1);
13567 assert(!Last->isZero() && "Recurrency with zero step?");
13568 Ops.push_back(Last);
13569 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13571}
13572
13573// Return true when S contains at least an undef value.
13575 return SCEVExprContains(S, [](const SCEV *S) {
13576 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13577 return isa<UndefValue>(SU->getValue());
13578 return false;
13579 });
13580}
13581
13582// Return true when S contains a value that is a nullptr.
13584 return SCEVExprContains(S, [](const SCEV *S) {
13585 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13586 return SU->getValue() == nullptr;
13587 return false;
13588 });
13589}
13590
13591/// Return the size of an element read or written by Inst.
13593 Type *Ty;
13594 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13595 Ty = Store->getValueOperand()->getType();
13596 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13597 Ty = Load->getType();
13598 else
13599 return nullptr;
13600
13602 return getSizeOfExpr(ETy, Ty);
13603}
13604
13605//===----------------------------------------------------------------------===//
13606// SCEVCallbackVH Class Implementation
13607//===----------------------------------------------------------------------===//
13608
13609void ScalarEvolution::SCEVCallbackVH::deleted() {
13610 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13611 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13612 SE->ConstantEvolutionLoopExitValue.erase(PN);
13613 SE->eraseValueFromMap(getValPtr());
13614 // this now dangles!
13615}
13616
13617void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13618 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13619
13620 // Forget all the expressions associated with users of the old value,
13621 // so that future queries will recompute the expressions using the new
13622 // value.
13623 SE->forgetValue(getValPtr());
13624 // this now dangles!
13625}
13626
13627ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13628 : CallbackVH(V), SE(se) {}
13629
13630//===----------------------------------------------------------------------===//
13631// ScalarEvolution Class Implementation
13632//===----------------------------------------------------------------------===//
13633
13636 LoopInfo &LI)
13637 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13638 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13639 LoopDispositions(64), BlockDispositions(64) {
13640 // To use guards for proving predicates, we need to scan every instruction in
13641 // relevant basic blocks, and not just terminators. Doing this is a waste of
13642 // time if the IR does not actually contain any calls to
13643 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13644 //
13645 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13646 // to _add_ guards to the module when there weren't any before, and wants
13647 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13648 // efficient in lieu of being smart in that rather obscure case.
13649
13650 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13651 F.getParent(), Intrinsic::experimental_guard);
13652 HasGuards = GuardDecl && !GuardDecl->use_empty();
13653}
13654
13656 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13657 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13658 ValueExprMap(std::move(Arg.ValueExprMap)),
13659 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13660 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13661 PendingMerges(std::move(Arg.PendingMerges)),
13662 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13663 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13664 PredicatedBackedgeTakenCounts(
13665 std::move(Arg.PredicatedBackedgeTakenCounts)),
13666 BECountUsers(std::move(Arg.BECountUsers)),
13667 ConstantEvolutionLoopExitValue(
13668 std::move(Arg.ConstantEvolutionLoopExitValue)),
13669 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13670 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13671 LoopDispositions(std::move(Arg.LoopDispositions)),
13672 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13673 BlockDispositions(std::move(Arg.BlockDispositions)),
13674 SCEVUsers(std::move(Arg.SCEVUsers)),
13675 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13676 SignedRanges(std::move(Arg.SignedRanges)),
13677 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13678 UniquePreds(std::move(Arg.UniquePreds)),
13679 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13680 LoopUsers(std::move(Arg.LoopUsers)),
13681 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13682 FirstUnknown(Arg.FirstUnknown) {
13683 Arg.FirstUnknown = nullptr;
13684}
13685
13687 // Iterate through all the SCEVUnknown instances and call their
13688 // destructors, so that they release their references to their values.
13689 for (SCEVUnknown *U = FirstUnknown; U;) {
13690 SCEVUnknown *Tmp = U;
13691 U = U->Next;
13692 Tmp->~SCEVUnknown();
13693 }
13694 FirstUnknown = nullptr;
13695
13696 ExprValueMap.clear();
13697 ValueExprMap.clear();
13698 HasRecMap.clear();
13699 BackedgeTakenCounts.clear();
13700 PredicatedBackedgeTakenCounts.clear();
13701
13702 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13703 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13704 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13705 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13706 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13707}
13708
13710 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13711}
13712
13713/// When printing a top-level SCEV for trip counts, it's helpful to include
13714/// a type for constants which are otherwise hard to disambiguate.
13715static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13716 if (isa<SCEVConstant>(S))
13717 OS << *S->getType() << " ";
13718 OS << *S;
13719}
13720
13722 const Loop *L) {
13723 // Print all inner loops first
13724 for (Loop *I : *L)
13725 PrintLoopInfo(OS, SE, I);
13726
13727 OS << "Loop ";
13728 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13729 OS << ": ";
13730
13731 SmallVector<BasicBlock *, 8> ExitingBlocks;
13732 L->getExitingBlocks(ExitingBlocks);
13733 if (ExitingBlocks.size() != 1)
13734 OS << "<multiple exits> ";
13735
13736 auto *BTC = SE->getBackedgeTakenCount(L);
13737 if (!isa<SCEVCouldNotCompute>(BTC)) {
13738 OS << "backedge-taken count is ";
13740 } else
13741 OS << "Unpredictable backedge-taken count.";
13742 OS << "\n";
13743
13744 if (ExitingBlocks.size() > 1)
13745 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13746 OS << " exit count for " << ExitingBlock->getName() << ": ";
13747 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13749 if (isa<SCEVCouldNotCompute>(EC)) {
13750 // Retry with predicates.
13752 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13753 if (!isa<SCEVCouldNotCompute>(EC)) {
13754 OS << "\n predicated exit count for " << ExitingBlock->getName()
13755 << ": ";
13757 OS << "\n Predicates:\n";
13758 for (const auto *P : Predicates)
13759 P->print(OS, 4);
13760 }
13761 }
13762 OS << "\n";
13763 }
13764
13765 OS << "Loop ";
13766 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13767 OS << ": ";
13768
13769 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13770 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13771 OS << "constant max backedge-taken count is ";
13772 PrintSCEVWithTypeHint(OS, ConstantBTC);
13774 OS << ", actual taken count either this or zero.";
13775 } else {
13776 OS << "Unpredictable constant max backedge-taken count. ";
13777 }
13778
13779 OS << "\n"
13780 "Loop ";
13781 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13782 OS << ": ";
13783
13784 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13785 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13786 OS << "symbolic max backedge-taken count is ";
13787 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13789 OS << ", actual taken count either this or zero.";
13790 } else {
13791 OS << "Unpredictable symbolic max backedge-taken count. ";
13792 }
13793 OS << "\n";
13794
13795 if (ExitingBlocks.size() > 1)
13796 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13797 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13798 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13800 PrintSCEVWithTypeHint(OS, ExitBTC);
13801 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13802 // Retry with predicates.
13804 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13806 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13807 OS << "\n predicated symbolic max exit count for "
13808 << ExitingBlock->getName() << ": ";
13809 PrintSCEVWithTypeHint(OS, ExitBTC);
13810 OS << "\n Predicates:\n";
13811 for (const auto *P : Predicates)
13812 P->print(OS, 4);
13813 }
13814 }
13815 OS << "\n";
13816 }
13817
13819 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13820 if (PBT != BTC) {
13821 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13822 OS << "Loop ";
13823 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13824 OS << ": ";
13825 if (!isa<SCEVCouldNotCompute>(PBT)) {
13826 OS << "Predicated backedge-taken count is ";
13828 } else
13829 OS << "Unpredictable predicated backedge-taken count.";
13830 OS << "\n";
13831 OS << " Predicates:\n";
13832 for (const auto *P : Preds)
13833 P->print(OS, 4);
13834 }
13835 Preds.clear();
13836
13837 auto *PredConstantMax =
13839 if (PredConstantMax != ConstantBTC) {
13840 assert(!Preds.empty() &&
13841 "different predicated constant max BTC but no predicates");
13842 OS << "Loop ";
13843 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13844 OS << ": ";
13845 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13846 OS << "Predicated constant max backedge-taken count is ";
13847 PrintSCEVWithTypeHint(OS, PredConstantMax);
13848 } else
13849 OS << "Unpredictable predicated constant max backedge-taken count.";
13850 OS << "\n";
13851 OS << " Predicates:\n";
13852 for (const auto *P : Preds)
13853 P->print(OS, 4);
13854 }
13855 Preds.clear();
13856
13857 auto *PredSymbolicMax =
13859 if (SymbolicBTC != PredSymbolicMax) {
13860 assert(!Preds.empty() &&
13861 "Different predicated symbolic max BTC, but no predicates");
13862 OS << "Loop ";
13863 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13864 OS << ": ";
13865 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13866 OS << "Predicated symbolic max backedge-taken count is ";
13867 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13868 } else
13869 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13870 OS << "\n";
13871 OS << " Predicates:\n";
13872 for (const auto *P : Preds)
13873 P->print(OS, 4);
13874 }
13875
13877 OS << "Loop ";
13878 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13879 OS << ": ";
13880 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13881 }
13882}
13883
13884namespace llvm {
13886 switch (LD) {
13888 OS << "Variant";
13889 break;
13891 OS << "Invariant";
13892 break;
13894 OS << "Computable";
13895 break;
13896 }
13897 return OS;
13898}
13899
13901 switch (BD) {
13903 OS << "DoesNotDominate";
13904 break;
13906 OS << "Dominates";
13907 break;
13909 OS << "ProperlyDominates";
13910 break;
13911 }
13912 return OS;
13913}
13914} // namespace llvm
13915
13917 // ScalarEvolution's implementation of the print method is to print
13918 // out SCEV values of all instructions that are interesting. Doing
13919 // this potentially causes it to create new SCEV objects though,
13920 // which technically conflicts with the const qualifier. This isn't
13921 // observable from outside the class though, so casting away the
13922 // const isn't dangerous.
13923 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13924
13925 if (ClassifyExpressions) {
13926 OS << "Classifying expressions for: ";
13927 F.printAsOperand(OS, /*PrintType=*/false);
13928 OS << "\n";
13929 for (Instruction &I : instructions(F))
13930 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13931 OS << I << '\n';
13932 OS << " --> ";
13933 const SCEV *SV = SE.getSCEV(&I);
13934 SV->print(OS);
13935 if (!isa<SCEVCouldNotCompute>(SV)) {
13936 OS << " U: ";
13937 SE.getUnsignedRange(SV).print(OS);
13938 OS << " S: ";
13939 SE.getSignedRange(SV).print(OS);
13940 }
13941
13942 const Loop *L = LI.getLoopFor(I.getParent());
13943
13944 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13945 if (AtUse != SV) {
13946 OS << " --> ";
13947 AtUse->print(OS);
13948 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13949 OS << " U: ";
13950 SE.getUnsignedRange(AtUse).print(OS);
13951 OS << " S: ";
13952 SE.getSignedRange(AtUse).print(OS);
13953 }
13954 }
13955
13956 if (L) {
13957 OS << "\t\t" "Exits: ";
13958 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13959 if (!SE.isLoopInvariant(ExitValue, L)) {
13960 OS << "<<Unknown>>";
13961 } else {
13962 OS << *ExitValue;
13963 }
13964
13965 bool First = true;
13966 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13967 if (First) {
13968 OS << "\t\t" "LoopDispositions: { ";
13969 First = false;
13970 } else {
13971 OS << ", ";
13972 }
13973
13974 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13975 OS << ": " << SE.getLoopDisposition(SV, Iter);
13976 }
13977
13978 for (const auto *InnerL : depth_first(L)) {
13979 if (InnerL == L)
13980 continue;
13981 if (First) {
13982 OS << "\t\t" "LoopDispositions: { ";
13983 First = false;
13984 } else {
13985 OS << ", ";
13986 }
13987
13988 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13989 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13990 }
13991
13992 OS << " }";
13993 }
13994
13995 OS << "\n";
13996 }
13997 }
13998
13999 OS << "Determining loop execution counts for: ";
14000 F.printAsOperand(OS, /*PrintType=*/false);
14001 OS << "\n";
14002 for (Loop *I : LI)
14003 PrintLoopInfo(OS, &SE, I);
14004}
14005
14008 auto &Values = LoopDispositions[S];
14009 for (auto &V : Values) {
14010 if (V.getPointer() == L)
14011 return V.getInt();
14012 }
14013 Values.emplace_back(L, LoopVariant);
14014 LoopDisposition D = computeLoopDisposition(S, L);
14015 auto &Values2 = LoopDispositions[S];
14016 for (auto &V : llvm::reverse(Values2)) {
14017 if (V.getPointer() == L) {
14018 V.setInt(D);
14019 break;
14020 }
14021 }
14022 return D;
14023}
14024
14026ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14027 switch (S->getSCEVType()) {
14028 case scConstant:
14029 case scVScale:
14030 return LoopInvariant;
14031 case scAddRecExpr: {
14032 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14033
14034 // If L is the addrec's loop, it's computable.
14035 if (AR->getLoop() == L)
14036 return LoopComputable;
14037
14038 // Add recurrences are never invariant in the function-body (null loop).
14039 if (!L)
14040 return LoopVariant;
14041
14042 // Everything that is not defined at loop entry is variant.
14043 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14044 return LoopVariant;
14045 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14046 " dominate the contained loop's header?");
14047
14048 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14049 if (AR->getLoop()->contains(L))
14050 return LoopInvariant;
14051
14052 // This recurrence is variant w.r.t. L if any of its operands
14053 // are variant.
14054 for (const auto *Op : AR->operands())
14055 if (!isLoopInvariant(Op, L))
14056 return LoopVariant;
14057
14058 // Otherwise it's loop-invariant.
14059 return LoopInvariant;
14060 }
14061 case scTruncate:
14062 case scZeroExtend:
14063 case scSignExtend:
14064 case scPtrToInt:
14065 case scAddExpr:
14066 case scMulExpr:
14067 case scUDivExpr:
14068 case scUMaxExpr:
14069 case scSMaxExpr:
14070 case scUMinExpr:
14071 case scSMinExpr:
14072 case scSequentialUMinExpr: {
14073 bool HasVarying = false;
14074 for (const auto *Op : S->operands()) {
14076 if (D == LoopVariant)
14077 return LoopVariant;
14078 if (D == LoopComputable)
14079 HasVarying = true;
14080 }
14081 return HasVarying ? LoopComputable : LoopInvariant;
14082 }
14083 case scUnknown:
14084 // All non-instruction values are loop invariant. All instructions are loop
14085 // invariant if they are not contained in the specified loop.
14086 // Instructions are never considered invariant in the function body
14087 // (null loop) because they are defined within the "loop".
14088 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14089 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14090 return LoopInvariant;
14091 case scCouldNotCompute:
14092 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14093 }
14094 llvm_unreachable("Unknown SCEV kind!");
14095}
14096
14098 return getLoopDisposition(S, L) == LoopInvariant;
14099}
14100
14102 return getLoopDisposition(S, L) == LoopComputable;
14103}
14104
14107 auto &Values = BlockDispositions[S];
14108 for (auto &V : Values) {
14109 if (V.getPointer() == BB)
14110 return V.getInt();
14111 }
14112 Values.emplace_back(BB, DoesNotDominateBlock);
14113 BlockDisposition D = computeBlockDisposition(S, BB);
14114 auto &Values2 = BlockDispositions[S];
14115 for (auto &V : llvm::reverse(Values2)) {
14116 if (V.getPointer() == BB) {
14117 V.setInt(D);
14118 break;
14119 }
14120 }
14121 return D;
14122}
14123
14125ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14126 switch (S->getSCEVType()) {
14127 case scConstant:
14128 case scVScale:
14130 case scAddRecExpr: {
14131 // This uses a "dominates" query instead of "properly dominates" query
14132 // to test for proper dominance too, because the instruction which
14133 // produces the addrec's value is a PHI, and a PHI effectively properly
14134 // dominates its entire containing block.
14135 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14136 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14137 return DoesNotDominateBlock;
14138
14139 // Fall through into SCEVNAryExpr handling.
14140 [[fallthrough]];
14141 }
14142 case scTruncate:
14143 case scZeroExtend:
14144 case scSignExtend:
14145 case scPtrToInt:
14146 case scAddExpr:
14147 case scMulExpr:
14148 case scUDivExpr:
14149 case scUMaxExpr:
14150 case scSMaxExpr:
14151 case scUMinExpr:
14152 case scSMinExpr:
14153 case scSequentialUMinExpr: {
14154 bool Proper = true;
14155 for (const SCEV *NAryOp : S->operands()) {
14157 if (D == DoesNotDominateBlock)
14158 return DoesNotDominateBlock;
14159 if (D == DominatesBlock)
14160 Proper = false;
14161 }
14162 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14163 }
14164 case scUnknown:
14165 if (Instruction *I =
14166 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14167 if (I->getParent() == BB)
14168 return DominatesBlock;
14169 if (DT.properlyDominates(I->getParent(), BB))
14171 return DoesNotDominateBlock;
14172 }
14174 case scCouldNotCompute:
14175 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14176 }
14177 llvm_unreachable("Unknown SCEV kind!");
14178}
14179
14180bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14181 return getBlockDisposition(S, BB) >= DominatesBlock;
14182}
14183
14186}
14187
14188bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14189 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14190}
14191
14192void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14193 bool Predicated) {
14194 auto &BECounts =
14195 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14196 auto It = BECounts.find(L);
14197 if (It != BECounts.end()) {
14198 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14199 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14200 if (!isa<SCEVConstant>(S)) {
14201 auto UserIt = BECountUsers.find(S);
14202 assert(UserIt != BECountUsers.end());
14203 UserIt->second.erase({L, Predicated});
14204 }
14205 }
14206 }
14207 BECounts.erase(It);
14208 }
14209}
14210
14211void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14212 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14213 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14214
14215 while (!Worklist.empty()) {
14216 const SCEV *Curr = Worklist.pop_back_val();
14217 auto Users = SCEVUsers.find(Curr);
14218 if (Users != SCEVUsers.end())
14219 for (const auto *User : Users->second)
14220 if (ToForget.insert(User).second)
14221 Worklist.push_back(User);
14222 }
14223
14224 for (const auto *S : ToForget)
14225 forgetMemoizedResultsImpl(S);
14226
14227 for (auto I = PredicatedSCEVRewrites.begin();
14228 I != PredicatedSCEVRewrites.end();) {
14229 std::pair<const SCEV *, const Loop *> Entry = I->first;
14230 if (ToForget.count(Entry.first))
14231 PredicatedSCEVRewrites.erase(I++);
14232 else
14233 ++I;
14234 }
14235}
14236
14237void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14238 LoopDispositions.erase(S);
14239 BlockDispositions.erase(S);
14240 UnsignedRanges.erase(S);
14241 SignedRanges.erase(S);
14242 HasRecMap.erase(S);
14243 ConstantMultipleCache.erase(S);
14244
14245 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14246 UnsignedWrapViaInductionTried.erase(AR);
14247 SignedWrapViaInductionTried.erase(AR);
14248 }
14249
14250 auto ExprIt = ExprValueMap.find(S);
14251 if (ExprIt != ExprValueMap.end()) {
14252 for (Value *V : ExprIt->second) {
14253 auto ValueIt = ValueExprMap.find_as(V);
14254 if (ValueIt != ValueExprMap.end())
14255 ValueExprMap.erase(ValueIt);
14256 }
14257 ExprValueMap.erase(ExprIt);
14258 }
14259
14260 auto ScopeIt = ValuesAtScopes.find(S);
14261 if (ScopeIt != ValuesAtScopes.end()) {
14262 for (const auto &Pair : ScopeIt->second)
14263 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14264 llvm::erase(ValuesAtScopesUsers[Pair.second],
14265 std::make_pair(Pair.first, S));
14266 ValuesAtScopes.erase(ScopeIt);
14267 }
14268
14269 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14270 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14271 for (const auto &Pair : ScopeUserIt->second)
14272 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14273 ValuesAtScopesUsers.erase(ScopeUserIt);
14274 }
14275
14276 auto BEUsersIt = BECountUsers.find(S);
14277 if (BEUsersIt != BECountUsers.end()) {
14278 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14279 auto Copy = BEUsersIt->second;
14280 for (const auto &Pair : Copy)
14281 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14282 BECountUsers.erase(BEUsersIt);
14283 }
14284
14285 auto FoldUser = FoldCacheUser.find(S);
14286 if (FoldUser != FoldCacheUser.end())
14287 for (auto &KV : FoldUser->second)
14288 FoldCache.erase(KV);
14289 FoldCacheUser.erase(S);
14290}
14291
14292void
14293ScalarEvolution::getUsedLoops(const SCEV *S,
14294 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14295 struct FindUsedLoops {
14296 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14297 : LoopsUsed(LoopsUsed) {}
14299 bool follow(const SCEV *S) {
14300 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14301 LoopsUsed.insert(AR->getLoop());
14302 return true;
14303 }
14304
14305 bool isDone() const { return false; }
14306 };
14307
14308 FindUsedLoops F(LoopsUsed);
14310}
14311
14312void ScalarEvolution::getReachableBlocks(
14315 Worklist.push_back(&F.getEntryBlock());
14316 while (!Worklist.empty()) {
14317 BasicBlock *BB = Worklist.pop_back_val();
14318 if (!Reachable.insert(BB).second)
14319 continue;
14320
14321 Value *Cond;
14322 BasicBlock *TrueBB, *FalseBB;
14323 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14324 m_BasicBlock(FalseBB)))) {
14325 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14326 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14327 continue;
14328 }
14329
14330 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14331 const SCEV *L = getSCEV(Cmp->getOperand(0));
14332 const SCEV *R = getSCEV(Cmp->getOperand(1));
14333 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14334 Worklist.push_back(TrueBB);
14335 continue;
14336 }
14337 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14338 R)) {
14339 Worklist.push_back(FalseBB);
14340 continue;
14341 }
14342 }
14343 }
14344
14345 append_range(Worklist, successors(BB));
14346 }
14347}
14348
14350 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14351 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14352
14353 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14354
14355 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14356 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14357 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14358
14359 const SCEV *visitConstant(const SCEVConstant *Constant) {
14360 return SE.getConstant(Constant->getAPInt());
14361 }
14362
14363 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14364 return SE.getUnknown(Expr->getValue());
14365 }
14366
14367 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14368 return SE.getCouldNotCompute();
14369 }
14370 };
14371
14372 SCEVMapper SCM(SE2);
14373 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14374 SE2.getReachableBlocks(ReachableBlocks, F);
14375
14376 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14377 if (containsUndefs(Old) || containsUndefs(New)) {
14378 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14379 // not propagate undef aggressively). This means we can (and do) fail
14380 // verification in cases where a transform makes a value go from "undef"
14381 // to "undef+1" (say). The transform is fine, since in both cases the
14382 // result is "undef", but SCEV thinks the value increased by 1.
14383 return nullptr;
14384 }
14385
14386 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14387 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14388 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14389 return nullptr;
14390
14391 return Delta;
14392 };
14393
14394 while (!LoopStack.empty()) {
14395 auto *L = LoopStack.pop_back_val();
14396 llvm::append_range(LoopStack, *L);
14397
14398 // Only verify BECounts in reachable loops. For an unreachable loop,
14399 // any BECount is legal.
14400 if (!ReachableBlocks.contains(L->getHeader()))
14401 continue;
14402
14403 // Only verify cached BECounts. Computing new BECounts may change the
14404 // results of subsequent SCEV uses.
14405 auto It = BackedgeTakenCounts.find(L);
14406 if (It == BackedgeTakenCounts.end())
14407 continue;
14408
14409 auto *CurBECount =
14410 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14411 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14412
14413 if (CurBECount == SE2.getCouldNotCompute() ||
14414 NewBECount == SE2.getCouldNotCompute()) {
14415 // NB! This situation is legal, but is very suspicious -- whatever pass
14416 // change the loop to make a trip count go from could not compute to
14417 // computable or vice-versa *should have* invalidated SCEV. However, we
14418 // choose not to assert here (for now) since we don't want false
14419 // positives.
14420 continue;
14421 }
14422
14423 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14424 SE.getTypeSizeInBits(NewBECount->getType()))
14425 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14426 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14427 SE.getTypeSizeInBits(NewBECount->getType()))
14428 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14429
14430 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14431 if (Delta && !Delta->isZero()) {
14432 dbgs() << "Trip Count for " << *L << " Changed!\n";
14433 dbgs() << "Old: " << *CurBECount << "\n";
14434 dbgs() << "New: " << *NewBECount << "\n";
14435 dbgs() << "Delta: " << *Delta << "\n";
14436 std::abort();
14437 }
14438 }
14439
14440 // Collect all valid loops currently in LoopInfo.
14441 SmallPtrSet<Loop *, 32> ValidLoops;
14442 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14443 while (!Worklist.empty()) {
14444 Loop *L = Worklist.pop_back_val();
14445 if (ValidLoops.insert(L).second)
14446 Worklist.append(L->begin(), L->end());
14447 }
14448 for (const auto &KV : ValueExprMap) {
14449#ifndef NDEBUG
14450 // Check for SCEV expressions referencing invalid/deleted loops.
14451 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14452 assert(ValidLoops.contains(AR->getLoop()) &&
14453 "AddRec references invalid loop");
14454 }
14455#endif
14456
14457 // Check that the value is also part of the reverse map.
14458 auto It = ExprValueMap.find(KV.second);
14459 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14460 dbgs() << "Value " << *KV.first
14461 << " is in ValueExprMap but not in ExprValueMap\n";
14462 std::abort();
14463 }
14464
14465 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14466 if (!ReachableBlocks.contains(I->getParent()))
14467 continue;
14468 const SCEV *OldSCEV = SCM.visit(KV.second);
14469 const SCEV *NewSCEV = SE2.getSCEV(I);
14470 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14471 if (Delta && !Delta->isZero()) {
14472 dbgs() << "SCEV for value " << *I << " changed!\n"
14473 << "Old: " << *OldSCEV << "\n"
14474 << "New: " << *NewSCEV << "\n"
14475 << "Delta: " << *Delta << "\n";
14476 std::abort();
14477 }
14478 }
14479 }
14480
14481 for (const auto &KV : ExprValueMap) {
14482 for (Value *V : KV.second) {
14483 auto It = ValueExprMap.find_as(V);
14484 if (It == ValueExprMap.end()) {
14485 dbgs() << "Value " << *V
14486 << " is in ExprValueMap but not in ValueExprMap\n";
14487 std::abort();
14488 }
14489 if (It->second != KV.first) {
14490 dbgs() << "Value " << *V << " mapped to " << *It->second
14491 << " rather than " << *KV.first << "\n";
14492 std::abort();
14493 }
14494 }
14495 }
14496
14497 // Verify integrity of SCEV users.
14498 for (const auto &S : UniqueSCEVs) {
14499 for (const auto *Op : S.operands()) {
14500 // We do not store dependencies of constants.
14501 if (isa<SCEVConstant>(Op))
14502 continue;
14503 auto It = SCEVUsers.find(Op);
14504 if (It != SCEVUsers.end() && It->second.count(&S))
14505 continue;
14506 dbgs() << "Use of operand " << *Op << " by user " << S
14507 << " is not being tracked!\n";
14508 std::abort();
14509 }
14510 }
14511
14512 // Verify integrity of ValuesAtScopes users.
14513 for (const auto &ValueAndVec : ValuesAtScopes) {
14514 const SCEV *Value = ValueAndVec.first;
14515 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14516 const Loop *L = LoopAndValueAtScope.first;
14517 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14518 if (!isa<SCEVConstant>(ValueAtScope)) {
14519 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14520 if (It != ValuesAtScopesUsers.end() &&
14521 is_contained(It->second, std::make_pair(L, Value)))
14522 continue;
14523 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14524 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14525 std::abort();
14526 }
14527 }
14528 }
14529
14530 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14531 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14532 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14533 const Loop *L = LoopAndValue.first;
14534 const SCEV *Value = LoopAndValue.second;
14535 assert(!isa<SCEVConstant>(Value));
14536 auto It = ValuesAtScopes.find(Value);
14537 if (It != ValuesAtScopes.end() &&
14538 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14539 continue;
14540 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14541 << *ValueAtScope << " missing in ValuesAtScopes\n";
14542 std::abort();
14543 }
14544 }
14545
14546 // Verify integrity of BECountUsers.
14547 auto VerifyBECountUsers = [&](bool Predicated) {
14548 auto &BECounts =
14549 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14550 for (const auto &LoopAndBEInfo : BECounts) {
14551 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14552 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14553 if (!isa<SCEVConstant>(S)) {
14554 auto UserIt = BECountUsers.find(S);
14555 if (UserIt != BECountUsers.end() &&
14556 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14557 continue;
14558 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14559 << " missing from BECountUsers\n";
14560 std::abort();
14561 }
14562 }
14563 }
14564 }
14565 };
14566 VerifyBECountUsers(/* Predicated */ false);
14567 VerifyBECountUsers(/* Predicated */ true);
14568
14569 // Verify intergity of loop disposition cache.
14570 for (auto &[S, Values] : LoopDispositions) {
14571 for (auto [Loop, CachedDisposition] : Values) {
14572 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14573 if (CachedDisposition != RecomputedDisposition) {
14574 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14575 << " is incorrect: cached " << CachedDisposition << ", actual "
14576 << RecomputedDisposition << "\n";
14577 std::abort();
14578 }
14579 }
14580 }
14581
14582 // Verify integrity of the block disposition cache.
14583 for (auto &[S, Values] : BlockDispositions) {
14584 for (auto [BB, CachedDisposition] : Values) {
14585 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14586 if (CachedDisposition != RecomputedDisposition) {
14587 dbgs() << "Cached disposition of " << *S << " for block %"
14588 << BB->getName() << " is incorrect: cached " << CachedDisposition
14589 << ", actual " << RecomputedDisposition << "\n";
14590 std::abort();
14591 }
14592 }
14593 }
14594
14595 // Verify FoldCache/FoldCacheUser caches.
14596 for (auto [FoldID, Expr] : FoldCache) {
14597 auto I = FoldCacheUser.find(Expr);
14598 if (I == FoldCacheUser.end()) {
14599 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14600 << "!\n";
14601 std::abort();
14602 }
14603 if (!is_contained(I->second, FoldID)) {
14604 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14605 std::abort();
14606 }
14607 }
14608 for (auto [Expr, IDs] : FoldCacheUser) {
14609 for (auto &FoldID : IDs) {
14610 auto I = FoldCache.find(FoldID);
14611 if (I == FoldCache.end()) {
14612 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14613 << "!\n";
14614 std::abort();
14615 }
14616 if (I->second != Expr) {
14617 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14618 << *I->second << " != " << *Expr << "!\n";
14619 std::abort();
14620 }
14621 }
14622 }
14623
14624 // Verify that ConstantMultipleCache computations are correct. We check that
14625 // cached multiples and recomputed multiples are multiples of each other to
14626 // verify correctness. It is possible that a recomputed multiple is different
14627 // from the cached multiple due to strengthened no wrap flags or changes in
14628 // KnownBits computations.
14629 for (auto [S, Multiple] : ConstantMultipleCache) {
14630 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14631 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14632 Multiple.urem(RecomputedMultiple) != 0 &&
14633 RecomputedMultiple.urem(Multiple) != 0)) {
14634 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14635 << *S << " : Computed " << RecomputedMultiple
14636 << " but cache contains " << Multiple << "!\n";
14637 std::abort();
14638 }
14639 }
14640}
14641
14643 Function &F, const PreservedAnalyses &PA,
14645 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14646 // of its dependencies is invalidated.
14647 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14648 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14649 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14651 Inv.invalidate<LoopAnalysis>(F, PA);
14652}
14653
14654AnalysisKey ScalarEvolutionAnalysis::Key;
14655
14658 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14659 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14660 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14661 auto &LI = AM.getResult<LoopAnalysis>(F);
14662 return ScalarEvolution(F, TLI, AC, DT, LI);
14663}
14664
14668 return PreservedAnalyses::all();
14669}
14670
14673 // For compatibility with opt's -analyze feature under legacy pass manager
14674 // which was not ported to NPM. This keeps tests using
14675 // update_analyze_test_checks.py working.
14676 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14677 << F.getName() << "':\n";
14679 return PreservedAnalyses::all();
14680}
14681
14683 "Scalar Evolution Analysis", false, true)
14689 "Scalar Evolution Analysis", false, true)
14690
14692
14695}
14696
14698 SE.reset(new ScalarEvolution(
14699 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14700 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14701 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14702 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14703 return false;
14704}
14705
14707
14709 SE->print(OS);
14710}
14711
14713 if (!VerifySCEV)
14714 return;
14715
14716 SE->verify();
14717}
14718
14720 AU.setPreservesAll();
14725}
14726
14728 const SCEV *RHS) {
14730}
14731
14732const SCEVPredicate *
14734 const SCEV *LHS, const SCEV *RHS) {
14736 assert(LHS->getType() == RHS->getType() &&
14737 "Type mismatch between LHS and RHS");
14738 // Unique this node based on the arguments
14739 ID.AddInteger(SCEVPredicate::P_Compare);
14740 ID.AddInteger(Pred);
14741 ID.AddPointer(LHS);
14742 ID.AddPointer(RHS);
14743 void *IP = nullptr;
14744 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14745 return S;
14746 SCEVComparePredicate *Eq = new (SCEVAllocator)
14747 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14748 UniquePreds.InsertNode(Eq, IP);
14749 return Eq;
14750}
14751
14753 const SCEVAddRecExpr *AR,
14756 // Unique this node based on the arguments
14757 ID.AddInteger(SCEVPredicate::P_Wrap);
14758 ID.AddPointer(AR);
14759 ID.AddInteger(AddedFlags);
14760 void *IP = nullptr;
14761 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14762 return S;
14763 auto *OF = new (SCEVAllocator)
14764 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14765 UniquePreds.InsertNode(OF, IP);
14766 return OF;
14767}
14768
14769namespace {
14770
14771class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14772public:
14773
14774 /// Rewrites \p S in the context of a loop L and the SCEV predication
14775 /// infrastructure.
14776 ///
14777 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14778 /// equivalences present in \p Pred.
14779 ///
14780 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14781 /// \p NewPreds such that the result will be an AddRecExpr.
14782 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14784 const SCEVPredicate *Pred) {
14785 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14786 return Rewriter.visit(S);
14787 }
14788
14789 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14790 if (Pred) {
14791 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14792 for (const auto *Pred : U->getPredicates())
14793 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14794 if (IPred->getLHS() == Expr &&
14795 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14796 return IPred->getRHS();
14797 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14798 if (IPred->getLHS() == Expr &&
14799 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14800 return IPred->getRHS();
14801 }
14802 }
14803 return convertToAddRecWithPreds(Expr);
14804 }
14805
14806 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14807 const SCEV *Operand = visit(Expr->getOperand());
14808 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14809 if (AR && AR->getLoop() == L && AR->isAffine()) {
14810 // This couldn't be folded because the operand didn't have the nuw
14811 // flag. Add the nusw flag as an assumption that we could make.
14812 const SCEV *Step = AR->getStepRecurrence(SE);
14813 Type *Ty = Expr->getType();
14814 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14815 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14816 SE.getSignExtendExpr(Step, Ty), L,
14817 AR->getNoWrapFlags());
14818 }
14819 return SE.getZeroExtendExpr(Operand, Expr->getType());
14820 }
14821
14822 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14823 const SCEV *Operand = visit(Expr->getOperand());
14824 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14825 if (AR && AR->getLoop() == L && AR->isAffine()) {
14826 // This couldn't be folded because the operand didn't have the nsw
14827 // flag. Add the nssw flag as an assumption that we could make.
14828 const SCEV *Step = AR->getStepRecurrence(SE);
14829 Type *Ty = Expr->getType();
14830 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14831 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14832 SE.getSignExtendExpr(Step, Ty), L,
14833 AR->getNoWrapFlags());
14834 }
14835 return SE.getSignExtendExpr(Operand, Expr->getType());
14836 }
14837
14838private:
14839 explicit SCEVPredicateRewriter(
14840 const Loop *L, ScalarEvolution &SE,
14842 const SCEVPredicate *Pred)
14843 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14844
14845 bool addOverflowAssumption(const SCEVPredicate *P) {
14846 if (!NewPreds) {
14847 // Check if we've already made this assumption.
14848 return Pred && Pred->implies(P, SE);
14849 }
14850 NewPreds->push_back(P);
14851 return true;
14852 }
14853
14854 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14856 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14857 return addOverflowAssumption(A);
14858 }
14859
14860 // If \p Expr represents a PHINode, we try to see if it can be represented
14861 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14862 // to add this predicate as a runtime overflow check, we return the AddRec.
14863 // If \p Expr does not meet these conditions (is not a PHI node, or we
14864 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14865 // return \p Expr.
14866 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14867 if (!isa<PHINode>(Expr->getValue()))
14868 return Expr;
14869 std::optional<
14870 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14871 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14872 if (!PredicatedRewrite)
14873 return Expr;
14874 for (const auto *P : PredicatedRewrite->second){
14875 // Wrap predicates from outer loops are not supported.
14876 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14877 if (L != WP->getExpr()->getLoop())
14878 return Expr;
14879 }
14880 if (!addOverflowAssumption(P))
14881 return Expr;
14882 }
14883 return PredicatedRewrite->first;
14884 }
14885
14887 const SCEVPredicate *Pred;
14888 const Loop *L;
14889};
14890
14891} // end anonymous namespace
14892
14893const SCEV *
14895 const SCEVPredicate &Preds) {
14896 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14897}
14898
14900 const SCEV *S, const Loop *L,
14903 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14904 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14905
14906 if (!AddRec)
14907 return nullptr;
14908
14909 // Since the transformation was successful, we can now transfer the SCEV
14910 // predicates.
14911 Preds.append(TransformPreds.begin(), TransformPreds.end());
14912
14913 return AddRec;
14914}
14915
14916/// SCEV predicates
14918 SCEVPredicateKind Kind)
14919 : FastID(ID), Kind(Kind) {}
14920
14922 const ICmpInst::Predicate Pred,
14923 const SCEV *LHS, const SCEV *RHS)
14924 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14925 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14926 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14927}
14928
14930 ScalarEvolution &SE) const {
14931 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14932
14933 if (!Op)
14934 return false;
14935
14936 if (Pred != ICmpInst::ICMP_EQ)
14937 return false;
14938
14939 return Op->LHS == LHS && Op->RHS == RHS;
14940}
14941
14942bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14943
14945 if (Pred == ICmpInst::ICMP_EQ)
14946 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14947 else
14948 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14949 << *RHS << "\n";
14950
14951}
14952
14954 const SCEVAddRecExpr *AR,
14955 IncrementWrapFlags Flags)
14956 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14957
14958const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14959
14961 ScalarEvolution &SE) const {
14962 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14963 if (!Op || setFlags(Flags, Op->Flags) != Flags)
14964 return false;
14965
14966 if (Op->AR == AR)
14967 return true;
14968
14969 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14971 return false;
14972
14973 const SCEV *Start = AR->getStart();
14974 const SCEV *OpStart = Op->AR->getStart();
14975 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
14976 return false;
14977
14978 const SCEV *Step = AR->getStepRecurrence(SE);
14979 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14980 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14981 return false;
14982
14983 // If both steps are positive, this implies N, if N's start and step are
14984 // ULE/SLE (for NSUW/NSSW) than this'.
14985 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14986 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14987 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14988
14989 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14990 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
14991 : SE.getNoopOrSignExtend(OpStart, WiderTy);
14992 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
14993 : SE.getNoopOrSignExtend(Start, WiderTy);
14995 return SE.isKnownPredicate(Pred, OpStep, Step) &&
14996 SE.isKnownPredicate(Pred, OpStart, Start);
14997}
14998
15000 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15001 IncrementWrapFlags IFlags = Flags;
15002
15003 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15004 IFlags = clearFlags(IFlags, IncrementNSSW);
15005
15006 return IFlags == IncrementAnyWrap;
15007}
15008
15010 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15012 OS << "<nusw>";
15014 OS << "<nssw>";
15015 OS << "\n";
15016}
15017
15020 ScalarEvolution &SE) {
15021 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15022 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15023
15024 // We can safely transfer the NSW flag as NSSW.
15025 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15026 ImpliedFlags = IncrementNSSW;
15027
15028 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15029 // If the increment is positive, the SCEV NUW flag will also imply the
15030 // WrapPredicate NUSW flag.
15031 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15032 if (Step->getValue()->getValue().isNonNegative())
15033 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15034 }
15035
15036 return ImpliedFlags;
15037}
15038
15039/// Union predicates don't get cached so create a dummy set ID for it.
15041 ScalarEvolution &SE)
15042 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15043 for (const auto *P : Preds)
15044 add(P, SE);
15045}
15046
15048 return all_of(Preds,
15049 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15050}
15051
15053 ScalarEvolution &SE) const {
15054 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15055 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15056 return this->implies(I, SE);
15057 });
15058
15059 return any_of(Preds,
15060 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15061}
15062
15064 for (const auto *Pred : Preds)
15065 Pred->print(OS, Depth);
15066}
15067
15068void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15069 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15070 for (const auto *Pred : Set->Preds)
15071 add(Pred, SE);
15072 return;
15073 }
15074
15075 // Only add predicate if it is not already implied by this union predicate.
15076 if (implies(N, SE))
15077 return;
15078
15079 // Build a new vector containing the current predicates, except the ones that
15080 // are implied by the new predicate N.
15082 for (auto *P : Preds) {
15083 if (N->implies(P, SE))
15084 continue;
15085 PrunedPreds.push_back(P);
15086 }
15087 Preds = std::move(PrunedPreds);
15088 Preds.push_back(N);
15089}
15090
15092 Loop &L)
15093 : SE(SE), L(L) {
15095 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15096}
15097
15100 for (const auto *Op : Ops)
15101 // We do not expect that forgetting cached data for SCEVConstants will ever
15102 // open any prospects for sharpening or introduce any correctness issues,
15103 // so we don't bother storing their dependencies.
15104 if (!isa<SCEVConstant>(Op))
15105 SCEVUsers[Op].insert(User);
15106}
15107
15109 const SCEV *Expr = SE.getSCEV(V);
15110 RewriteEntry &Entry = RewriteMap[Expr];
15111
15112 // If we already have an entry and the version matches, return it.
15113 if (Entry.second && Generation == Entry.first)
15114 return Entry.second;
15115
15116 // We found an entry but it's stale. Rewrite the stale entry
15117 // according to the current predicate.
15118 if (Entry.second)
15119 Expr = Entry.second;
15120
15121 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15122 Entry = {Generation, NewSCEV};
15123
15124 return NewSCEV;
15125}
15126
15128 if (!BackedgeCount) {
15130 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15131 for (const auto *P : Preds)
15132 addPredicate(*P);
15133 }
15134 return BackedgeCount;
15135}
15136
15138 if (!SymbolicMaxBackedgeCount) {
15140 SymbolicMaxBackedgeCount =
15142 for (const auto *P : Preds)
15143 addPredicate(*P);
15144 }
15145 return SymbolicMaxBackedgeCount;
15146}
15147
15149 if (!SmallConstantMaxTripCount) {
15151 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15152 for (const auto *P : Preds)
15153 addPredicate(*P);
15154 }
15155 return *SmallConstantMaxTripCount;
15156}
15157
15159 if (Preds->implies(&Pred, SE))
15160 return;
15161
15162 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15163 NewPreds.push_back(&Pred);
15164 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15165 updateGeneration();
15166}
15167
15169 return *Preds;
15170}
15171
15172void PredicatedScalarEvolution::updateGeneration() {
15173 // If the generation number wrapped recompute everything.
15174 if (++Generation == 0) {
15175 for (auto &II : RewriteMap) {
15176 const SCEV *Rewritten = II.second.second;
15177 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15178 }
15179 }
15180}
15181
15184 const SCEV *Expr = getSCEV(V);
15185 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15186
15187 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15188
15189 // Clear the statically implied flags.
15190 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15191 addPredicate(*SE.getWrapPredicate(AR, Flags));
15192
15193 auto II = FlagsMap.insert({V, Flags});
15194 if (!II.second)
15195 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15196}
15197
15200 const SCEV *Expr = getSCEV(V);
15201 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15202
15204 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15205
15206 auto II = FlagsMap.find(V);
15207
15208 if (II != FlagsMap.end())
15209 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15210
15212}
15213
15215 const SCEV *Expr = this->getSCEV(V);
15217 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15218
15219 if (!New)
15220 return nullptr;
15221
15222 for (const auto *P : NewPreds)
15223 addPredicate(*P);
15224
15225 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15226 return New;
15227}
15228
15231 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15232 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15233 SE)),
15234 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15235 for (auto I : Init.FlagsMap)
15236 FlagsMap.insert(I);
15237}
15238
15240 // For each block.
15241 for (auto *BB : L.getBlocks())
15242 for (auto &I : *BB) {
15243 if (!SE.isSCEVable(I.getType()))
15244 continue;
15245
15246 auto *Expr = SE.getSCEV(&I);
15247 auto II = RewriteMap.find(Expr);
15248
15249 if (II == RewriteMap.end())
15250 continue;
15251
15252 // Don't print things that are not interesting.
15253 if (II->second.second == Expr)
15254 continue;
15255
15256 OS.indent(Depth) << "[PSE]" << I << ":\n";
15257 OS.indent(Depth + 2) << *Expr << "\n";
15258 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15259 }
15260}
15261
15262// Match the mathematical pattern A - (A / B) * B, where A and B can be
15263// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15264// for URem with constant power-of-2 second operands.
15265// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15266// 4, A / B becomes X / 8).
15267bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15268 const SCEV *&RHS) {
15269 if (Expr->getType()->isPointerTy())
15270 return false;
15271
15272 // Try to match 'zext (trunc A to iB) to iY', which is used
15273 // for URem with constant power-of-2 second operands. Make sure the size of
15274 // the operand A matches the size of the whole expressions.
15275 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15276 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15277 LHS = Trunc->getOperand();
15278 // Bail out if the type of the LHS is larger than the type of the
15279 // expression for now.
15280 if (getTypeSizeInBits(LHS->getType()) >
15281 getTypeSizeInBits(Expr->getType()))
15282 return false;
15283 if (LHS->getType() != Expr->getType())
15284 LHS = getZeroExtendExpr(LHS, Expr->getType());
15286 << getTypeSizeInBits(Trunc->getType()));
15287 return true;
15288 }
15289 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15290 if (Add == nullptr || Add->getNumOperands() != 2)
15291 return false;
15292
15293 const SCEV *A = Add->getOperand(1);
15294 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15295
15296 if (Mul == nullptr)
15297 return false;
15298
15299 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15300 // (SomeExpr + (-(SomeExpr / B) * B)).
15301 if (Expr == getURemExpr(A, B)) {
15302 LHS = A;
15303 RHS = B;
15304 return true;
15305 }
15306 return false;
15307 };
15308
15309 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15310 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15311 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15312 MatchURemWithDivisor(Mul->getOperand(2));
15313
15314 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15315 if (Mul->getNumOperands() == 2)
15316 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15317 MatchURemWithDivisor(Mul->getOperand(0)) ||
15318 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15319 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15320 return false;
15321}
15322
15325 BasicBlock *Header = L->getHeader();
15326 BasicBlock *Pred = L->getLoopPredecessor();
15327 LoopGuards Guards(SE);
15328 if (!Pred)
15329 return Guards;
15331 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15332 return Guards;
15333}
15334
15335void ScalarEvolution::LoopGuards::collectFromPHI(
15337 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15339 unsigned Depth) {
15340 if (!SE.isSCEVable(Phi.getType()))
15341 return;
15342
15343 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15344 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15345 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15346 if (!VisitedBlocks.insert(InBlock).second)
15347 return {nullptr, scCouldNotCompute};
15348 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15349 if (Inserted)
15350 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15351 Depth + 1);
15352 auto &RewriteMap = G->second.RewriteMap;
15353 if (RewriteMap.empty())
15354 return {nullptr, scCouldNotCompute};
15355 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15356 if (S == RewriteMap.end())
15357 return {nullptr, scCouldNotCompute};
15358 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15359 if (!SM)
15360 return {nullptr, scCouldNotCompute};
15361 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15362 return {C0, SM->getSCEVType()};
15363 return {nullptr, scCouldNotCompute};
15364 };
15365 auto MergeMinMaxConst = [](MinMaxPattern P1,
15366 MinMaxPattern P2) -> MinMaxPattern {
15367 auto [C1, T1] = P1;
15368 auto [C2, T2] = P2;
15369 if (!C1 || !C2 || T1 != T2)
15370 return {nullptr, scCouldNotCompute};
15371 switch (T1) {
15372 case scUMaxExpr:
15373 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15374 case scSMaxExpr:
15375 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15376 case scUMinExpr:
15377 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15378 case scSMinExpr:
15379 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15380 default:
15381 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15382 }
15383 };
15384 auto P = GetMinMaxConst(0);
15385 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15386 if (!P.first)
15387 break;
15388 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15389 }
15390 if (P.first) {
15391 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15392 SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15393 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15394 Guards.RewriteMap.insert({LHS, RHS});
15395 }
15396}
15397
15398void ScalarEvolution::LoopGuards::collectFromBlock(
15400 const BasicBlock *Block, const BasicBlock *Pred,
15401 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15402 SmallVector<const SCEV *> ExprsToRewrite;
15403 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15404 const SCEV *RHS,
15406 &RewriteMap) {
15407 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15408 // replacement SCEV which isn't directly implied by the structure of that
15409 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15410 // legal. See the scoping rules for flags in the header to understand why.
15411
15412 // If LHS is a constant, apply information to the other expression.
15413 if (isa<SCEVConstant>(LHS)) {
15414 std::swap(LHS, RHS);
15416 }
15417
15418 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15419 // create this form when combining two checks of the form (X u< C2 + C1) and
15420 // (X >=u C1).
15421 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15422 &ExprsToRewrite]() {
15423 const SCEVConstant *C1;
15424 const SCEVUnknown *LHSUnknown;
15425 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15426 if (!match(LHS,
15427 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15428 !C2)
15429 return false;
15430
15431 auto ExactRegion =
15433 .sub(C1->getAPInt());
15434
15435 // Bail out, unless we have a non-wrapping, monotonic range.
15436 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15437 return false;
15438 auto I = RewriteMap.find(LHSUnknown);
15439 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15440 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15441 SE.getConstant(ExactRegion.getUnsignedMin()),
15442 SE.getUMinExpr(RewrittenLHS,
15443 SE.getConstant(ExactRegion.getUnsignedMax())));
15444 ExprsToRewrite.push_back(LHSUnknown);
15445 return true;
15446 };
15447 if (MatchRangeCheckIdiom())
15448 return;
15449
15450 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15451 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15452 // the non-constant operand and in \p LHS the constant operand.
15453 auto IsMinMaxSCEVWithNonNegativeConstant =
15454 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15455 const SCEV *&RHS) {
15456 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15457 if (MinMax->getNumOperands() != 2)
15458 return false;
15459 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15460 if (C->getAPInt().isNegative())
15461 return false;
15462 SCTy = MinMax->getSCEVType();
15463 LHS = MinMax->getOperand(0);
15464 RHS = MinMax->getOperand(1);
15465 return true;
15466 }
15467 }
15468 return false;
15469 };
15470
15471 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15472 // constant, and returns their APInt in ExprVal and in DivisorVal.
15473 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15474 APInt &ExprVal, APInt &DivisorVal) {
15475 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15476 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15477 if (!ConstExpr || !ConstDivisor)
15478 return false;
15479 ExprVal = ConstExpr->getAPInt();
15480 DivisorVal = ConstDivisor->getAPInt();
15481 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15482 };
15483
15484 // Return a new SCEV that modifies \p Expr to the closest number divides by
15485 // \p Divisor and greater or equal than Expr.
15486 // For now, only handle constant Expr and Divisor.
15487 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15488 const SCEV *Divisor) {
15489 APInt ExprVal;
15490 APInt DivisorVal;
15491 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15492 return Expr;
15493 APInt Rem = ExprVal.urem(DivisorVal);
15494 if (!Rem.isZero())
15495 // return the SCEV: Expr + Divisor - Expr % Divisor
15496 return SE.getConstant(ExprVal + DivisorVal - Rem);
15497 return Expr;
15498 };
15499
15500 // Return a new SCEV that modifies \p Expr to the closest number divides by
15501 // \p Divisor and less or equal than Expr.
15502 // For now, only handle constant Expr and Divisor.
15503 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15504 const SCEV *Divisor) {
15505 APInt ExprVal;
15506 APInt DivisorVal;
15507 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15508 return Expr;
15509 APInt Rem = ExprVal.urem(DivisorVal);
15510 // return the SCEV: Expr - Expr % Divisor
15511 return SE.getConstant(ExprVal - Rem);
15512 };
15513
15514 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15515 // recursively. This is done by aligning up/down the constant value to the
15516 // Divisor.
15517 std::function<const SCEV *(const SCEV *, const SCEV *)>
15518 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15519 const SCEV *Divisor) {
15520 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15521 SCEVTypes SCTy;
15522 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15523 MinMaxRHS))
15524 return MinMaxExpr;
15525 auto IsMin =
15526 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15527 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15528 "Expected non-negative operand!");
15529 auto *DivisibleExpr =
15530 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15531 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15533 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15534 return SE.getMinMaxExpr(SCTy, Ops);
15535 };
15536
15537 // If we have LHS == 0, check if LHS is computing a property of some unknown
15538 // SCEV %v which we can rewrite %v to express explicitly.
15539 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15540 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15541 // explicitly express that.
15542 const SCEV *URemLHS = nullptr;
15543 const SCEV *URemRHS = nullptr;
15544 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15545 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15546 auto I = RewriteMap.find(LHSUnknown);
15547 const SCEV *RewrittenLHS =
15548 I != RewriteMap.end() ? I->second : LHSUnknown;
15549 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15550 const auto *Multiple =
15551 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15552 RewriteMap[LHSUnknown] = Multiple;
15553 ExprsToRewrite.push_back(LHSUnknown);
15554 return;
15555 }
15556 }
15557 }
15558
15559 // Do not apply information for constants or if RHS contains an AddRec.
15560 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15561 return;
15562
15563 // If RHS is SCEVUnknown, make sure the information is applied to it.
15564 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15565 std::swap(LHS, RHS);
15567 }
15568
15569 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15570 // and \p FromRewritten are the same (i.e. there has been no rewrite
15571 // registered for \p From), then puts this value in the list of rewritten
15572 // expressions.
15573 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15574 const SCEV *To) {
15575 if (From == FromRewritten)
15576 ExprsToRewrite.push_back(From);
15577 RewriteMap[From] = To;
15578 };
15579
15580 // Checks whether \p S has already been rewritten. In that case returns the
15581 // existing rewrite because we want to chain further rewrites onto the
15582 // already rewritten value. Otherwise returns \p S.
15583 auto GetMaybeRewritten = [&](const SCEV *S) {
15584 auto I = RewriteMap.find(S);
15585 return I != RewriteMap.end() ? I->second : S;
15586 };
15587
15588 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15589 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15590 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15591 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15592 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15593 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15594 // DividesBy.
15595 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15596 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15597 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15598 if (Mul->getNumOperands() != 2)
15599 return false;
15600 auto *MulLHS = Mul->getOperand(0);
15601 auto *MulRHS = Mul->getOperand(1);
15602 if (isa<SCEVConstant>(MulLHS))
15603 std::swap(MulLHS, MulRHS);
15604 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15605 if (Div->getOperand(1) == MulRHS) {
15606 DividesBy = MulRHS;
15607 return true;
15608 }
15609 }
15610 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15611 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15612 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15613 return false;
15614 };
15615
15616 // Return true if Expr known to divide by \p DividesBy.
15617 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15618 [&](const SCEV *Expr, const SCEV *DividesBy) {
15619 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15620 return true;
15621 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15622 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15623 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15624 return false;
15625 };
15626
15627 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15628 const SCEV *DividesBy = nullptr;
15629 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15630 // Check that the whole expression is divided by DividesBy
15631 DividesBy =
15632 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15633
15634 // Collect rewrites for LHS and its transitive operands based on the
15635 // condition.
15636 // For min/max expressions, also apply the guard to its operands:
15637 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15638 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15639 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15640 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15641
15642 // We cannot express strict predicates in SCEV, so instead we replace them
15643 // with non-strict ones against plus or minus one of RHS depending on the
15644 // predicate.
15645 const SCEV *One = SE.getOne(RHS->getType());
15646 switch (Predicate) {
15647 case CmpInst::ICMP_ULT:
15648 if (RHS->getType()->isPointerTy())
15649 return;
15650 RHS = SE.getUMaxExpr(RHS, One);
15651 [[fallthrough]];
15652 case CmpInst::ICMP_SLT: {
15653 RHS = SE.getMinusSCEV(RHS, One);
15654 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15655 break;
15656 }
15657 case CmpInst::ICMP_UGT:
15658 case CmpInst::ICMP_SGT:
15659 RHS = SE.getAddExpr(RHS, One);
15660 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15661 break;
15662 case CmpInst::ICMP_ULE:
15663 case CmpInst::ICMP_SLE:
15664 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15665 break;
15666 case CmpInst::ICMP_UGE:
15667 case CmpInst::ICMP_SGE:
15668 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15669 break;
15670 default:
15671 break;
15672 }
15673
15674 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15676
15677 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15678 append_range(Worklist, S->operands());
15679 };
15680
15681 while (!Worklist.empty()) {
15682 const SCEV *From = Worklist.pop_back_val();
15683 if (isa<SCEVConstant>(From))
15684 continue;
15685 if (!Visited.insert(From).second)
15686 continue;
15687 const SCEV *FromRewritten = GetMaybeRewritten(From);
15688 const SCEV *To = nullptr;
15689
15690 switch (Predicate) {
15691 case CmpInst::ICMP_ULT:
15692 case CmpInst::ICMP_ULE:
15693 To = SE.getUMinExpr(FromRewritten, RHS);
15694 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15695 EnqueueOperands(UMax);
15696 break;
15697 case CmpInst::ICMP_SLT:
15698 case CmpInst::ICMP_SLE:
15699 To = SE.getSMinExpr(FromRewritten, RHS);
15700 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15701 EnqueueOperands(SMax);
15702 break;
15703 case CmpInst::ICMP_UGT:
15704 case CmpInst::ICMP_UGE:
15705 To = SE.getUMaxExpr(FromRewritten, RHS);
15706 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15707 EnqueueOperands(UMin);
15708 break;
15709 case CmpInst::ICMP_SGT:
15710 case CmpInst::ICMP_SGE:
15711 To = SE.getSMaxExpr(FromRewritten, RHS);
15712 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15713 EnqueueOperands(SMin);
15714 break;
15715 case CmpInst::ICMP_EQ:
15716 if (isa<SCEVConstant>(RHS))
15717 To = RHS;
15718 break;
15719 case CmpInst::ICMP_NE:
15720 if (match(RHS, m_scev_Zero())) {
15721 const SCEV *OneAlignedUp =
15722 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15723 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15724 }
15725 break;
15726 default:
15727 break;
15728 }
15729
15730 if (To)
15731 AddRewrite(From, FromRewritten, To);
15732 }
15733 };
15734
15736 // First, collect information from assumptions dominating the loop.
15737 for (auto &AssumeVH : SE.AC.assumptions()) {
15738 if (!AssumeVH)
15739 continue;
15740 auto *AssumeI = cast<CallInst>(AssumeVH);
15741 if (!SE.DT.dominates(AssumeI, Block))
15742 continue;
15743 Terms.emplace_back(AssumeI->getOperand(0), true);
15744 }
15745
15746 // Second, collect information from llvm.experimental.guards dominating the loop.
15747 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15748 SE.F.getParent(), Intrinsic::experimental_guard);
15749 if (GuardDecl)
15750 for (const auto *GU : GuardDecl->users())
15751 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15752 if (Guard->getFunction() == Block->getParent() &&
15753 SE.DT.dominates(Guard, Block))
15754 Terms.emplace_back(Guard->getArgOperand(0), true);
15755
15756 // Third, collect conditions from dominating branches. Starting at the loop
15757 // predecessor, climb up the predecessor chain, as long as there are
15758 // predecessors that can be found that have unique successors leading to the
15759 // original header.
15760 // TODO: share this logic with isLoopEntryGuardedByCond.
15761 unsigned NumCollectedConditions = 0;
15762 VisitedBlocks.insert(Block);
15763 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15764 for (; Pair.first;
15765 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15766 VisitedBlocks.insert(Pair.second);
15767 const BranchInst *LoopEntryPredicate =
15768 dyn_cast<BranchInst>(Pair.first->getTerminator());
15769 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15770 continue;
15771
15772 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15773 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15774 NumCollectedConditions++;
15775
15776 // If we are recursively collecting guards stop after 2
15777 // conditions to limit compile-time impact for now.
15778 if (Depth > 0 && NumCollectedConditions == 2)
15779 break;
15780 }
15781 // Finally, if we stopped climbing the predecessor chain because
15782 // there wasn't a unique one to continue, try to collect conditions
15783 // for PHINodes by recursively following all of their incoming
15784 // blocks and try to merge the found conditions to build a new one
15785 // for the Phi.
15786 if (Pair.second->hasNPredecessorsOrMore(2) &&
15789 for (auto &Phi : Pair.second->phis())
15790 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15791 }
15792
15793 // Now apply the information from the collected conditions to
15794 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15795 // earliest conditions is processed first. This ensures the SCEVs with the
15796 // shortest dependency chains are constructed first.
15797 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15798 SmallVector<Value *, 8> Worklist;
15800 Worklist.push_back(Term);
15801 while (!Worklist.empty()) {
15802 Value *Cond = Worklist.pop_back_val();
15803 if (!Visited.insert(Cond).second)
15804 continue;
15805
15806 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15807 auto Predicate =
15808 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15809 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15810 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15811 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15812 continue;
15813 }
15814
15815 Value *L, *R;
15816 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15817 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15818 Worklist.push_back(L);
15819 Worklist.push_back(R);
15820 }
15821 }
15822 }
15823
15824 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15825 // the replacement expressions are contained in the ranges of the replaced
15826 // expressions.
15827 Guards.PreserveNUW = true;
15828 Guards.PreserveNSW = true;
15829 for (const SCEV *Expr : ExprsToRewrite) {
15830 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15831 Guards.PreserveNUW &=
15832 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15833 Guards.PreserveNSW &=
15834 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15835 }
15836
15837 // Now that all rewrite information is collect, rewrite the collected
15838 // expressions with the information in the map. This applies information to
15839 // sub-expressions.
15840 if (ExprsToRewrite.size() > 1) {
15841 for (const SCEV *Expr : ExprsToRewrite) {
15842 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15843 Guards.RewriteMap.erase(Expr);
15844 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15845 }
15846 }
15847}
15848
15850 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15851 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15852 /// replacement is loop invariant in the loop of the AddRec.
15853 class SCEVLoopGuardRewriter
15854 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15856
15858
15859 public:
15860 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15861 const ScalarEvolution::LoopGuards &Guards)
15862 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15863 if (Guards.PreserveNUW)
15864 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15865 if (Guards.PreserveNSW)
15866 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15867 }
15868
15869 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15870
15871 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15872 auto I = Map.find(Expr);
15873 if (I == Map.end())
15874 return Expr;
15875 return I->second;
15876 }
15877
15878 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15879 auto I = Map.find(Expr);
15880 if (I == Map.end()) {
15881 // If we didn't find the extact ZExt expr in the map, check if there's
15882 // an entry for a smaller ZExt we can use instead.
15883 Type *Ty = Expr->getType();
15884 const SCEV *Op = Expr->getOperand(0);
15885 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15886 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15887 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15888 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15889 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15890 auto I = Map.find(NarrowExt);
15891 if (I != Map.end())
15892 return SE.getZeroExtendExpr(I->second, Ty);
15893 Bitwidth = Bitwidth / 2;
15894 }
15895
15897 Expr);
15898 }
15899 return I->second;
15900 }
15901
15902 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15903 auto I = Map.find(Expr);
15904 if (I == Map.end())
15906 Expr);
15907 return I->second;
15908 }
15909
15910 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15911 auto I = Map.find(Expr);
15912 if (I == Map.end())
15914 return I->second;
15915 }
15916
15917 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15918 auto I = Map.find(Expr);
15919 if (I == Map.end())
15921 return I->second;
15922 }
15923
15924 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15926 bool Changed = false;
15927 for (const auto *Op : Expr->operands()) {
15928 Operands.push_back(
15930 Changed |= Op != Operands.back();
15931 }
15932 // We are only replacing operands with equivalent values, so transfer the
15933 // flags from the original expression.
15934 return !Changed ? Expr
15935 : SE.getAddExpr(Operands,
15937 Expr->getNoWrapFlags(), FlagMask));
15938 }
15939
15940 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15942 bool Changed = false;
15943 for (const auto *Op : Expr->operands()) {
15944 Operands.push_back(
15946 Changed |= Op != Operands.back();
15947 }
15948 // We are only replacing operands with equivalent values, so transfer the
15949 // flags from the original expression.
15950 return !Changed ? Expr
15951 : SE.getMulExpr(Operands,
15953 Expr->getNoWrapFlags(), FlagMask));
15954 }
15955 };
15956
15957 if (RewriteMap.empty())
15958 return Expr;
15959
15960 SCEVLoopGuardRewriter Rewriter(SE, *this);
15961 return Rewriter.visit(Expr);
15962}
15963
15964const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15965 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15966}
15967
15969 const LoopGuards &Guards) {
15970 return Guards.rewrite(Expr);
15971}
@ Poison
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:622
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:557
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:245
Virtual Register Rewriter
Definition: VirtRegMap.cpp:261
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1945
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1547
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:986
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1520
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1392
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:612
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1007
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1492
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:910
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:206
APInt abs() const
Get the absolute value.
Definition: APInt.h:1773
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition: APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:466
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1640
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1468
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:219
unsigned countTrailingZeros() const
Definition: APInt.h:1626
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:356
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:827
APInt multiplicativeInverse() const
Definition: APInt.cpp:1248
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1150
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:959
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1221
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:231
iterator end() const
Definition: ArrayRef.h:157
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
iterator begin() const
Definition: ArrayRef.h:156
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:461
const Instruction & front() const
Definition: BasicBlock.h:484
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:481
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:220
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:240
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:370
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:946
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:673
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:703
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:697
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:696
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:700
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:698
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:701
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:699
bool isSigned() const
Definition: InstrTypes.h:928
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:825
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:940
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:869
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:787
bool isUnsigned() const
Definition: InstrTypes.h:934
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:924
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
Definition: CmpPredicate.h:22
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2632
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2293
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1267
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2638
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2626
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2279
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:873
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:148
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:880
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:709
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:851
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:754
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:878
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:617
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition: DenseMap.h:226
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:176
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const BasicBlock & getEntryBlock() const
Definition: Function.h:821
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:657
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:407
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:404
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
Definition: DerivedTypes.h:42
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
An instruction for reading from memory.
Definition: Instructions.h:176
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1073
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:686
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:121
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:458
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
size_type size() const
Definition: SmallSet.h:170
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void reserve(size_type N)
Definition: SmallVector.h:663
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:805
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:567
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:596
TypeSize getSizeInBits() const
Definition: DataLayout.h:576
Class to represent struct types.
Definition: DerivedTypes.h:218
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:264
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:252
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:35
op_range operands()
Definition: User.h:288
Use & Op()
Definition: User.h:192
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5144
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1094
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2217
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2222
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2227
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2785
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2232
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:771
@ Entry
Definition: COFF.h:844
@ Exit
Definition: COFF.h:845
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:747
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:832
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bind_ty< const SCEVConstant > m_SCEVConstant(const SCEVConstant *&V)
bind_ty< const SCEV > m_SCEV(const SCEV *&V)
Match a SCEV, capturing it if we match.
SCEVBinaryExpr_match< SCEVAddExpr, Op0_t, Op1_t > m_scev_Add(const Op0_t &Op0, const Op1_t &Op1)
bool match(const SCEV *S, const Pattern &P)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:390
void stable_sort(R &&Range)
Definition: STLExtras.h:2037
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:256
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7301
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2115
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2107
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1162
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1152
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:256
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:261
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2014
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:303
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1873
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1945
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1903
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:293
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:100
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:43
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:188
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:137
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:121
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:97
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.