LLVM 21.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
85#include "llvm/Config/llvm-config.h"
86#include "llvm/IR/Argument.h"
87#include "llvm/IR/BasicBlock.h"
88#include "llvm/IR/CFG.h"
89#include "llvm/IR/Constant.h"
91#include "llvm/IR/Constants.h"
92#include "llvm/IR/DataLayout.h"
94#include "llvm/IR/Dominators.h"
95#include "llvm/IR/Function.h"
96#include "llvm/IR/GlobalAlias.h"
97#include "llvm/IR/GlobalValue.h"
99#include "llvm/IR/InstrTypes.h"
100#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Instructions.h"
103#include "llvm/IR/Intrinsics.h"
104#include "llvm/IR/LLVMContext.h"
105#include "llvm/IR/Operator.h"
106#include "llvm/IR/PatternMatch.h"
107#include "llvm/IR/Type.h"
108#include "llvm/IR/Use.h"
109#include "llvm/IR/User.h"
110#include "llvm/IR/Value.h"
111#include "llvm/IR/Verifier.h"
113#include "llvm/Pass.h"
114#include "llvm/Support/Casting.h"
117#include "llvm/Support/Debug.h"
122#include <algorithm>
123#include <cassert>
124#include <climits>
125#include <cstdint>
126#include <cstdlib>
127#include <map>
128#include <memory>
129#include <numeric>
130#include <optional>
131#include <tuple>
132#include <utility>
133#include <vector>
134
135using namespace llvm;
136using namespace PatternMatch;
137using namespace SCEVPatternMatch;
138
139#define DEBUG_TYPE "scalar-evolution"
140
141STATISTIC(NumExitCountsComputed,
142 "Number of loop exits with predictable exit counts");
143STATISTIC(NumExitCountsNotComputed,
144 "Number of loop exits without predictable exit counts");
145STATISTIC(NumBruteForceTripCountsComputed,
146 "Number of loops with trip counts computed by force");
147
148#ifdef EXPENSIVE_CHECKS
149bool llvm::VerifySCEV = true;
150#else
151bool llvm::VerifySCEV = false;
152#endif
153
155 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
156 cl::desc("Maximum number of iterations SCEV will "
157 "symbolically execute a constant "
158 "derived loop"),
159 cl::init(100));
160
162 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
163 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
165 "verify-scev-strict", cl::Hidden,
166 cl::desc("Enable stricter verification with -verify-scev is passed"));
167
169 "scev-verify-ir", cl::Hidden,
170 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
171 cl::init(false));
172
174 "scev-mulops-inline-threshold", cl::Hidden,
175 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
176 cl::init(32));
177
179 "scev-addops-inline-threshold", cl::Hidden,
180 cl::desc("Threshold for inlining addition operands into a SCEV"),
181 cl::init(500));
182
184 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
185 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
186 cl::init(32));
187
189 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
190 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
191 cl::init(2));
192
194 "scalar-evolution-max-value-compare-depth", cl::Hidden,
195 cl::desc("Maximum depth of recursive value complexity comparisons"),
196 cl::init(2));
197
199 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
200 cl::desc("Maximum depth of recursive arithmetics"),
201 cl::init(32));
202
204 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
205 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
206
208 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
209 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
210 cl::init(8));
211
213 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
214 cl::desc("Max coefficients in AddRec during evolving"),
215 cl::init(8));
216
218 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
219 cl::desc("Size of the expression which is considered huge"),
220 cl::init(4096));
221
223 "scev-range-iter-threshold", cl::Hidden,
224 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
225 cl::init(32));
226
228 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
229 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
230
231static cl::opt<bool>
232ClassifyExpressions("scalar-evolution-classify-expressions",
233 cl::Hidden, cl::init(true),
234 cl::desc("When printing analysis, include information on every instruction"));
235
237 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
238 cl::init(false),
239 cl::desc("Use more powerful methods of sharpening expression ranges. May "
240 "be costly in terms of compile time"));
241
243 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
244 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
245 "Phi strongly connected components"),
246 cl::init(8));
247
248static cl::opt<bool>
249 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
250 cl::desc("Handle <= and >= in finite loops"),
251 cl::init(true));
252
254 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
255 cl::desc("Infer nuw/nsw flags using context where suitable"),
256 cl::init(true));
257
258//===----------------------------------------------------------------------===//
259// SCEV class definitions
260//===----------------------------------------------------------------------===//
261
262//===----------------------------------------------------------------------===//
263// Implementation of the SCEV class.
264//
265
266#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
268 print(dbgs());
269 dbgs() << '\n';
270}
271#endif
272
274 switch (getSCEVType()) {
275 case scConstant:
276 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
277 return;
278 case scVScale:
279 OS << "vscale";
280 return;
281 case scPtrToInt: {
282 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
283 const SCEV *Op = PtrToInt->getOperand();
284 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
285 << *PtrToInt->getType() << ")";
286 return;
287 }
288 case scTruncate: {
289 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
290 const SCEV *Op = Trunc->getOperand();
291 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
292 << *Trunc->getType() << ")";
293 return;
294 }
295 case scZeroExtend: {
296 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
297 const SCEV *Op = ZExt->getOperand();
298 OS << "(zext " << *Op->getType() << " " << *Op << " to "
299 << *ZExt->getType() << ")";
300 return;
301 }
302 case scSignExtend: {
303 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
304 const SCEV *Op = SExt->getOperand();
305 OS << "(sext " << *Op->getType() << " " << *Op << " to "
306 << *SExt->getType() << ")";
307 return;
308 }
309 case scAddRecExpr: {
310 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
311 OS << "{" << *AR->getOperand(0);
312 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
313 OS << ",+," << *AR->getOperand(i);
314 OS << "}<";
315 if (AR->hasNoUnsignedWrap())
316 OS << "nuw><";
317 if (AR->hasNoSignedWrap())
318 OS << "nsw><";
319 if (AR->hasNoSelfWrap() &&
321 OS << "nw><";
322 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
323 OS << ">";
324 return;
325 }
326 case scAddExpr:
327 case scMulExpr:
328 case scUMaxExpr:
329 case scSMaxExpr:
330 case scUMinExpr:
331 case scSMinExpr:
333 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
334 const char *OpStr = nullptr;
335 switch (NAry->getSCEVType()) {
336 case scAddExpr: OpStr = " + "; break;
337 case scMulExpr: OpStr = " * "; break;
338 case scUMaxExpr: OpStr = " umax "; break;
339 case scSMaxExpr: OpStr = " smax "; break;
340 case scUMinExpr:
341 OpStr = " umin ";
342 break;
343 case scSMinExpr:
344 OpStr = " smin ";
345 break;
347 OpStr = " umin_seq ";
348 break;
349 default:
350 llvm_unreachable("There are no other nary expression types.");
351 }
352 OS << "(";
353 ListSeparator LS(OpStr);
354 for (const SCEV *Op : NAry->operands())
355 OS << LS << *Op;
356 OS << ")";
357 switch (NAry->getSCEVType()) {
358 case scAddExpr:
359 case scMulExpr:
360 if (NAry->hasNoUnsignedWrap())
361 OS << "<nuw>";
362 if (NAry->hasNoSignedWrap())
363 OS << "<nsw>";
364 break;
365 default:
366 // Nothing to print for other nary expressions.
367 break;
368 }
369 return;
370 }
371 case scUDivExpr: {
372 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
373 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
374 return;
375 }
376 case scUnknown:
377 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
378 return;
380 OS << "***COULDNOTCOMPUTE***";
381 return;
382 }
383 llvm_unreachable("Unknown SCEV kind!");
384}
385
387 switch (getSCEVType()) {
388 case scConstant:
389 return cast<SCEVConstant>(this)->getType();
390 case scVScale:
391 return cast<SCEVVScale>(this)->getType();
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToInt:
427 case scTruncate:
428 case scZeroExtend:
429 case scSignExtend:
430 return cast<SCEVCastExpr>(this)->operands();
431 case scAddRecExpr:
432 case scAddExpr:
433 case scMulExpr:
434 case scUMaxExpr:
435 case scSMaxExpr:
436 case scUMinExpr:
437 case scSMinExpr:
439 return cast<SCEVNAryExpr>(this)->operands();
440 case scUDivExpr:
441 return cast<SCEVUDivExpr>(this)->operands();
443 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
444 }
445 llvm_unreachable("Unknown SCEV kind!");
446}
447
448bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
449
450bool SCEV::isOne() const { return match(this, m_scev_One()); }
451
452bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
453
455 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
456 if (!Mul) return false;
457
458 // If there is a constant factor, it will be first.
459 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
460 if (!SC) return false;
461
462 // Return true if the value is negative, this matches things like (-42 * V).
463 return SC->getAPInt().isNegative();
464}
465
468
470 return S->getSCEVType() == scCouldNotCompute;
471}
472
475 ID.AddInteger(scConstant);
476 ID.AddPointer(V);
477 void *IP = nullptr;
478 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
479 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
480 UniqueSCEVs.InsertNode(S, IP);
481 return S;
482}
483
485 return getConstant(ConstantInt::get(getContext(), Val));
486}
487
488const SCEV *
490 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
491 return getConstant(ConstantInt::get(ITy, V, isSigned));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
508 if (EC.isScalable())
509 Res = getMulExpr(Res, getVScale(Ty));
510 return Res;
511}
512
514 const SCEV *op, Type *ty)
515 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
516
517SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
518 Type *ITy)
519 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 "Must be a non-bit-width-changing pointer-to-integer cast!");
522}
523
525 SCEVTypes SCEVTy, const SCEV *op,
526 Type *ty)
527 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
528
529SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
530 Type *ty)
532 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
533 "Cannot truncate non-integer value!");
534}
535
536SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
537 const SCEV *op, Type *ty)
539 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
540 "Cannot zero extend non-integer value!");
541}
542
543SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
544 const SCEV *op, Type *ty)
546 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
547 "Cannot sign extend non-integer value!");
548}
549
550void SCEVUnknown::deleted() {
551 // Clear this SCEVUnknown from various maps.
552 SE->forgetMemoizedResults(this);
553
554 // Remove this SCEVUnknown from the uniquing map.
555 SE->UniqueSCEVs.RemoveNode(this);
556
557 // Release the value.
558 setValPtr(nullptr);
559}
560
561void SCEVUnknown::allUsesReplacedWith(Value *New) {
562 // Clear this SCEVUnknown from various maps.
563 SE->forgetMemoizedResults(this);
564
565 // Remove this SCEVUnknown from the uniquing map.
566 SE->UniqueSCEVs.RemoveNode(this);
567
568 // Replace the value pointer in case someone is still using this SCEVUnknown.
569 setValPtr(New);
570}
571
572//===----------------------------------------------------------------------===//
573// SCEV Utilities
574//===----------------------------------------------------------------------===//
575
576/// Compare the two values \p LV and \p RV in terms of their "complexity" where
577/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
578/// operands in SCEV expressions.
579static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
580 Value *RV, unsigned Depth) {
582 return 0;
583
584 // Order pointer values after integer values. This helps SCEVExpander form
585 // GEPs.
586 bool LIsPointer = LV->getType()->isPointerTy(),
587 RIsPointer = RV->getType()->isPointerTy();
588 if (LIsPointer != RIsPointer)
589 return (int)LIsPointer - (int)RIsPointer;
590
591 // Compare getValueID values.
592 unsigned LID = LV->getValueID(), RID = RV->getValueID();
593 if (LID != RID)
594 return (int)LID - (int)RID;
595
596 // Sort arguments by their position.
597 if (const auto *LA = dyn_cast<Argument>(LV)) {
598 const auto *RA = cast<Argument>(RV);
599 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
600 return (int)LArgNo - (int)RArgNo;
601 }
602
603 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
604 const auto *RGV = cast<GlobalValue>(RV);
605
606 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
607 auto LT = GV->getLinkage();
608 return !(GlobalValue::isPrivateLinkage(LT) ||
610 };
611
612 // Use the names to distinguish the two values, but only if the
613 // names are semantically important.
614 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
615 return LGV->getName().compare(RGV->getName());
616 }
617
618 // For instructions, compare their loop depth, and their operand count. This
619 // is pretty loose.
620 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
621 const auto *RInst = cast<Instruction>(RV);
622
623 // Compare loop depths.
624 const BasicBlock *LParent = LInst->getParent(),
625 *RParent = RInst->getParent();
626 if (LParent != RParent) {
627 unsigned LDepth = LI->getLoopDepth(LParent),
628 RDepth = LI->getLoopDepth(RParent);
629 if (LDepth != RDepth)
630 return (int)LDepth - (int)RDepth;
631 }
632
633 // Compare the number of operands.
634 unsigned LNumOps = LInst->getNumOperands(),
635 RNumOps = RInst->getNumOperands();
636 if (LNumOps != RNumOps)
637 return (int)LNumOps - (int)RNumOps;
638
639 for (unsigned Idx : seq(LNumOps)) {
640 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
641 RInst->getOperand(Idx), Depth + 1);
642 if (Result != 0)
643 return Result;
644 }
645 }
646
647 return 0;
648}
649
650// Return negative, zero, or positive, if LHS is less than, equal to, or greater
651// than RHS, respectively. A three-way result allows recursive comparisons to be
652// more efficient.
653// If the max analysis depth was reached, return std::nullopt, assuming we do
654// not know if they are equivalent for sure.
655static std::optional<int>
657 const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
668 if (EqCacheSCEV.isEquivalent(LHS, RHS))
669 return 0;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 if (X == 0)
685 EqCacheSCEV.unionSets(LHS, RHS);
686 return X;
687 }
688
689 case scConstant: {
690 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
691 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
692
693 // Compare constant values.
694 const APInt &LA = LC->getAPInt();
695 const APInt &RA = RC->getAPInt();
696 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
697 if (LBitWidth != RBitWidth)
698 return (int)LBitWidth - (int)RBitWidth;
699 return LA.ult(RA) ? -1 : 1;
700 }
701
702 case scVScale: {
703 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
704 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
705 return LTy->getBitWidth() - RTy->getBitWidth();
706 }
707
708 case scAddRecExpr: {
709 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
710 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
711
712 // There is always a dominance between two recs that are used by one SCEV,
713 // so we can safely sort recs by loop header dominance. We require such
714 // order in getAddExpr.
715 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
716 if (LLoop != RLoop) {
717 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
718 assert(LHead != RHead && "Two loops share the same header?");
719 if (DT.dominates(LHead, RHead))
720 return 1;
721 assert(DT.dominates(RHead, LHead) &&
722 "No dominance between recurrences used by one SCEV?");
723 return -1;
724 }
725
726 [[fallthrough]];
727 }
728
729 case scTruncate:
730 case scZeroExtend:
731 case scSignExtend:
732 case scPtrToInt:
733 case scAddExpr:
734 case scMulExpr:
735 case scUDivExpr:
736 case scSMaxExpr:
737 case scUMaxExpr:
738 case scSMinExpr:
739 case scUMinExpr:
741 ArrayRef<const SCEV *> LOps = LHS->operands();
742 ArrayRef<const SCEV *> ROps = RHS->operands();
743
744 // Lexicographically compare n-ary-like expressions.
745 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
746 if (LNumOps != RNumOps)
747 return (int)LNumOps - (int)RNumOps;
748
749 for (unsigned i = 0; i != LNumOps; ++i) {
750 auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
751 Depth + 1);
752 if (X != 0)
753 return X;
754 }
755 EqCacheSCEV.unionSets(LHS, RHS);
756 return 0;
757 }
758
760 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
761 }
762 llvm_unreachable("Unknown SCEV kind!");
763}
764
765/// Given a list of SCEV objects, order them by their complexity, and group
766/// objects of the same complexity together by value. When this routine is
767/// finished, we know that any duplicates in the vector are consecutive and that
768/// complexity is monotonically increasing.
769///
770/// Note that we go take special precautions to ensure that we get deterministic
771/// results from this routine. In other words, we don't want the results of
772/// this to depend on where the addresses of various SCEV objects happened to
773/// land in memory.
775 LoopInfo *LI, DominatorTree &DT) {
776 if (Ops.size() < 2) return; // Noop
777
779
780 // Whether LHS has provably less complexity than RHS.
781 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
782 auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
783 return Complexity && *Complexity < 0;
784 };
785 if (Ops.size() == 2) {
786 // This is the common case, which also happens to be trivially simple.
787 // Special case it.
788 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
789 if (IsLessComplex(RHS, LHS))
790 std::swap(LHS, RHS);
791 return;
792 }
793
794 // Do the rough sort by complexity.
795 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
796 return IsLessComplex(LHS, RHS);
797 });
798
799 // Now that we are sorted by complexity, group elements of the same
800 // complexity. Note that this is, at worst, N^2, but the vector is likely to
801 // be extremely short in practice. Note that we take this approach because we
802 // do not want to depend on the addresses of the objects we are grouping.
803 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
804 const SCEV *S = Ops[i];
805 unsigned Complexity = S->getSCEVType();
806
807 // If there are any objects of the same complexity and same value as this
808 // one, group them.
809 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
810 if (Ops[j] == S) { // Found a duplicate.
811 // Move it to immediately after i'th element.
812 std::swap(Ops[i+1], Ops[j]);
813 ++i; // no need to rescan it.
814 if (i == e-2) return; // Done!
815 }
816 }
817 }
818}
819
820/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
821/// least HugeExprThreshold nodes).
823 return any_of(Ops, [](const SCEV *S) {
825 });
826}
827
828/// Performs a number of common optimizations on the passed \p Ops. If the
829/// whole expression reduces down to a single operand, it will be returned.
830///
831/// The following optimizations are performed:
832/// * Fold constants using the \p Fold function.
833/// * Remove identity constants satisfying \p IsIdentity.
834/// * If a constant satisfies \p IsAbsorber, return it.
835/// * Sort operands by complexity.
836template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
837static const SCEV *
839 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
840 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
841 const SCEVConstant *Folded = nullptr;
842 for (unsigned Idx = 0; Idx < Ops.size();) {
843 const SCEV *Op = Ops[Idx];
844 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
845 if (!Folded)
846 Folded = C;
847 else
848 Folded = cast<SCEVConstant>(
849 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
850 Ops.erase(Ops.begin() + Idx);
851 continue;
852 }
853 ++Idx;
854 }
855
856 if (Ops.empty()) {
857 assert(Folded && "Must have folded value");
858 return Folded;
859 }
860
861 if (Folded && IsAbsorber(Folded->getAPInt()))
862 return Folded;
863
864 GroupByComplexity(Ops, &LI, DT);
865 if (Folded && !IsIdentity(Folded->getAPInt()))
866 Ops.insert(Ops.begin(), Folded);
867
868 return Ops.size() == 1 ? Ops[0] : nullptr;
869}
870
871//===----------------------------------------------------------------------===//
872// Simple SCEV method implementations
873//===----------------------------------------------------------------------===//
874
875/// Compute BC(It, K). The result has width W. Assume, K > 0.
876static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
877 ScalarEvolution &SE,
878 Type *ResultTy) {
879 // Handle the simplest case efficiently.
880 if (K == 1)
881 return SE.getTruncateOrZeroExtend(It, ResultTy);
882
883 // We are using the following formula for BC(It, K):
884 //
885 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
886 //
887 // Suppose, W is the bitwidth of the return value. We must be prepared for
888 // overflow. Hence, we must assure that the result of our computation is
889 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
890 // safe in modular arithmetic.
891 //
892 // However, this code doesn't use exactly that formula; the formula it uses
893 // is something like the following, where T is the number of factors of 2 in
894 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
895 // exponentiation:
896 //
897 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
898 //
899 // This formula is trivially equivalent to the previous formula. However,
900 // this formula can be implemented much more efficiently. The trick is that
901 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
902 // arithmetic. To do exact division in modular arithmetic, all we have
903 // to do is multiply by the inverse. Therefore, this step can be done at
904 // width W.
905 //
906 // The next issue is how to safely do the division by 2^T. The way this
907 // is done is by doing the multiplication step at a width of at least W + T
908 // bits. This way, the bottom W+T bits of the product are accurate. Then,
909 // when we perform the division by 2^T (which is equivalent to a right shift
910 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
911 // truncated out after the division by 2^T.
912 //
913 // In comparison to just directly using the first formula, this technique
914 // is much more efficient; using the first formula requires W * K bits,
915 // but this formula less than W + K bits. Also, the first formula requires
916 // a division step, whereas this formula only requires multiplies and shifts.
917 //
918 // It doesn't matter whether the subtraction step is done in the calculation
919 // width or the input iteration count's width; if the subtraction overflows,
920 // the result must be zero anyway. We prefer here to do it in the width of
921 // the induction variable because it helps a lot for certain cases; CodeGen
922 // isn't smart enough to ignore the overflow, which leads to much less
923 // efficient code if the width of the subtraction is wider than the native
924 // register width.
925 //
926 // (It's possible to not widen at all by pulling out factors of 2 before
927 // the multiplication; for example, K=2 can be calculated as
928 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
929 // extra arithmetic, so it's not an obvious win, and it gets
930 // much more complicated for K > 3.)
931
932 // Protection from insane SCEVs; this bound is conservative,
933 // but it probably doesn't matter.
934 if (K > 1000)
935 return SE.getCouldNotCompute();
936
937 unsigned W = SE.getTypeSizeInBits(ResultTy);
938
939 // Calculate K! / 2^T and T; we divide out the factors of two before
940 // multiplying for calculating K! / 2^T to avoid overflow.
941 // Other overflow doesn't matter because we only care about the bottom
942 // W bits of the result.
943 APInt OddFactorial(W, 1);
944 unsigned T = 1;
945 for (unsigned i = 3; i <= K; ++i) {
946 unsigned TwoFactors = countr_zero(i);
947 T += TwoFactors;
948 OddFactorial *= (i >> TwoFactors);
949 }
950
951 // We need at least W + T bits for the multiplication step
952 unsigned CalculationBits = W + T;
953
954 // Calculate 2^T, at width T+W.
955 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
956
957 // Calculate the multiplicative inverse of K! / 2^T;
958 // this multiplication factor will perform the exact division by
959 // K! / 2^T.
960 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
961
962 // Calculate the product, at width T+W
963 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
964 CalculationBits);
965 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
966 for (unsigned i = 1; i != K; ++i) {
967 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
968 Dividend = SE.getMulExpr(Dividend,
969 SE.getTruncateOrZeroExtend(S, CalculationTy));
970 }
971
972 // Divide by 2^T
973 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
974
975 // Truncate the result, and divide by K! / 2^T.
976
977 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
978 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
979}
980
981/// Return the value of this chain of recurrences at the specified iteration
982/// number. We can evaluate this recurrence by multiplying each element in the
983/// chain by the binomial coefficient corresponding to it. In other words, we
984/// can evaluate {A,+,B,+,C,+,D} as:
985///
986/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
987///
988/// where BC(It, k) stands for binomial coefficient.
990 ScalarEvolution &SE) const {
991 return evaluateAtIteration(operands(), It, SE);
992}
993
994const SCEV *
996 const SCEV *It, ScalarEvolution &SE) {
997 assert(Operands.size() > 0);
998 const SCEV *Result = Operands[0];
999 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1000 // The computation is correct in the face of overflow provided that the
1001 // multiplication is performed _after_ the evaluation of the binomial
1002 // coefficient.
1003 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1004 if (isa<SCEVCouldNotCompute>(Coeff))
1005 return Coeff;
1006
1007 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1008 }
1009 return Result;
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SCEV Expression folder implementations
1014//===----------------------------------------------------------------------===//
1015
1017 unsigned Depth) {
1018 assert(Depth <= 1 &&
1019 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1020
1021 // We could be called with an integer-typed operands during SCEV rewrites.
1022 // Since the operand is an integer already, just perform zext/trunc/self cast.
1023 if (!Op->getType()->isPointerTy())
1024 return Op;
1025
1026 // What would be an ID for such a SCEV cast expression?
1028 ID.AddInteger(scPtrToInt);
1029 ID.AddPointer(Op);
1030
1031 void *IP = nullptr;
1032
1033 // Is there already an expression for such a cast?
1034 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1035 return S;
1036
1037 // It isn't legal for optimizations to construct new ptrtoint expressions
1038 // for non-integral pointers.
1039 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1040 return getCouldNotCompute();
1041
1042 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1043
1044 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1045 // is sufficiently wide to represent all possible pointer values.
1046 // We could theoretically teach SCEV to truncate wider pointers, but
1047 // that isn't implemented for now.
1049 getDataLayout().getTypeSizeInBits(IntPtrTy))
1050 return getCouldNotCompute();
1051
1052 // If not, is this expression something we can't reduce any further?
1053 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1054 // Perform some basic constant folding. If the operand of the ptr2int cast
1055 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1056 // left as-is), but produce a zero constant.
1057 // NOTE: We could handle a more general case, but lack motivational cases.
1058 if (isa<ConstantPointerNull>(U->getValue()))
1059 return getZero(IntPtrTy);
1060
1061 // Create an explicit cast node.
1062 // We can reuse the existing insert position since if we get here,
1063 // we won't have made any changes which would invalidate it.
1064 SCEV *S = new (SCEVAllocator)
1065 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1066 UniqueSCEVs.InsertNode(S, IP);
1067 registerUser(S, Op);
1068 return S;
1069 }
1070
1071 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1072 "non-SCEVUnknown's.");
1073
1074 // Otherwise, we've got some expression that is more complex than just a
1075 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1076 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1077 // only, and the expressions must otherwise be integer-typed.
1078 // So sink the cast down to the SCEVUnknown's.
1079
1080 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1081 /// which computes a pointer-typed value, and rewrites the whole expression
1082 /// tree so that *all* the computations are done on integers, and the only
1083 /// pointer-typed operands in the expression are SCEVUnknown.
1084 class SCEVPtrToIntSinkingRewriter
1085 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1087
1088 public:
1089 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1090
1091 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1092 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1093 return Rewriter.visit(Scev);
1094 }
1095
1096 const SCEV *visit(const SCEV *S) {
1097 Type *STy = S->getType();
1098 // If the expression is not pointer-typed, just keep it as-is.
1099 if (!STy->isPointerTy())
1100 return S;
1101 // Else, recursively sink the cast down into it.
1102 return Base::visit(S);
1103 }
1104
1105 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1107 bool Changed = false;
1108 for (const auto *Op : Expr->operands()) {
1109 Operands.push_back(visit(Op));
1110 Changed |= Op != Operands.back();
1111 }
1112 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1113 }
1114
1115 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1117 bool Changed = false;
1118 for (const auto *Op : Expr->operands()) {
1119 Operands.push_back(visit(Op));
1120 Changed |= Op != Operands.back();
1121 }
1122 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1123 }
1124
1125 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1126 assert(Expr->getType()->isPointerTy() &&
1127 "Should only reach pointer-typed SCEVUnknown's.");
1128 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1129 }
1130 };
1131
1132 // And actually perform the cast sinking.
1133 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1134 assert(IntOp->getType()->isIntegerTy() &&
1135 "We must have succeeded in sinking the cast, "
1136 "and ending up with an integer-typed expression!");
1137 return IntOp;
1138}
1139
1141 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1142
1143 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1144 if (isa<SCEVCouldNotCompute>(IntOp))
1145 return IntOp;
1146
1147 return getTruncateOrZeroExtend(IntOp, Ty);
1148}
1149
1151 unsigned Depth) {
1152 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1153 "This is not a truncating conversion!");
1154 assert(isSCEVable(Ty) &&
1155 "This is not a conversion to a SCEVable type!");
1156 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1157 Ty = getEffectiveSCEVType(Ty);
1158
1160 ID.AddInteger(scTruncate);
1161 ID.AddPointer(Op);
1162 ID.AddPointer(Ty);
1163 void *IP = nullptr;
1164 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1165
1166 // Fold if the operand is constant.
1167 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1168 return getConstant(
1169 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1170
1171 // trunc(trunc(x)) --> trunc(x)
1172 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1173 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1174
1175 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1176 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1178
1179 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1180 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1182
1183 if (Depth > MaxCastDepth) {
1184 SCEV *S =
1185 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1186 UniqueSCEVs.InsertNode(S, IP);
1187 registerUser(S, Op);
1188 return S;
1189 }
1190
1191 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1192 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1193 // if after transforming we have at most one truncate, not counting truncates
1194 // that replace other casts.
1195 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1196 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1198 unsigned numTruncs = 0;
1199 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1200 ++i) {
1201 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1202 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1203 isa<SCEVTruncateExpr>(S))
1204 numTruncs++;
1205 Operands.push_back(S);
1206 }
1207 if (numTruncs < 2) {
1208 if (isa<SCEVAddExpr>(Op))
1209 return getAddExpr(Operands);
1210 if (isa<SCEVMulExpr>(Op))
1211 return getMulExpr(Operands);
1212 llvm_unreachable("Unexpected SCEV type for Op.");
1213 }
1214 // Although we checked in the beginning that ID is not in the cache, it is
1215 // possible that during recursion and different modification ID was inserted
1216 // into the cache. So if we find it, just return it.
1217 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1218 return S;
1219 }
1220
1221 // If the input value is a chrec scev, truncate the chrec's operands.
1222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1224 for (const SCEV *Op : AddRec->operands())
1225 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1226 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1227 }
1228
1229 // Return zero if truncating to known zeros.
1230 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1231 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1232 return getZero(Ty);
1233
1234 // The cast wasn't folded; create an explicit cast node. We can reuse
1235 // the existing insert position since if we get here, we won't have
1236 // made any changes which would invalidate it.
1237 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1238 Op, Ty);
1239 UniqueSCEVs.InsertNode(S, IP);
1240 registerUser(S, Op);
1241 return S;
1242}
1243
1244// Get the limit of a recurrence such that incrementing by Step cannot cause
1245// signed overflow as long as the value of the recurrence within the
1246// loop does not exceed this limit before incrementing.
1247static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1248 ICmpInst::Predicate *Pred,
1249 ScalarEvolution *SE) {
1250 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1251 if (SE->isKnownPositive(Step)) {
1252 *Pred = ICmpInst::ICMP_SLT;
1254 SE->getSignedRangeMax(Step));
1255 }
1256 if (SE->isKnownNegative(Step)) {
1257 *Pred = ICmpInst::ICMP_SGT;
1259 SE->getSignedRangeMin(Step));
1260 }
1261 return nullptr;
1262}
1263
1264// Get the limit of a recurrence such that incrementing by Step cannot cause
1265// unsigned overflow as long as the value of the recurrence within the loop does
1266// not exceed this limit before incrementing.
1268 ICmpInst::Predicate *Pred,
1269 ScalarEvolution *SE) {
1270 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1271 *Pred = ICmpInst::ICMP_ULT;
1272
1274 SE->getUnsignedRangeMax(Step));
1275}
1276
1277namespace {
1278
1279struct ExtendOpTraitsBase {
1280 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1281 unsigned);
1282};
1283
1284// Used to make code generic over signed and unsigned overflow.
1285template <typename ExtendOp> struct ExtendOpTraits {
1286 // Members present:
1287 //
1288 // static const SCEV::NoWrapFlags WrapType;
1289 //
1290 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1291 //
1292 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1293 // ICmpInst::Predicate *Pred,
1294 // ScalarEvolution *SE);
1295};
1296
1297template <>
1298struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1299 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1300
1301 static const GetExtendExprTy GetExtendExpr;
1302
1303 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1304 ICmpInst::Predicate *Pred,
1305 ScalarEvolution *SE) {
1306 return getSignedOverflowLimitForStep(Step, Pred, SE);
1307 }
1308};
1309
1310const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1312
1313template <>
1314struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1315 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1316
1317 static const GetExtendExprTy GetExtendExpr;
1318
1319 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1320 ICmpInst::Predicate *Pred,
1321 ScalarEvolution *SE) {
1322 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1323 }
1324};
1325
1326const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1328
1329} // end anonymous namespace
1330
1331// The recurrence AR has been shown to have no signed/unsigned wrap or something
1332// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1333// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1334// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1335// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1336// expression "Step + sext/zext(PreIncAR)" is congruent with
1337// "sext/zext(PostIncAR)"
1338template <typename ExtendOpTy>
1339static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1340 ScalarEvolution *SE, unsigned Depth) {
1341 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1342 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1343
1344 const Loop *L = AR->getLoop();
1345 const SCEV *Start = AR->getStart();
1346 const SCEV *Step = AR->getStepRecurrence(*SE);
1347
1348 // Check for a simple looking step prior to loop entry.
1349 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1350 if (!SA)
1351 return nullptr;
1352
1353 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1354 // subtraction is expensive. For this purpose, perform a quick and dirty
1355 // difference, by checking for Step in the operand list. Note, that
1356 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1358 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1359 if (*It == Step) {
1360 DiffOps.erase(It);
1361 break;
1362 }
1363
1364 if (DiffOps.size() == SA->getNumOperands())
1365 return nullptr;
1366
1367 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1368 // `Step`:
1369
1370 // 1. NSW/NUW flags on the step increment.
1371 auto PreStartFlags =
1373 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1374 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1375 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1376
1377 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1378 // "S+X does not sign/unsign-overflow".
1379 //
1380
1381 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1382 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1383 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1384 return PreStart;
1385
1386 // 2. Direct overflow check on the step operation's expression.
1387 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1388 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1389 const SCEV *OperandExtendedStart =
1390 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1391 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1392 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1393 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1394 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1395 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1396 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1397 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1398 }
1399 return PreStart;
1400 }
1401
1402 // 3. Loop precondition.
1404 const SCEV *OverflowLimit =
1405 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1406
1407 if (OverflowLimit &&
1408 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1409 return PreStart;
1410
1411 return nullptr;
1412}
1413
1414// Get the normalized zero or sign extended expression for this AddRec's Start.
1415template <typename ExtendOpTy>
1416static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1417 ScalarEvolution *SE,
1418 unsigned Depth) {
1419 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1420
1421 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1422 if (!PreStart)
1423 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1424
1425 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1426 Depth),
1427 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1428}
1429
1430// Try to prove away overflow by looking at "nearby" add recurrences. A
1431// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1432// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1433//
1434// Formally:
1435//
1436// {S,+,X} == {S-T,+,X} + T
1437// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1438//
1439// If ({S-T,+,X} + T) does not overflow ... (1)
1440//
1441// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1442//
1443// If {S-T,+,X} does not overflow ... (2)
1444//
1445// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1446// == {Ext(S-T)+Ext(T),+,Ext(X)}
1447//
1448// If (S-T)+T does not overflow ... (3)
1449//
1450// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1451// == {Ext(S),+,Ext(X)} == LHS
1452//
1453// Thus, if (1), (2) and (3) are true for some T, then
1454// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1455//
1456// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1457// does not overflow" restricted to the 0th iteration. Therefore we only need
1458// to check for (1) and (2).
1459//
1460// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1461// is `Delta` (defined below).
1462template <typename ExtendOpTy>
1463bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1464 const SCEV *Step,
1465 const Loop *L) {
1466 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1467
1468 // We restrict `Start` to a constant to prevent SCEV from spending too much
1469 // time here. It is correct (but more expensive) to continue with a
1470 // non-constant `Start` and do a general SCEV subtraction to compute
1471 // `PreStart` below.
1472 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1473 if (!StartC)
1474 return false;
1475
1476 APInt StartAI = StartC->getAPInt();
1477
1478 for (unsigned Delta : {-2, -1, 1, 2}) {
1479 const SCEV *PreStart = getConstant(StartAI - Delta);
1480
1482 ID.AddInteger(scAddRecExpr);
1483 ID.AddPointer(PreStart);
1484 ID.AddPointer(Step);
1485 ID.AddPointer(L);
1486 void *IP = nullptr;
1487 const auto *PreAR =
1488 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1489
1490 // Give up if we don't already have the add recurrence we need because
1491 // actually constructing an add recurrence is relatively expensive.
1492 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1493 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1495 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1496 DeltaS, &Pred, this);
1497 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1498 return true;
1499 }
1500 }
1501
1502 return false;
1503}
1504
1505// Finds an integer D for an expression (C + x + y + ...) such that the top
1506// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1507// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1508// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1509// the (C + x + y + ...) expression is \p WholeAddExpr.
1511 const SCEVConstant *ConstantTerm,
1512 const SCEVAddExpr *WholeAddExpr) {
1513 const APInt &C = ConstantTerm->getAPInt();
1514 const unsigned BitWidth = C.getBitWidth();
1515 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1516 uint32_t TZ = BitWidth;
1517 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1518 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1519 if (TZ) {
1520 // Set D to be as many least significant bits of C as possible while still
1521 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1522 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1523 }
1524 return APInt(BitWidth, 0);
1525}
1526
1527// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1528// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1529// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1530// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1532 const APInt &ConstantStart,
1533 const SCEV *Step) {
1534 const unsigned BitWidth = ConstantStart.getBitWidth();
1535 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1536 if (TZ)
1537 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1538 : ConstantStart;
1539 return APInt(BitWidth, 0);
1540}
1541
1543 const ScalarEvolution::FoldID &ID, const SCEV *S,
1546 &FoldCacheUser) {
1547 auto I = FoldCache.insert({ID, S});
1548 if (!I.second) {
1549 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1550 // entry.
1551 auto &UserIDs = FoldCacheUser[I.first->second];
1552 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1553 for (unsigned I = 0; I != UserIDs.size(); ++I)
1554 if (UserIDs[I] == ID) {
1555 std::swap(UserIDs[I], UserIDs.back());
1556 break;
1557 }
1558 UserIDs.pop_back();
1559 I.first->second = S;
1560 }
1561 FoldCacheUser[S].push_back(ID);
1562}
1563
1564const SCEV *
1566 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1567 "This is not an extending conversion!");
1568 assert(isSCEVable(Ty) &&
1569 "This is not a conversion to a SCEVable type!");
1570 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1571 Ty = getEffectiveSCEVType(Ty);
1572
1573 FoldID ID(scZeroExtend, Op, Ty);
1574 auto Iter = FoldCache.find(ID);
1575 if (Iter != FoldCache.end())
1576 return Iter->second;
1577
1578 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1579 if (!isa<SCEVZeroExtendExpr>(S))
1580 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1581 return S;
1582}
1583
1585 unsigned Depth) {
1586 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1587 "This is not an extending conversion!");
1588 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1589 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1590
1591 // Fold if the operand is constant.
1592 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1593 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1594
1595 // zext(zext(x)) --> zext(x)
1596 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1598
1599 // Before doing any expensive analysis, check to see if we've already
1600 // computed a SCEV for this Op and Ty.
1602 ID.AddInteger(scZeroExtend);
1603 ID.AddPointer(Op);
1604 ID.AddPointer(Ty);
1605 void *IP = nullptr;
1606 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607 if (Depth > MaxCastDepth) {
1608 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1609 Op, Ty);
1610 UniqueSCEVs.InsertNode(S, IP);
1611 registerUser(S, Op);
1612 return S;
1613 }
1614
1615 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617 // It's possible the bits taken off by the truncate were all zero bits. If
1618 // so, we should be able to simplify this further.
1619 const SCEV *X = ST->getOperand();
1621 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622 unsigned NewBits = getTypeSizeInBits(Ty);
1623 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624 CR.zextOrTrunc(NewBits)))
1625 return getTruncateOrZeroExtend(X, Ty, Depth);
1626 }
1627
1628 // If the input value is a chrec scev, and we can prove that the value
1629 // did not overflow the old, smaller, value, we can zero extend all of the
1630 // operands (often constants). This allows analysis of something like
1631 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633 if (AR->isAffine()) {
1634 const SCEV *Start = AR->getStart();
1635 const SCEV *Step = AR->getStepRecurrence(*this);
1636 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637 const Loop *L = AR->getLoop();
1638
1639 // If we have special knowledge that this addrec won't overflow,
1640 // we don't need to do any further analysis.
1641 if (AR->hasNoUnsignedWrap()) {
1642 Start =
1643 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1644 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1645 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1646 }
1647
1648 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1649 // Note that this serves two purposes: It filters out loops that are
1650 // simply not analyzable, and it covers the case where this code is
1651 // being called from within backedge-taken count analysis, such that
1652 // attempting to ask for the backedge-taken count would likely result
1653 // in infinite recursion. In the later case, the analysis code will
1654 // cope with a conservative value, and it will take care to purge
1655 // that value once it has finished.
1656 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1657 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1658 // Manually compute the final value for AR, checking for overflow.
1659
1660 // Check whether the backedge-taken count can be losslessly casted to
1661 // the addrec's type. The count is always unsigned.
1662 const SCEV *CastedMaxBECount =
1663 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1664 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1665 CastedMaxBECount, MaxBECount->getType(), Depth);
1666 if (MaxBECount == RecastedMaxBECount) {
1667 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1668 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1669 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1671 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1673 Depth + 1),
1674 WideTy, Depth + 1);
1675 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1676 const SCEV *WideMaxBECount =
1677 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1678 const SCEV *OperandExtendedAdd =
1679 getAddExpr(WideStart,
1680 getMulExpr(WideMaxBECount,
1681 getZeroExtendExpr(Step, WideTy, Depth + 1),
1684 if (ZAdd == OperandExtendedAdd) {
1685 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1686 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1687 // Return the expression with the addrec on the outside.
1688 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1689 Depth + 1);
1690 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1691 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1692 }
1693 // Similar to above, only this time treat the step value as signed.
1694 // This covers loops that count down.
1695 OperandExtendedAdd =
1696 getAddExpr(WideStart,
1697 getMulExpr(WideMaxBECount,
1698 getSignExtendExpr(Step, WideTy, Depth + 1),
1701 if (ZAdd == OperandExtendedAdd) {
1702 // Cache knowledge of AR NW, which is propagated to this AddRec.
1703 // Negative step causes unsigned wrap, but it still can't self-wrap.
1704 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1705 // Return the expression with the addrec on the outside.
1706 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707 Depth + 1);
1708 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1709 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 }
1711 }
1712 }
1713
1714 // Normally, in the cases we can prove no-overflow via a
1715 // backedge guarding condition, we can also compute a backedge
1716 // taken count for the loop. The exceptions are assumptions and
1717 // guards present in the loop -- SCEV is not great at exploiting
1718 // these to compute max backedge taken counts, but can still use
1719 // these to prove lack of overflow. Use this fact to avoid
1720 // doing extra work that may not pay off.
1721 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1722 !AC.assumptions().empty()) {
1723
1724 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1725 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1726 if (AR->hasNoUnsignedWrap()) {
1727 // Same as nuw case above - duplicated here to avoid a compile time
1728 // issue. It's not clear that the order of checks does matter, but
1729 // it's one of two issue possible causes for a change which was
1730 // reverted. Be conservative for the moment.
1731 Start =
1732 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1733 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1734 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1735 }
1736
1737 // For a negative step, we can extend the operands iff doing so only
1738 // traverses values in the range zext([0,UINT_MAX]).
1739 if (isKnownNegative(Step)) {
1741 getSignedRangeMin(Step));
1744 // Cache knowledge of AR NW, which is propagated to this
1745 // AddRec. Negative step causes unsigned wrap, but it
1746 // still can't self-wrap.
1747 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1748 // Return the expression with the addrec on the outside.
1749 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1750 Depth + 1);
1751 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1752 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1753 }
1754 }
1755 }
1756
1757 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1758 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1759 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1760 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1761 const APInt &C = SC->getAPInt();
1762 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1763 if (D != 0) {
1764 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1765 const SCEV *SResidual =
1766 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1767 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1768 return getAddExpr(SZExtD, SZExtR,
1770 Depth + 1);
1771 }
1772 }
1773
1774 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1775 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1776 Start =
1777 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1778 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1779 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1780 }
1781 }
1782
1783 // zext(A % B) --> zext(A) % zext(B)
1784 {
1785 const SCEV *LHS;
1786 const SCEV *RHS;
1787 if (matchURem(Op, LHS, RHS))
1788 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1789 getZeroExtendExpr(RHS, Ty, Depth + 1));
1790 }
1791
1792 // zext(A / B) --> zext(A) / zext(B).
1793 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1794 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1795 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1796
1797 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1798 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1799 if (SA->hasNoUnsignedWrap()) {
1800 // If the addition does not unsign overflow then we can, by definition,
1801 // commute the zero extension with the addition operation.
1803 for (const auto *Op : SA->operands())
1804 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1805 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1806 }
1807
1808 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1809 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1810 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1811 //
1812 // Often address arithmetics contain expressions like
1813 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1814 // This transformation is useful while proving that such expressions are
1815 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1816 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1817 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1818 if (D != 0) {
1819 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1820 const SCEV *SResidual =
1822 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1823 return getAddExpr(SZExtD, SZExtR,
1825 Depth + 1);
1826 }
1827 }
1828 }
1829
1830 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1831 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1832 if (SM->hasNoUnsignedWrap()) {
1833 // If the multiply does not unsign overflow then we can, by definition,
1834 // commute the zero extension with the multiply operation.
1836 for (const auto *Op : SM->operands())
1837 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1838 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1839 }
1840
1841 // zext(2^K * (trunc X to iN)) to iM ->
1842 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1843 //
1844 // Proof:
1845 //
1846 // zext(2^K * (trunc X to iN)) to iM
1847 // = zext((trunc X to iN) << K) to iM
1848 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1849 // (because shl removes the top K bits)
1850 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1851 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1852 //
1853 if (SM->getNumOperands() == 2)
1854 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1855 if (MulLHS->getAPInt().isPowerOf2())
1856 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1857 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1858 MulLHS->getAPInt().logBase2();
1859 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1860 return getMulExpr(
1861 getZeroExtendExpr(MulLHS, Ty),
1863 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1864 SCEV::FlagNUW, Depth + 1);
1865 }
1866 }
1867
1868 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1869 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1870 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1871 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1873 for (auto *Operand : MinMax->operands())
1874 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1875 if (isa<SCEVUMinExpr>(MinMax))
1876 return getUMinExpr(Operands);
1877 return getUMaxExpr(Operands);
1878 }
1879
1880 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1881 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1882 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1884 for (auto *Operand : MinMax->operands())
1885 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1886 return getUMinExpr(Operands, /*Sequential*/ true);
1887 }
1888
1889 // The cast wasn't folded; create an explicit cast node.
1890 // Recompute the insert position, as it may have been invalidated.
1891 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1892 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1893 Op, Ty);
1894 UniqueSCEVs.InsertNode(S, IP);
1895 registerUser(S, Op);
1896 return S;
1897}
1898
1899const SCEV *
1901 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1902 "This is not an extending conversion!");
1903 assert(isSCEVable(Ty) &&
1904 "This is not a conversion to a SCEVable type!");
1905 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1906 Ty = getEffectiveSCEVType(Ty);
1907
1908 FoldID ID(scSignExtend, Op, Ty);
1909 auto Iter = FoldCache.find(ID);
1910 if (Iter != FoldCache.end())
1911 return Iter->second;
1912
1913 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1914 if (!isa<SCEVSignExtendExpr>(S))
1915 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1916 return S;
1917}
1918
1920 unsigned Depth) {
1921 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1922 "This is not an extending conversion!");
1923 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1924 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1925 Ty = getEffectiveSCEVType(Ty);
1926
1927 // Fold if the operand is constant.
1928 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1929 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1930
1931 // sext(sext(x)) --> sext(x)
1932 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1933 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1934
1935 // sext(zext(x)) --> zext(x)
1936 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1937 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1938
1939 // Before doing any expensive analysis, check to see if we've already
1940 // computed a SCEV for this Op and Ty.
1942 ID.AddInteger(scSignExtend);
1943 ID.AddPointer(Op);
1944 ID.AddPointer(Ty);
1945 void *IP = nullptr;
1946 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1947 // Limit recursion depth.
1948 if (Depth > MaxCastDepth) {
1949 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1950 Op, Ty);
1951 UniqueSCEVs.InsertNode(S, IP);
1952 registerUser(S, Op);
1953 return S;
1954 }
1955
1956 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1957 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1958 // It's possible the bits taken off by the truncate were all sign bits. If
1959 // so, we should be able to simplify this further.
1960 const SCEV *X = ST->getOperand();
1962 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1963 unsigned NewBits = getTypeSizeInBits(Ty);
1964 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1965 CR.sextOrTrunc(NewBits)))
1966 return getTruncateOrSignExtend(X, Ty, Depth);
1967 }
1968
1969 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1970 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1971 if (SA->hasNoSignedWrap()) {
1972 // If the addition does not sign overflow then we can, by definition,
1973 // commute the sign extension with the addition operation.
1975 for (const auto *Op : SA->operands())
1976 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1977 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1978 }
1979
1980 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1981 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1982 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1983 //
1984 // For instance, this will bring two seemingly different expressions:
1985 // 1 + sext(5 + 20 * %x + 24 * %y) and
1986 // sext(6 + 20 * %x + 24 * %y)
1987 // to the same form:
1988 // 2 + sext(4 + 20 * %x + 24 * %y)
1989 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1990 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1991 if (D != 0) {
1992 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1993 const SCEV *SResidual =
1995 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1996 return getAddExpr(SSExtD, SSExtR,
1998 Depth + 1);
1999 }
2000 }
2001 }
2002 // If the input value is a chrec scev, and we can prove that the value
2003 // did not overflow the old, smaller, value, we can sign extend all of the
2004 // operands (often constants). This allows analysis of something like
2005 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2006 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
2007 if (AR->isAffine()) {
2008 const SCEV *Start = AR->getStart();
2009 const SCEV *Step = AR->getStepRecurrence(*this);
2010 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2011 const Loop *L = AR->getLoop();
2012
2013 // If we have special knowledge that this addrec won't overflow,
2014 // we don't need to do any further analysis.
2015 if (AR->hasNoSignedWrap()) {
2016 Start =
2017 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2018 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2019 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2020 }
2021
2022 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2023 // Note that this serves two purposes: It filters out loops that are
2024 // simply not analyzable, and it covers the case where this code is
2025 // being called from within backedge-taken count analysis, such that
2026 // attempting to ask for the backedge-taken count would likely result
2027 // in infinite recursion. In the later case, the analysis code will
2028 // cope with a conservative value, and it will take care to purge
2029 // that value once it has finished.
2030 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2031 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2032 // Manually compute the final value for AR, checking for
2033 // overflow.
2034
2035 // Check whether the backedge-taken count can be losslessly casted to
2036 // the addrec's type. The count is always unsigned.
2037 const SCEV *CastedMaxBECount =
2038 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2039 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2040 CastedMaxBECount, MaxBECount->getType(), Depth);
2041 if (MaxBECount == RecastedMaxBECount) {
2042 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2043 // Check whether Start+Step*MaxBECount has no signed overflow.
2044 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2046 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2048 Depth + 1),
2049 WideTy, Depth + 1);
2050 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2051 const SCEV *WideMaxBECount =
2052 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2053 const SCEV *OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getSignExtendExpr(Step, WideTy, Depth + 1),
2059 if (SAdd == OperandExtendedAdd) {
2060 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2061 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2062 // Return the expression with the addrec on the outside.
2063 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2064 Depth + 1);
2065 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2066 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2067 }
2068 // Similar to above, only this time treat the step value as unsigned.
2069 // This covers loops that count up with an unsigned step.
2070 OperandExtendedAdd =
2071 getAddExpr(WideStart,
2072 getMulExpr(WideMaxBECount,
2073 getZeroExtendExpr(Step, WideTy, Depth + 1),
2076 if (SAdd == OperandExtendedAdd) {
2077 // If AR wraps around then
2078 //
2079 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2080 // => SAdd != OperandExtendedAdd
2081 //
2082 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2083 // (SAdd == OperandExtendedAdd => AR is NW)
2084
2085 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2086
2087 // Return the expression with the addrec on the outside.
2088 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2089 Depth + 1);
2090 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2091 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 }
2093 }
2094 }
2095
2096 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2098 if (AR->hasNoSignedWrap()) {
2099 // Same as nsw case above - duplicated here to avoid a compile time
2100 // issue. It's not clear that the order of checks does matter, but
2101 // it's one of two issue possible causes for a change which was
2102 // reverted. Be conservative for the moment.
2103 Start =
2104 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2105 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2106 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2107 }
2108
2109 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2110 // if D + (C - D + Step * n) could be proven to not signed wrap
2111 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2112 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2113 const APInt &C = SC->getAPInt();
2114 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2115 if (D != 0) {
2116 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2117 const SCEV *SResidual =
2118 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2119 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2120 return getAddExpr(SSExtD, SSExtR,
2122 Depth + 1);
2123 }
2124 }
2125
2126 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2127 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2128 Start =
2129 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2130 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2131 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2132 }
2133 }
2134
2135 // If the input value is provably positive and we could not simplify
2136 // away the sext build a zext instead.
2138 return getZeroExtendExpr(Op, Ty, Depth + 1);
2139
2140 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2141 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2142 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2143 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2145 for (auto *Operand : MinMax->operands())
2146 Operands.push_back(getSignExtendExpr(Operand, Ty));
2147 if (isa<SCEVSMinExpr>(MinMax))
2148 return getSMinExpr(Operands);
2149 return getSMaxExpr(Operands);
2150 }
2151
2152 // The cast wasn't folded; create an explicit cast node.
2153 // Recompute the insert position, as it may have been invalidated.
2154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2155 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2156 Op, Ty);
2157 UniqueSCEVs.InsertNode(S, IP);
2158 registerUser(S, { Op });
2159 return S;
2160}
2161
2163 Type *Ty) {
2164 switch (Kind) {
2165 case scTruncate:
2166 return getTruncateExpr(Op, Ty);
2167 case scZeroExtend:
2168 return getZeroExtendExpr(Op, Ty);
2169 case scSignExtend:
2170 return getSignExtendExpr(Op, Ty);
2171 case scPtrToInt:
2172 return getPtrToIntExpr(Op, Ty);
2173 default:
2174 llvm_unreachable("Not a SCEV cast expression!");
2175 }
2176}
2177
2178/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2179/// unspecified bits out to the given type.
2181 Type *Ty) {
2182 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2183 "This is not an extending conversion!");
2184 assert(isSCEVable(Ty) &&
2185 "This is not a conversion to a SCEVable type!");
2186 Ty = getEffectiveSCEVType(Ty);
2187
2188 // Sign-extend negative constants.
2189 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2190 if (SC->getAPInt().isNegative())
2191 return getSignExtendExpr(Op, Ty);
2192
2193 // Peel off a truncate cast.
2194 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2195 const SCEV *NewOp = T->getOperand();
2196 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2197 return getAnyExtendExpr(NewOp, Ty);
2198 return getTruncateOrNoop(NewOp, Ty);
2199 }
2200
2201 // Next try a zext cast. If the cast is folded, use it.
2202 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2203 if (!isa<SCEVZeroExtendExpr>(ZExt))
2204 return ZExt;
2205
2206 // Next try a sext cast. If the cast is folded, use it.
2207 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2208 if (!isa<SCEVSignExtendExpr>(SExt))
2209 return SExt;
2210
2211 // Force the cast to be folded into the operands of an addrec.
2212 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2214 for (const SCEV *Op : AR->operands())
2215 Ops.push_back(getAnyExtendExpr(Op, Ty));
2216 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2217 }
2218
2219 // If the expression is obviously signed, use the sext cast value.
2220 if (isa<SCEVSMaxExpr>(Op))
2221 return SExt;
2222
2223 // Absent any other information, use the zext cast value.
2224 return ZExt;
2225}
2226
2227/// Process the given Ops list, which is a list of operands to be added under
2228/// the given scale, update the given map. This is a helper function for
2229/// getAddRecExpr. As an example of what it does, given a sequence of operands
2230/// that would form an add expression like this:
2231///
2232/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2233///
2234/// where A and B are constants, update the map with these values:
2235///
2236/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2237///
2238/// and add 13 + A*B*29 to AccumulatedConstant.
2239/// This will allow getAddRecExpr to produce this:
2240///
2241/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2242///
2243/// This form often exposes folding opportunities that are hidden in
2244/// the original operand list.
2245///
2246/// Return true iff it appears that any interesting folding opportunities
2247/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2248/// the common case where no interesting opportunities are present, and
2249/// is also used as a check to avoid infinite recursion.
2250static bool
2253 APInt &AccumulatedConstant,
2254 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2255 ScalarEvolution &SE) {
2256 bool Interesting = false;
2257
2258 // Iterate over the add operands. They are sorted, with constants first.
2259 unsigned i = 0;
2260 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2261 ++i;
2262 // Pull a buried constant out to the outside.
2263 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2264 Interesting = true;
2265 AccumulatedConstant += Scale * C->getAPInt();
2266 }
2267
2268 // Next comes everything else. We're especially interested in multiplies
2269 // here, but they're in the middle, so just visit the rest with one loop.
2270 for (; i != Ops.size(); ++i) {
2271 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2272 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2273 APInt NewScale =
2274 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2275 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2276 // A multiplication of a constant with another add; recurse.
2277 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2278 Interesting |=
2279 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2280 Add->operands(), NewScale, SE);
2281 } else {
2282 // A multiplication of a constant with some other value. Update
2283 // the map.
2284 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2285 const SCEV *Key = SE.getMulExpr(MulOps);
2286 auto Pair = M.insert({Key, NewScale});
2287 if (Pair.second) {
2288 NewOps.push_back(Pair.first->first);
2289 } else {
2290 Pair.first->second += NewScale;
2291 // The map already had an entry for this value, which may indicate
2292 // a folding opportunity.
2293 Interesting = true;
2294 }
2295 }
2296 } else {
2297 // An ordinary operand. Update the map.
2298 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2299 M.insert({Ops[i], Scale});
2300 if (Pair.second) {
2301 NewOps.push_back(Pair.first->first);
2302 } else {
2303 Pair.first->second += Scale;
2304 // The map already had an entry for this value, which may indicate
2305 // a folding opportunity.
2306 Interesting = true;
2307 }
2308 }
2309 }
2310
2311 return Interesting;
2312}
2313
2315 const SCEV *LHS, const SCEV *RHS,
2316 const Instruction *CtxI) {
2317 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2318 SCEV::NoWrapFlags, unsigned);
2319 switch (BinOp) {
2320 default:
2321 llvm_unreachable("Unsupported binary op");
2322 case Instruction::Add:
2324 break;
2325 case Instruction::Sub:
2327 break;
2328 case Instruction::Mul:
2330 break;
2331 }
2332
2333 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2336
2337 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2338 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2339 auto *WideTy =
2340 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2341
2342 const SCEV *A = (this->*Extension)(
2343 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2344 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2345 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2346 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2347 if (A == B)
2348 return true;
2349 // Can we use context to prove the fact we need?
2350 if (!CtxI)
2351 return false;
2352 // TODO: Support mul.
2353 if (BinOp == Instruction::Mul)
2354 return false;
2355 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2356 // TODO: Lift this limitation.
2357 if (!RHSC)
2358 return false;
2359 APInt C = RHSC->getAPInt();
2360 unsigned NumBits = C.getBitWidth();
2361 bool IsSub = (BinOp == Instruction::Sub);
2362 bool IsNegativeConst = (Signed && C.isNegative());
2363 // Compute the direction and magnitude by which we need to check overflow.
2364 bool OverflowDown = IsSub ^ IsNegativeConst;
2365 APInt Magnitude = C;
2366 if (IsNegativeConst) {
2367 if (C == APInt::getSignedMinValue(NumBits))
2368 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2369 // want to deal with that.
2370 return false;
2371 Magnitude = -C;
2372 }
2373
2375 if (OverflowDown) {
2376 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2377 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2378 : APInt::getMinValue(NumBits);
2379 APInt Limit = Min + Magnitude;
2380 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2381 } else {
2382 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2383 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2384 : APInt::getMaxValue(NumBits);
2385 APInt Limit = Max - Magnitude;
2386 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2387 }
2388}
2389
2390std::optional<SCEV::NoWrapFlags>
2392 const OverflowingBinaryOperator *OBO) {
2393 // It cannot be done any better.
2394 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2395 return std::nullopt;
2396
2398
2399 if (OBO->hasNoUnsignedWrap())
2401 if (OBO->hasNoSignedWrap())
2403
2404 bool Deduced = false;
2405
2406 if (OBO->getOpcode() != Instruction::Add &&
2407 OBO->getOpcode() != Instruction::Sub &&
2408 OBO->getOpcode() != Instruction::Mul)
2409 return std::nullopt;
2410
2411 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2412 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2413
2414 const Instruction *CtxI =
2415 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2416 if (!OBO->hasNoUnsignedWrap() &&
2418 /* Signed */ false, LHS, RHS, CtxI)) {
2420 Deduced = true;
2421 }
2422
2423 if (!OBO->hasNoSignedWrap() &&
2425 /* Signed */ true, LHS, RHS, CtxI)) {
2427 Deduced = true;
2428 }
2429
2430 if (Deduced)
2431 return Flags;
2432 return std::nullopt;
2433}
2434
2435// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2436// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2437// can't-overflow flags for the operation if possible.
2438static SCEV::NoWrapFlags
2440 const ArrayRef<const SCEV *> Ops,
2441 SCEV::NoWrapFlags Flags) {
2442 using namespace std::placeholders;
2443
2444 using OBO = OverflowingBinaryOperator;
2445
2446 bool CanAnalyze =
2448 (void)CanAnalyze;
2449 assert(CanAnalyze && "don't call from other places!");
2450
2451 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2452 SCEV::NoWrapFlags SignOrUnsignWrap =
2453 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2456 auto IsKnownNonNegative = [&](const SCEV *S) {
2457 return SE->isKnownNonNegative(S);
2458 };
2459
2460 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2461 Flags =
2462 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2463
2464 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2465
2466 if (SignOrUnsignWrap != SignOrUnsignMask &&
2467 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2468 isa<SCEVConstant>(Ops[0])) {
2469
2470 auto Opcode = [&] {
2471 switch (Type) {
2472 case scAddExpr:
2473 return Instruction::Add;
2474 case scMulExpr:
2475 return Instruction::Mul;
2476 default:
2477 llvm_unreachable("Unexpected SCEV op.");
2478 }
2479 }();
2480
2481 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2482
2483 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2484 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2486 Opcode, C, OBO::NoSignedWrap);
2487 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2489 }
2490
2491 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2492 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2494 Opcode, C, OBO::NoUnsignedWrap);
2495 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2497 }
2498 }
2499
2500 // <0,+,nonnegative><nw> is also nuw
2501 // TODO: Add corresponding nsw case
2503 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2504 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2506
2507 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2509 Ops.size() == 2) {
2510 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2511 if (UDiv->getOperand(1) == Ops[1])
2513 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2514 if (UDiv->getOperand(1) == Ops[0])
2516 }
2517
2518 return Flags;
2519}
2520
2522 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2523}
2524
2525/// Get a canonical add expression, or something simpler if possible.
2527 SCEV::NoWrapFlags OrigFlags,
2528 unsigned Depth) {
2529 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2530 "only nuw or nsw allowed");
2531 assert(!Ops.empty() && "Cannot get empty add!");
2532 if (Ops.size() == 1) return Ops[0];
2533#ifndef NDEBUG
2534 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2535 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2536 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2537 "SCEVAddExpr operand types don't match!");
2538 unsigned NumPtrs = count_if(
2539 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2540 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2541#endif
2542
2543 const SCEV *Folded = constantFoldAndGroupOps(
2544 *this, LI, DT, Ops,
2545 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2546 [](const APInt &C) { return C.isZero(); }, // identity
2547 [](const APInt &C) { return false; }); // absorber
2548 if (Folded)
2549 return Folded;
2550
2551 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2552
2553 // Delay expensive flag strengthening until necessary.
2554 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2555 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2556 };
2557
2558 // Limit recursion calls depth.
2560 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2561
2562 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2563 // Don't strengthen flags if we have no new information.
2564 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2565 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2566 Add->setNoWrapFlags(ComputeFlags(Ops));
2567 return S;
2568 }
2569
2570 // Okay, check to see if the same value occurs in the operand list more than
2571 // once. If so, merge them together into an multiply expression. Since we
2572 // sorted the list, these values are required to be adjacent.
2573 Type *Ty = Ops[0]->getType();
2574 bool FoundMatch = false;
2575 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2576 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2577 // Scan ahead to count how many equal operands there are.
2578 unsigned Count = 2;
2579 while (i+Count != e && Ops[i+Count] == Ops[i])
2580 ++Count;
2581 // Merge the values into a multiply.
2582 const SCEV *Scale = getConstant(Ty, Count);
2583 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2584 if (Ops.size() == Count)
2585 return Mul;
2586 Ops[i] = Mul;
2587 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2588 --i; e -= Count - 1;
2589 FoundMatch = true;
2590 }
2591 if (FoundMatch)
2592 return getAddExpr(Ops, OrigFlags, Depth + 1);
2593
2594 // Check for truncates. If all the operands are truncated from the same
2595 // type, see if factoring out the truncate would permit the result to be
2596 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2597 // if the contents of the resulting outer trunc fold to something simple.
2598 auto FindTruncSrcType = [&]() -> Type * {
2599 // We're ultimately looking to fold an addrec of truncs and muls of only
2600 // constants and truncs, so if we find any other types of SCEV
2601 // as operands of the addrec then we bail and return nullptr here.
2602 // Otherwise, we return the type of the operand of a trunc that we find.
2603 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2604 return T->getOperand()->getType();
2605 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2606 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2607 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2608 return T->getOperand()->getType();
2609 }
2610 return nullptr;
2611 };
2612 if (auto *SrcType = FindTruncSrcType()) {
2614 bool Ok = true;
2615 // Check all the operands to see if they can be represented in the
2616 // source type of the truncate.
2617 for (const SCEV *Op : Ops) {
2618 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2619 if (T->getOperand()->getType() != SrcType) {
2620 Ok = false;
2621 break;
2622 }
2623 LargeOps.push_back(T->getOperand());
2624 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2625 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2626 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2627 SmallVector<const SCEV *, 8> LargeMulOps;
2628 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2629 if (const SCEVTruncateExpr *T =
2630 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2631 if (T->getOperand()->getType() != SrcType) {
2632 Ok = false;
2633 break;
2634 }
2635 LargeMulOps.push_back(T->getOperand());
2636 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2637 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2638 } else {
2639 Ok = false;
2640 break;
2641 }
2642 }
2643 if (Ok)
2644 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2645 } else {
2646 Ok = false;
2647 break;
2648 }
2649 }
2650 if (Ok) {
2651 // Evaluate the expression in the larger type.
2652 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2653 // If it folds to something simple, use it. Otherwise, don't.
2654 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2655 return getTruncateExpr(Fold, Ty);
2656 }
2657 }
2658
2659 if (Ops.size() == 2) {
2660 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2661 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2662 // C1).
2663 const SCEV *A = Ops[0];
2664 const SCEV *B = Ops[1];
2665 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2666 auto *C = dyn_cast<SCEVConstant>(A);
2667 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2668 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2669 auto C2 = C->getAPInt();
2670 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2671
2672 APInt ConstAdd = C1 + C2;
2673 auto AddFlags = AddExpr->getNoWrapFlags();
2674 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2676 ConstAdd.ule(C1)) {
2677 PreservedFlags =
2679 }
2680
2681 // Adding a constant with the same sign and small magnitude is NSW, if the
2682 // original AddExpr was NSW.
2684 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2685 ConstAdd.abs().ule(C1.abs())) {
2686 PreservedFlags =
2688 }
2689
2690 if (PreservedFlags != SCEV::FlagAnyWrap) {
2691 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2692 NewOps[0] = getConstant(ConstAdd);
2693 return getAddExpr(NewOps, PreservedFlags);
2694 }
2695 }
2696 }
2697
2698 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2699 if (Ops.size() == 2) {
2700 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2701 if (Mul && Mul->getNumOperands() == 2 &&
2702 Mul->getOperand(0)->isAllOnesValue()) {
2703 const SCEV *X;
2704 const SCEV *Y;
2705 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2706 return getMulExpr(Y, getUDivExpr(X, Y));
2707 }
2708 }
2709 }
2710
2711 // Skip past any other cast SCEVs.
2712 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2713 ++Idx;
2714
2715 // If there are add operands they would be next.
2716 if (Idx < Ops.size()) {
2717 bool DeletedAdd = false;
2718 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2719 // common NUW flag for expression after inlining. Other flags cannot be
2720 // preserved, because they may depend on the original order of operations.
2721 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2722 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2723 if (Ops.size() > AddOpsInlineThreshold ||
2724 Add->getNumOperands() > AddOpsInlineThreshold)
2725 break;
2726 // If we have an add, expand the add operands onto the end of the operands
2727 // list.
2728 Ops.erase(Ops.begin()+Idx);
2729 append_range(Ops, Add->operands());
2730 DeletedAdd = true;
2731 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2732 }
2733
2734 // If we deleted at least one add, we added operands to the end of the list,
2735 // and they are not necessarily sorted. Recurse to resort and resimplify
2736 // any operands we just acquired.
2737 if (DeletedAdd)
2738 return getAddExpr(Ops, CommonFlags, Depth + 1);
2739 }
2740
2741 // Skip over the add expression until we get to a multiply.
2742 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2743 ++Idx;
2744
2745 // Check to see if there are any folding opportunities present with
2746 // operands multiplied by constant values.
2747 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2751 APInt AccumulatedConstant(BitWidth, 0);
2752 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2753 Ops, APInt(BitWidth, 1), *this)) {
2754 struct APIntCompare {
2755 bool operator()(const APInt &LHS, const APInt &RHS) const {
2756 return LHS.ult(RHS);
2757 }
2758 };
2759
2760 // Some interesting folding opportunity is present, so its worthwhile to
2761 // re-generate the operands list. Group the operands by constant scale,
2762 // to avoid multiplying by the same constant scale multiple times.
2763 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2764 for (const SCEV *NewOp : NewOps)
2765 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2766 // Re-generate the operands list.
2767 Ops.clear();
2768 if (AccumulatedConstant != 0)
2769 Ops.push_back(getConstant(AccumulatedConstant));
2770 for (auto &MulOp : MulOpLists) {
2771 if (MulOp.first == 1) {
2772 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2773 } else if (MulOp.first != 0) {
2775 getConstant(MulOp.first),
2776 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2777 SCEV::FlagAnyWrap, Depth + 1));
2778 }
2779 }
2780 if (Ops.empty())
2781 return getZero(Ty);
2782 if (Ops.size() == 1)
2783 return Ops[0];
2784 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2785 }
2786 }
2787
2788 // If we are adding something to a multiply expression, make sure the
2789 // something is not already an operand of the multiply. If so, merge it into
2790 // the multiply.
2791 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2792 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2793 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2794 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2795 if (isa<SCEVConstant>(MulOpSCEV))
2796 continue;
2797 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2798 if (MulOpSCEV == Ops[AddOp]) {
2799 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2800 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2801 if (Mul->getNumOperands() != 2) {
2802 // If the multiply has more than two operands, we must get the
2803 // Y*Z term.
2805 Mul->operands().take_front(MulOp));
2806 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2807 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2808 }
2809 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2810 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2811 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2813 if (Ops.size() == 2) return OuterMul;
2814 if (AddOp < Idx) {
2815 Ops.erase(Ops.begin()+AddOp);
2816 Ops.erase(Ops.begin()+Idx-1);
2817 } else {
2818 Ops.erase(Ops.begin()+Idx);
2819 Ops.erase(Ops.begin()+AddOp-1);
2820 }
2821 Ops.push_back(OuterMul);
2822 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2823 }
2824
2825 // Check this multiply against other multiplies being added together.
2826 for (unsigned OtherMulIdx = Idx+1;
2827 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 ++OtherMulIdx) {
2829 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2830 // If MulOp occurs in OtherMul, we can fold the two multiplies
2831 // together.
2832 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2833 OMulOp != e; ++OMulOp)
2834 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2835 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2836 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2837 if (Mul->getNumOperands() != 2) {
2839 Mul->operands().take_front(MulOp));
2840 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2841 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2842 }
2843 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2844 if (OtherMul->getNumOperands() != 2) {
2846 OtherMul->operands().take_front(OMulOp));
2847 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2848 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2849 }
2850 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2851 const SCEV *InnerMulSum =
2852 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2853 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2855 if (Ops.size() == 2) return OuterMul;
2856 Ops.erase(Ops.begin()+Idx);
2857 Ops.erase(Ops.begin()+OtherMulIdx-1);
2858 Ops.push_back(OuterMul);
2859 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2860 }
2861 }
2862 }
2863 }
2864
2865 // If there are any add recurrences in the operands list, see if any other
2866 // added values are loop invariant. If so, we can fold them into the
2867 // recurrence.
2868 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2869 ++Idx;
2870
2871 // Scan over all recurrences, trying to fold loop invariants into them.
2872 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2873 // Scan all of the other operands to this add and add them to the vector if
2874 // they are loop invariant w.r.t. the recurrence.
2876 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2877 const Loop *AddRecLoop = AddRec->getLoop();
2878 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2879 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2880 LIOps.push_back(Ops[i]);
2881 Ops.erase(Ops.begin()+i);
2882 --i; --e;
2883 }
2884
2885 // If we found some loop invariants, fold them into the recurrence.
2886 if (!LIOps.empty()) {
2887 // Compute nowrap flags for the addition of the loop-invariant ops and
2888 // the addrec. Temporarily push it as an operand for that purpose. These
2889 // flags are valid in the scope of the addrec only.
2890 LIOps.push_back(AddRec);
2891 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2892 LIOps.pop_back();
2893
2894 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2895 LIOps.push_back(AddRec->getStart());
2896
2897 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2898
2899 // It is not in general safe to propagate flags valid on an add within
2900 // the addrec scope to one outside it. We must prove that the inner
2901 // scope is guaranteed to execute if the outer one does to be able to
2902 // safely propagate. We know the program is undefined if poison is
2903 // produced on the inner scoped addrec. We also know that *for this use*
2904 // the outer scoped add can't overflow (because of the flags we just
2905 // computed for the inner scoped add) without the program being undefined.
2906 // Proving that entry to the outer scope neccesitates entry to the inner
2907 // scope, thus proves the program undefined if the flags would be violated
2908 // in the outer scope.
2909 SCEV::NoWrapFlags AddFlags = Flags;
2910 if (AddFlags != SCEV::FlagAnyWrap) {
2911 auto *DefI = getDefiningScopeBound(LIOps);
2912 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2913 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2914 AddFlags = SCEV::FlagAnyWrap;
2915 }
2916 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2917
2918 // Build the new addrec. Propagate the NUW and NSW flags if both the
2919 // outer add and the inner addrec are guaranteed to have no overflow.
2920 // Always propagate NW.
2921 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2922 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2923
2924 // If all of the other operands were loop invariant, we are done.
2925 if (Ops.size() == 1) return NewRec;
2926
2927 // Otherwise, add the folded AddRec by the non-invariant parts.
2928 for (unsigned i = 0;; ++i)
2929 if (Ops[i] == AddRec) {
2930 Ops[i] = NewRec;
2931 break;
2932 }
2933 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2934 }
2935
2936 // Okay, if there weren't any loop invariants to be folded, check to see if
2937 // there are multiple AddRec's with the same loop induction variable being
2938 // added together. If so, we can fold them.
2939 for (unsigned OtherIdx = Idx+1;
2940 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2941 ++OtherIdx) {
2942 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2943 // so that the 1st found AddRecExpr is dominated by all others.
2944 assert(DT.dominates(
2945 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2946 AddRec->getLoop()->getHeader()) &&
2947 "AddRecExprs are not sorted in reverse dominance order?");
2948 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2949 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2950 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2951 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 ++OtherIdx) {
2953 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2954 if (OtherAddRec->getLoop() == AddRecLoop) {
2955 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2956 i != e; ++i) {
2957 if (i >= AddRecOps.size()) {
2958 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2959 break;
2960 }
2962 AddRecOps[i], OtherAddRec->getOperand(i)};
2963 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2966 }
2967 }
2968 // Step size has changed, so we cannot guarantee no self-wraparound.
2969 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2970 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 }
2973
2974 // Otherwise couldn't fold anything into this recurrence. Move onto the
2975 // next one.
2976 }
2977
2978 // Okay, it looks like we really DO need an add expr. Check to see if we
2979 // already have one, otherwise create a new one.
2980 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2981}
2982
2983const SCEV *
2984ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2985 SCEV::NoWrapFlags Flags) {
2987 ID.AddInteger(scAddExpr);
2988 for (const SCEV *Op : Ops)
2989 ID.AddPointer(Op);
2990 void *IP = nullptr;
2991 SCEVAddExpr *S =
2992 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2993 if (!S) {
2994 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2995 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2996 S = new (SCEVAllocator)
2997 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2998 UniqueSCEVs.InsertNode(S, IP);
2999 registerUser(S, Ops);
3000 }
3001 S->setNoWrapFlags(Flags);
3002 return S;
3003}
3004
3005const SCEV *
3006ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3007 const Loop *L, SCEV::NoWrapFlags Flags) {
3009 ID.AddInteger(scAddRecExpr);
3010 for (const SCEV *Op : Ops)
3011 ID.AddPointer(Op);
3012 ID.AddPointer(L);
3013 void *IP = nullptr;
3014 SCEVAddRecExpr *S =
3015 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3016 if (!S) {
3017 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3019 S = new (SCEVAllocator)
3020 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3021 UniqueSCEVs.InsertNode(S, IP);
3022 LoopUsers[L].push_back(S);
3023 registerUser(S, Ops);
3024 }
3025 setNoWrapFlags(S, Flags);
3026 return S;
3027}
3028
3029const SCEV *
3030ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3031 SCEV::NoWrapFlags Flags) {
3033 ID.AddInteger(scMulExpr);
3034 for (const SCEV *Op : Ops)
3035 ID.AddPointer(Op);
3036 void *IP = nullptr;
3037 SCEVMulExpr *S =
3038 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3039 if (!S) {
3040 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3041 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3042 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3043 O, Ops.size());
3044 UniqueSCEVs.InsertNode(S, IP);
3045 registerUser(S, Ops);
3046 }
3047 S->setNoWrapFlags(Flags);
3048 return S;
3049}
3050
3051static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3052 uint64_t k = i*j;
3053 if (j > 1 && k / j != i) Overflow = true;
3054 return k;
3055}
3056
3057/// Compute the result of "n choose k", the binomial coefficient. If an
3058/// intermediate computation overflows, Overflow will be set and the return will
3059/// be garbage. Overflow is not cleared on absence of overflow.
3060static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3061 // We use the multiplicative formula:
3062 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3063 // At each iteration, we take the n-th term of the numeral and divide by the
3064 // (k-n)th term of the denominator. This division will always produce an
3065 // integral result, and helps reduce the chance of overflow in the
3066 // intermediate computations. However, we can still overflow even when the
3067 // final result would fit.
3068
3069 if (n == 0 || n == k) return 1;
3070 if (k > n) return 0;
3071
3072 if (k > n/2)
3073 k = n-k;
3074
3075 uint64_t r = 1;
3076 for (uint64_t i = 1; i <= k; ++i) {
3077 r = umul_ov(r, n-(i-1), Overflow);
3078 r /= i;
3079 }
3080 return r;
3081}
3082
3083/// Determine if any of the operands in this SCEV are a constant or if
3084/// any of the add or multiply expressions in this SCEV contain a constant.
3085static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3086 struct FindConstantInAddMulChain {
3087 bool FoundConstant = false;
3088
3089 bool follow(const SCEV *S) {
3090 FoundConstant |= isa<SCEVConstant>(S);
3091 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3092 }
3093
3094 bool isDone() const {
3095 return FoundConstant;
3096 }
3097 };
3098
3099 FindConstantInAddMulChain F;
3101 ST.visitAll(StartExpr);
3102 return F.FoundConstant;
3103}
3104
3105/// Get a canonical multiply expression, or something simpler if possible.
3107 SCEV::NoWrapFlags OrigFlags,
3108 unsigned Depth) {
3109 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3110 "only nuw or nsw allowed");
3111 assert(!Ops.empty() && "Cannot get empty mul!");
3112 if (Ops.size() == 1) return Ops[0];
3113#ifndef NDEBUG
3114 Type *ETy = Ops[0]->getType();
3115 assert(!ETy->isPointerTy());
3116 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3117 assert(Ops[i]->getType() == ETy &&
3118 "SCEVMulExpr operand types don't match!");
3119#endif
3120
3121 const SCEV *Folded = constantFoldAndGroupOps(
3122 *this, LI, DT, Ops,
3123 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3124 [](const APInt &C) { return C.isOne(); }, // identity
3125 [](const APInt &C) { return C.isZero(); }); // absorber
3126 if (Folded)
3127 return Folded;
3128
3129 // Delay expensive flag strengthening until necessary.
3130 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3131 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3132 };
3133
3134 // Limit recursion calls depth.
3136 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3137
3138 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3139 // Don't strengthen flags if we have no new information.
3140 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3141 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3142 Mul->setNoWrapFlags(ComputeFlags(Ops));
3143 return S;
3144 }
3145
3146 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3147 if (Ops.size() == 2) {
3148 // C1*(C2+V) -> C1*C2 + C1*V
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3157 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3159 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184 // Let M be the minimum representable signed value. AddRec with nsw
3185 // multiplied by -1 can have signed overflow if and only if it takes a
3186 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3187 // maximum signed value. In all other cases signed overflow is
3188 // impossible.
3189 auto FlagsMask = SCEV::FlagNW;
3190 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3191 auto MinInt =
3192 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3193 if (getSignedRangeMin(AddRec) != MinInt)
3194 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3195 }
3196 return getAddRecExpr(Operands, AddRec->getLoop(),
3197 AddRec->getNoWrapFlags(FlagsMask));
3198 }
3199 }
3200 }
3201 }
3202
3203 // Skip over the add expression until we get to a multiply.
3204 unsigned Idx = 0;
3205 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3206 ++Idx;
3207
3208 // If there are mul operands inline them all into this expression.
3209 if (Idx < Ops.size()) {
3210 bool DeletedMul = false;
3211 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3212 if (Ops.size() > MulOpsInlineThreshold)
3213 break;
3214 // If we have an mul, expand the mul operands onto the end of the
3215 // operands list.
3216 Ops.erase(Ops.begin()+Idx);
3217 append_range(Ops, Mul->operands());
3218 DeletedMul = true;
3219 }
3220
3221 // If we deleted at least one mul, we added operands to the end of the
3222 // list, and they are not necessarily sorted. Recurse to resort and
3223 // resimplify any operands we just acquired.
3224 if (DeletedMul)
3225 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3226 }
3227
3228 // If there are any add recurrences in the operands list, see if any other
3229 // added values are loop invariant. If so, we can fold them into the
3230 // recurrence.
3231 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3232 ++Idx;
3233
3234 // Scan over all recurrences, trying to fold loop invariants into them.
3235 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3236 // Scan all of the other operands to this mul and add them to the vector
3237 // if they are loop invariant w.r.t. the recurrence.
3239 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3240 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3241 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3242 LIOps.push_back(Ops[i]);
3243 Ops.erase(Ops.begin()+i);
3244 --i; --e;
3245 }
3246
3247 // If we found some loop invariants, fold them into the recurrence.
3248 if (!LIOps.empty()) {
3249 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3251 NewOps.reserve(AddRec->getNumOperands());
3252 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3253
3254 // If both the mul and addrec are nuw, we can preserve nuw.
3255 // If both the mul and addrec are nsw, we can only preserve nsw if either
3256 // a) they are also nuw, or
3257 // b) all multiplications of addrec operands with scale are nsw.
3258 SCEV::NoWrapFlags Flags =
3259 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3260
3261 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3262 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3263 SCEV::FlagAnyWrap, Depth + 1));
3264
3265 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3267 Instruction::Mul, getSignedRange(Scale),
3269 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3270 Flags = clearFlags(Flags, SCEV::FlagNSW);
3271 }
3272 }
3273
3274 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3275
3276 // If all of the other operands were loop invariant, we are done.
3277 if (Ops.size() == 1) return NewRec;
3278
3279 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3280 for (unsigned i = 0;; ++i)
3281 if (Ops[i] == AddRec) {
3282 Ops[i] = NewRec;
3283 break;
3284 }
3285 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 // Okay, if there weren't any loop invariants to be folded, check to see
3289 // if there are multiple AddRec's with the same loop induction variable
3290 // being multiplied together. If so, we can fold them.
3291
3292 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3293 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3294 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3295 // ]]],+,...up to x=2n}.
3296 // Note that the arguments to choose() are always integers with values
3297 // known at compile time, never SCEV objects.
3298 //
3299 // The implementation avoids pointless extra computations when the two
3300 // addrec's are of different length (mathematically, it's equivalent to
3301 // an infinite stream of zeros on the right).
3302 bool OpsModified = false;
3303 for (unsigned OtherIdx = Idx+1;
3304 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3305 ++OtherIdx) {
3306 const SCEVAddRecExpr *OtherAddRec =
3307 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3308 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3309 continue;
3310
3311 // Limit max number of arguments to avoid creation of unreasonably big
3312 // SCEVAddRecs with very complex operands.
3313 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3314 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3315 continue;
3316
3317 bool Overflow = false;
3318 Type *Ty = AddRec->getType();
3319 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3321 for (int x = 0, xe = AddRec->getNumOperands() +
3322 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3323 SmallVector <const SCEV *, 7> SumOps;
3324 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3325 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3326 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3327 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3328 z < ze && !Overflow; ++z) {
3329 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3330 uint64_t Coeff;
3331 if (LargerThan64Bits)
3332 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3333 else
3334 Coeff = Coeff1*Coeff2;
3335 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3336 const SCEV *Term1 = AddRec->getOperand(y-z);
3337 const SCEV *Term2 = OtherAddRec->getOperand(z);
3338 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3339 SCEV::FlagAnyWrap, Depth + 1));
3340 }
3341 }
3342 if (SumOps.empty())
3343 SumOps.push_back(getZero(Ty));
3344 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3345 }
3346 if (!Overflow) {
3347 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3349 if (Ops.size() == 2) return NewAddRec;
3350 Ops[Idx] = NewAddRec;
3351 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3352 OpsModified = true;
3353 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3354 if (!AddRec)
3355 break;
3356 }
3357 }
3358 if (OpsModified)
3359 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3360
3361 // Otherwise couldn't fold anything into this recurrence. Move onto the
3362 // next one.
3363 }
3364
3365 // Okay, it looks like we really DO need an mul expr. Check to see if we
3366 // already have one, otherwise create a new one.
3367 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3368}
3369
3370/// Represents an unsigned remainder expression based on unsigned division.
3372 const SCEV *RHS) {
3375 "SCEVURemExpr operand types don't match!");
3376
3377 // Short-circuit easy cases
3378 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3379 // If constant is one, the result is trivial
3380 if (RHSC->getValue()->isOne())
3381 return getZero(LHS->getType()); // X urem 1 --> 0
3382
3383 // If constant is a power of two, fold into a zext(trunc(LHS)).
3384 if (RHSC->getAPInt().isPowerOf2()) {
3385 Type *FullTy = LHS->getType();
3386 Type *TruncTy =
3387 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3388 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3389 }
3390 }
3391
3392 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3393 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3394 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3395 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3396}
3397
3398/// Get a canonical unsigned division expression, or something simpler if
3399/// possible.
3401 const SCEV *RHS) {
3402 assert(!LHS->getType()->isPointerTy() &&
3403 "SCEVUDivExpr operand can't be pointer!");
3404 assert(LHS->getType() == RHS->getType() &&
3405 "SCEVUDivExpr operand types don't match!");
3406
3408 ID.AddInteger(scUDivExpr);
3409 ID.AddPointer(LHS);
3410 ID.AddPointer(RHS);
3411 void *IP = nullptr;
3412 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3413 return S;
3414
3415 // 0 udiv Y == 0
3416 if (match(LHS, m_scev_Zero()))
3417 return LHS;
3418
3419 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3420 if (RHSC->getValue()->isOne())
3421 return LHS; // X udiv 1 --> x
3422 // If the denominator is zero, the result of the udiv is undefined. Don't
3423 // try to analyze it, because the resolution chosen here may differ from
3424 // the resolution chosen in other parts of the compiler.
3425 if (!RHSC->getValue()->isZero()) {
3426 // Determine if the division can be folded into the operands of
3427 // its operands.
3428 // TODO: Generalize this to non-constants by using known-bits information.
3429 Type *Ty = LHS->getType();
3430 unsigned LZ = RHSC->getAPInt().countl_zero();
3431 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3432 // For non-power-of-two values, effectively round the value up to the
3433 // nearest power of two.
3434 if (!RHSC->getAPInt().isPowerOf2())
3435 ++MaxShiftAmt;
3436 IntegerType *ExtTy =
3437 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3438 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3439 if (const SCEVConstant *Step =
3440 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3441 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3442 const APInt &StepInt = Step->getAPInt();
3443 const APInt &DivInt = RHSC->getAPInt();
3444 if (!StepInt.urem(DivInt) &&
3445 getZeroExtendExpr(AR, ExtTy) ==
3446 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3447 getZeroExtendExpr(Step, ExtTy),
3448 AR->getLoop(), SCEV::FlagAnyWrap)) {
3450 for (const SCEV *Op : AR->operands())
3451 Operands.push_back(getUDivExpr(Op, RHS));
3452 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3453 }
3454 /// Get a canonical UDivExpr for a recurrence.
3455 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3456 // We can currently only fold X%N if X is constant.
3457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3458 if (StartC && !DivInt.urem(StepInt) &&
3459 getZeroExtendExpr(AR, ExtTy) ==
3460 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3461 getZeroExtendExpr(Step, ExtTy),
3462 AR->getLoop(), SCEV::FlagAnyWrap)) {
3463 const APInt &StartInt = StartC->getAPInt();
3464 const APInt &StartRem = StartInt.urem(StepInt);
3465 if (StartRem != 0) {
3466 const SCEV *NewLHS =
3467 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3468 AR->getLoop(), SCEV::FlagNW);
3469 if (LHS != NewLHS) {
3470 LHS = NewLHS;
3471
3472 // Reset the ID to include the new LHS, and check if it is
3473 // already cached.
3474 ID.clear();
3475 ID.AddInteger(scUDivExpr);
3476 ID.AddPointer(LHS);
3477 ID.AddPointer(RHS);
3478 IP = nullptr;
3479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3480 return S;
3481 }
3482 }
3483 }
3484 }
3485 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3486 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3488 for (const SCEV *Op : M->operands())
3489 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3490 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3491 // Find an operand that's safely divisible.
3492 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3493 const SCEV *Op = M->getOperand(i);
3494 const SCEV *Div = getUDivExpr(Op, RHSC);
3495 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3496 Operands = SmallVector<const SCEV *, 4>(M->operands());
3497 Operands[i] = Div;
3498 return getMulExpr(Operands);
3499 }
3500 }
3501 }
3502
3503 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3504 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3505 if (auto *DivisorConstant =
3506 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3507 bool Overflow = false;
3508 APInt NewRHS =
3509 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3510 if (Overflow) {
3511 return getConstant(RHSC->getType(), 0, false);
3512 }
3513 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3514 }
3515 }
3516
3517 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3518 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3520 for (const SCEV *Op : A->operands())
3521 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3522 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3523 Operands.clear();
3524 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3525 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3526 if (isa<SCEVUDivExpr>(Op) ||
3527 getMulExpr(Op, RHS) != A->getOperand(i))
3528 break;
3529 Operands.push_back(Op);
3530 }
3531 if (Operands.size() == A->getNumOperands())
3532 return getAddExpr(Operands);
3533 }
3534 }
3535
3536 // Fold if both operands are constant.
3537 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3538 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3539 }
3540 }
3541
3542 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3543 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3544 AE && AE->getNumOperands() == 2) {
3545 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3546 const APInt &NegC = VC->getAPInt();
3547 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3548 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3549 if (MME && MME->getNumOperands() == 2 &&
3550 isa<SCEVConstant>(MME->getOperand(0)) &&
3551 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3552 MME->getOperand(1) == RHS)
3553 return getZero(LHS->getType());
3554 }
3555 }
3556 }
3557
3558 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3559 // changes). Make sure we get a new one.
3560 IP = nullptr;
3561 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3562 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3563 LHS, RHS);
3564 UniqueSCEVs.InsertNode(S, IP);
3565 registerUser(S, {LHS, RHS});
3566 return S;
3567}
3568
3569APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3570 APInt A = C1->getAPInt().abs();
3571 APInt B = C2->getAPInt().abs();
3572 uint32_t ABW = A.getBitWidth();
3573 uint32_t BBW = B.getBitWidth();
3574
3575 if (ABW > BBW)
3576 B = B.zext(ABW);
3577 else if (ABW < BBW)
3578 A = A.zext(BBW);
3579
3580 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3581}
3582
3583/// Get a canonical unsigned division expression, or something simpler if
3584/// possible. There is no representation for an exact udiv in SCEV IR, but we
3585/// can attempt to remove factors from the LHS and RHS. We can't do this when
3586/// it's not exact because the udiv may be clearing bits.
3588 const SCEV *RHS) {
3589 // TODO: we could try to find factors in all sorts of things, but for now we
3590 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3591 // end of this file for inspiration.
3592
3593 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3594 if (!Mul || !Mul->hasNoUnsignedWrap())
3595 return getUDivExpr(LHS, RHS);
3596
3597 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3598 // If the mulexpr multiplies by a constant, then that constant must be the
3599 // first element of the mulexpr.
3600 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3601 if (LHSCst == RHSCst) {
3603 return getMulExpr(Operands);
3604 }
3605
3606 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3607 // that there's a factor provided by one of the other terms. We need to
3608 // check.
3609 APInt Factor = gcd(LHSCst, RHSCst);
3610 if (!Factor.isIntN(1)) {
3611 LHSCst =
3612 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3613 RHSCst =
3614 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3616 Operands.push_back(LHSCst);
3617 append_range(Operands, Mul->operands().drop_front());
3619 RHS = RHSCst;
3620 Mul = dyn_cast<SCEVMulExpr>(LHS);
3621 if (!Mul)
3622 return getUDivExactExpr(LHS, RHS);
3623 }
3624 }
3625 }
3626
3627 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3628 if (Mul->getOperand(i) == RHS) {
3630 append_range(Operands, Mul->operands().take_front(i));
3631 append_range(Operands, Mul->operands().drop_front(i + 1));
3632 return getMulExpr(Operands);
3633 }
3634 }
3635
3636 return getUDivExpr(LHS, RHS);
3637}
3638
3639/// Get an add recurrence expression for the specified loop. Simplify the
3640/// expression as much as possible.
3641const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3642 const Loop *L,
3643 SCEV::NoWrapFlags Flags) {
3645 Operands.push_back(Start);
3646 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3647 if (StepChrec->getLoop() == L) {
3648 append_range(Operands, StepChrec->operands());
3649 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3650 }
3651
3652 Operands.push_back(Step);
3653 return getAddRecExpr(Operands, L, Flags);
3654}
3655
3656/// Get an add recurrence expression for the specified loop. Simplify the
3657/// expression as much as possible.
3658const SCEV *
3660 const Loop *L, SCEV::NoWrapFlags Flags) {
3661 if (Operands.size() == 1) return Operands[0];
3662#ifndef NDEBUG
3664 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3665 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3666 "SCEVAddRecExpr operand types don't match!");
3667 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3668 }
3669 for (const SCEV *Op : Operands)
3671 "SCEVAddRecExpr operand is not available at loop entry!");
3672#endif
3673
3674 if (Operands.back()->isZero()) {
3675 Operands.pop_back();
3676 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3677 }
3678
3679 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3680 // use that information to infer NUW and NSW flags. However, computing a
3681 // BE count requires calling getAddRecExpr, so we may not yet have a
3682 // meaningful BE count at this point (and if we don't, we'd be stuck
3683 // with a SCEVCouldNotCompute as the cached BE count).
3684
3685 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3686
3687 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3688 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3689 const Loop *NestedLoop = NestedAR->getLoop();
3690 if (L->contains(NestedLoop)
3691 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3692 : (!NestedLoop->contains(L) &&
3693 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3694 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3695 Operands[0] = NestedAR->getStart();
3696 // AddRecs require their operands be loop-invariant with respect to their
3697 // loops. Don't perform this transformation if it would break this
3698 // requirement.
3699 bool AllInvariant = all_of(
3700 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3701
3702 if (AllInvariant) {
3703 // Create a recurrence for the outer loop with the same step size.
3704 //
3705 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3706 // inner recurrence has the same property.
3707 SCEV::NoWrapFlags OuterFlags =
3708 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3709
3710 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3711 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3712 return isLoopInvariant(Op, NestedLoop);
3713 });
3714
3715 if (AllInvariant) {
3716 // Ok, both add recurrences are valid after the transformation.
3717 //
3718 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3719 // the outer recurrence has the same property.
3720 SCEV::NoWrapFlags InnerFlags =
3721 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3722 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3723 }
3724 }
3725 // Reset Operands to its original state.
3726 Operands[0] = NestedAR;
3727 }
3728 }
3729
3730 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3731 // already have one, otherwise create a new one.
3732 return getOrCreateAddRecExpr(Operands, L, Flags);
3733}
3734
3735const SCEV *
3737 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3738 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3739 // getSCEV(Base)->getType() has the same address space as Base->getType()
3740 // because SCEV::getType() preserves the address space.
3741 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3742 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3743 if (NW != GEPNoWrapFlags::none()) {
3744 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3745 // but to do that, we have to ensure that said flag is valid in the entire
3746 // defined scope of the SCEV.
3747 // TODO: non-instructions have global scope. We might be able to prove
3748 // some global scope cases
3749 auto *GEPI = dyn_cast<Instruction>(GEP);
3750 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3751 NW = GEPNoWrapFlags::none();
3752 }
3753
3755 if (NW.hasNoUnsignedSignedWrap())
3756 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3757 if (NW.hasNoUnsignedWrap())
3758 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3759
3760 Type *CurTy = GEP->getType();
3761 bool FirstIter = true;
3763 for (const SCEV *IndexExpr : IndexExprs) {
3764 // Compute the (potentially symbolic) offset in bytes for this index.
3765 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3766 // For a struct, add the member offset.
3767 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3768 unsigned FieldNo = Index->getZExtValue();
3769 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3770 Offsets.push_back(FieldOffset);
3771
3772 // Update CurTy to the type of the field at Index.
3773 CurTy = STy->getTypeAtIndex(Index);
3774 } else {
3775 // Update CurTy to its element type.
3776 if (FirstIter) {
3777 assert(isa<PointerType>(CurTy) &&
3778 "The first index of a GEP indexes a pointer");
3779 CurTy = GEP->getSourceElementType();
3780 FirstIter = false;
3781 } else {
3783 }
3784 // For an array, add the element offset, explicitly scaled.
3785 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3786 // Getelementptr indices are signed.
3787 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3788
3789 // Multiply the index by the element size to compute the element offset.
3790 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3791 Offsets.push_back(LocalOffset);
3792 }
3793 }
3794
3795 // Handle degenerate case of GEP without offsets.
3796 if (Offsets.empty())
3797 return BaseExpr;
3798
3799 // Add the offsets together, assuming nsw if inbounds.
3800 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3801 // Add the base address and the offset. We cannot use the nsw flag, as the
3802 // base address is unsigned. However, if we know that the offset is
3803 // non-negative, we can use nuw.
3804 bool NUW = NW.hasNoUnsignedWrap() ||
3807 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3808 assert(BaseExpr->getType() == GEPExpr->getType() &&
3809 "GEP should not change type mid-flight.");
3810 return GEPExpr;
3811}
3812
3813SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3816 ID.AddInteger(SCEVType);
3817 for (const SCEV *Op : Ops)
3818 ID.AddPointer(Op);
3819 void *IP = nullptr;
3820 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3821}
3822
3823const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3825 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3826}
3827
3830 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3831 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3832 if (Ops.size() == 1) return Ops[0];
3833#ifndef NDEBUG
3834 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3835 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3836 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3837 "Operand types don't match!");
3838 assert(Ops[0]->getType()->isPointerTy() ==
3839 Ops[i]->getType()->isPointerTy() &&
3840 "min/max should be consistently pointerish");
3841 }
3842#endif
3843
3844 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3845 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3846
3847 const SCEV *Folded = constantFoldAndGroupOps(
3848 *this, LI, DT, Ops,
3849 [&](const APInt &C1, const APInt &C2) {
3850 switch (Kind) {
3851 case scSMaxExpr:
3852 return APIntOps::smax(C1, C2);
3853 case scSMinExpr:
3854 return APIntOps::smin(C1, C2);
3855 case scUMaxExpr:
3856 return APIntOps::umax(C1, C2);
3857 case scUMinExpr:
3858 return APIntOps::umin(C1, C2);
3859 default:
3860 llvm_unreachable("Unknown SCEV min/max opcode");
3861 }
3862 },
3863 [&](const APInt &C) {
3864 // identity
3865 if (IsMax)
3866 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3867 else
3868 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3869 },
3870 [&](const APInt &C) {
3871 // absorber
3872 if (IsMax)
3873 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3874 else
3875 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3876 });
3877 if (Folded)
3878 return Folded;
3879
3880 // Check if we have created the same expression before.
3881 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3882 return S;
3883 }
3884
3885 // Find the first operation of the same kind
3886 unsigned Idx = 0;
3887 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3888 ++Idx;
3889
3890 // Check to see if one of the operands is of the same kind. If so, expand its
3891 // operands onto our operand list, and recurse to simplify.
3892 if (Idx < Ops.size()) {
3893 bool DeletedAny = false;
3894 while (Ops[Idx]->getSCEVType() == Kind) {
3895 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3896 Ops.erase(Ops.begin()+Idx);
3897 append_range(Ops, SMME->operands());
3898 DeletedAny = true;
3899 }
3900
3901 if (DeletedAny)
3902 return getMinMaxExpr(Kind, Ops);
3903 }
3904
3905 // Okay, check to see if the same value occurs in the operand list twice. If
3906 // so, delete one. Since we sorted the list, these values are required to
3907 // be adjacent.
3912 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3913 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3914 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3915 if (Ops[i] == Ops[i + 1] ||
3916 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3917 // X op Y op Y --> X op Y
3918 // X op Y --> X, if we know X, Y are ordered appropriately
3919 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3920 --i;
3921 --e;
3922 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3923 Ops[i + 1])) {
3924 // X op Y --> Y, if we know X, Y are ordered appropriately
3925 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3926 --i;
3927 --e;
3928 }
3929 }
3930
3931 if (Ops.size() == 1) return Ops[0];
3932
3933 assert(!Ops.empty() && "Reduced smax down to nothing!");
3934
3935 // Okay, it looks like we really DO need an expr. Check to see if we
3936 // already have one, otherwise create a new one.
3938 ID.AddInteger(Kind);
3939 for (const SCEV *Op : Ops)
3940 ID.AddPointer(Op);
3941 void *IP = nullptr;
3942 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3943 if (ExistingSCEV)
3944 return ExistingSCEV;
3945 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3946 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3947 SCEV *S = new (SCEVAllocator)
3948 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3949
3950 UniqueSCEVs.InsertNode(S, IP);
3951 registerUser(S, Ops);
3952 return S;
3953}
3954
3955namespace {
3956
3957class SCEVSequentialMinMaxDeduplicatingVisitor final
3958 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3959 std::optional<const SCEV *>> {
3960 using RetVal = std::optional<const SCEV *>;
3962
3963 ScalarEvolution &SE;
3964 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3965 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3967
3968 bool canRecurseInto(SCEVTypes Kind) const {
3969 // We can only recurse into the SCEV expression of the same effective type
3970 // as the type of our root SCEV expression.
3971 return RootKind == Kind || NonSequentialRootKind == Kind;
3972 };
3973
3974 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3975 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3976 "Only for min/max expressions.");
3977 SCEVTypes Kind = S->getSCEVType();
3978
3979 if (!canRecurseInto(Kind))
3980 return S;
3981
3982 auto *NAry = cast<SCEVNAryExpr>(S);
3984 bool Changed = visit(Kind, NAry->operands(), NewOps);
3985
3986 if (!Changed)
3987 return S;
3988 if (NewOps.empty())
3989 return std::nullopt;
3990
3991 return isa<SCEVSequentialMinMaxExpr>(S)
3992 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3993 : SE.getMinMaxExpr(Kind, NewOps);
3994 }
3995
3996 RetVal visit(const SCEV *S) {
3997 // Has the whole operand been seen already?
3998 if (!SeenOps.insert(S).second)
3999 return std::nullopt;
4000 return Base::visit(S);
4001 }
4002
4003public:
4004 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4005 SCEVTypes RootKind)
4006 : SE(SE), RootKind(RootKind),
4007 NonSequentialRootKind(
4008 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4009 RootKind)) {}
4010
4011 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4013 bool Changed = false;
4015 Ops.reserve(OrigOps.size());
4016
4017 for (const SCEV *Op : OrigOps) {
4018 RetVal NewOp = visit(Op);
4019 if (NewOp != Op)
4020 Changed = true;
4021 if (NewOp)
4022 Ops.emplace_back(*NewOp);
4023 }
4024
4025 if (Changed)
4026 NewOps = std::move(Ops);
4027 return Changed;
4028 }
4029
4030 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4031
4032 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4033
4034 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4035
4036 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4037
4038 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4039
4040 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4041
4042 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4043
4044 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4045
4046 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4047
4048 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4049
4050 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4051 return visitAnyMinMaxExpr(Expr);
4052 }
4053
4054 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4055 return visitAnyMinMaxExpr(Expr);
4056 }
4057
4058 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4059 return visitAnyMinMaxExpr(Expr);
4060 }
4061
4062 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4063 return visitAnyMinMaxExpr(Expr);
4064 }
4065
4066 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4067 return visitAnyMinMaxExpr(Expr);
4068 }
4069
4070 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4071
4072 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4073};
4074
4075} // namespace
4076
4078 switch (Kind) {
4079 case scConstant:
4080 case scVScale:
4081 case scTruncate:
4082 case scZeroExtend:
4083 case scSignExtend:
4084 case scPtrToInt:
4085 case scAddExpr:
4086 case scMulExpr:
4087 case scUDivExpr:
4088 case scAddRecExpr:
4089 case scUMaxExpr:
4090 case scSMaxExpr:
4091 case scUMinExpr:
4092 case scSMinExpr:
4093 case scUnknown:
4094 // If any operand is poison, the whole expression is poison.
4095 return true;
4097 // FIXME: if the *first* operand is poison, the whole expression is poison.
4098 return false; // Pessimistically, say that it does not propagate poison.
4099 case scCouldNotCompute:
4100 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4101 }
4102 llvm_unreachable("Unknown SCEV kind!");
4103}
4104
4105namespace {
4106// The only way poison may be introduced in a SCEV expression is from a
4107// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4108// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4109// introduce poison -- they encode guaranteed, non-speculated knowledge.
4110//
4111// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4112// with the notable exception of umin_seq, where only poison from the first
4113// operand is (unconditionally) propagated.
4114struct SCEVPoisonCollector {
4115 bool LookThroughMaybePoisonBlocking;
4117 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4118 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4119
4120 bool follow(const SCEV *S) {
4121 if (!LookThroughMaybePoisonBlocking &&
4123 return false;
4124
4125 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4126 if (!isGuaranteedNotToBePoison(SU->getValue()))
4127 MaybePoison.insert(SU);
4128 }
4129 return true;
4130 }
4131 bool isDone() const { return false; }
4132};
4133} // namespace
4134
4135/// Return true if V is poison given that AssumedPoison is already poison.
4136static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4137 // First collect all SCEVs that might result in AssumedPoison to be poison.
4138 // We need to look through potentially poison-blocking operations here,
4139 // because we want to find all SCEVs that *might* result in poison, not only
4140 // those that are *required* to.
4141 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4142 visitAll(AssumedPoison, PC1);
4143
4144 // AssumedPoison is never poison. As the assumption is false, the implication
4145 // is true. Don't bother walking the other SCEV in this case.
4146 if (PC1.MaybePoison.empty())
4147 return true;
4148
4149 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4150 // as well. We cannot look through potentially poison-blocking operations
4151 // here, as their arguments only *may* make the result poison.
4152 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4153 visitAll(S, PC2);
4154
4155 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4156 // it will also make S poison by being part of PC2.MaybePoison.
4157 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4158}
4159
4161 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4162 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4163 visitAll(S, PC);
4164 for (const SCEVUnknown *SU : PC.MaybePoison)
4165 Result.insert(SU->getValue());
4166}
4167
4169 const SCEV *S, Instruction *I,
4170 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4171 // If the instruction cannot be poison, it's always safe to reuse.
4173 return true;
4174
4175 // Otherwise, it is possible that I is more poisonous that S. Collect the
4176 // poison-contributors of S, and then check whether I has any additional
4177 // poison-contributors. Poison that is contributed through poison-generating
4178 // flags is handled by dropping those flags instead.
4180 getPoisonGeneratingValues(PoisonVals, S);
4181
4182 SmallVector<Value *> Worklist;
4184 Worklist.push_back(I);
4185 while (!Worklist.empty()) {
4186 Value *V = Worklist.pop_back_val();
4187 if (!Visited.insert(V).second)
4188 continue;
4189
4190 // Avoid walking large instruction graphs.
4191 if (Visited.size() > 16)
4192 return false;
4193
4194 // Either the value can't be poison, or the S would also be poison if it
4195 // is.
4196 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4197 continue;
4198
4199 auto *I = dyn_cast<Instruction>(V);
4200 if (!I)
4201 return false;
4202
4203 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4204 // can't replace an arbitrary add with disjoint or, even if we drop the
4205 // flag. We would need to convert the or into an add.
4206 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4207 if (PDI->isDisjoint())
4208 return false;
4209
4210 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4211 // because SCEV currently assumes it can't be poison. Remove this special
4212 // case once we proper model when vscale can be poison.
4213 if (auto *II = dyn_cast<IntrinsicInst>(I);
4214 II && II->getIntrinsicID() == Intrinsic::vscale)
4215 continue;
4216
4217 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4218 return false;
4219
4220 // If the instruction can't create poison, we can recurse to its operands.
4221 if (I->hasPoisonGeneratingAnnotations())
4222 DropPoisonGeneratingInsts.push_back(I);
4223
4224 for (Value *Op : I->operands())
4225 Worklist.push_back(Op);
4226 }
4227 return true;
4228}
4229
4230const SCEV *
4233 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4234 "Not a SCEVSequentialMinMaxExpr!");
4235 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4236 if (Ops.size() == 1)
4237 return Ops[0];
4238#ifndef NDEBUG
4239 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4240 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4241 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4242 "Operand types don't match!");
4243 assert(Ops[0]->getType()->isPointerTy() ==
4244 Ops[i]->getType()->isPointerTy() &&
4245 "min/max should be consistently pointerish");
4246 }
4247#endif
4248
4249 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4250 // so we can *NOT* do any kind of sorting of the expressions!
4251
4252 // Check if we have created the same expression before.
4253 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4254 return S;
4255
4256 // FIXME: there are *some* simplifications that we can do here.
4257
4258 // Keep only the first instance of an operand.
4259 {
4260 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4261 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4262 if (Changed)
4263 return getSequentialMinMaxExpr(Kind, Ops);
4264 }
4265
4266 // Check to see if one of the operands is of the same kind. If so, expand its
4267 // operands onto our operand list, and recurse to simplify.
4268 {
4269 unsigned Idx = 0;
4270 bool DeletedAny = false;
4271 while (Idx < Ops.size()) {
4272 if (Ops[Idx]->getSCEVType() != Kind) {
4273 ++Idx;
4274 continue;
4275 }
4276 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4277 Ops.erase(Ops.begin() + Idx);
4278 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4279 SMME->operands().end());
4280 DeletedAny = true;
4281 }
4282
4283 if (DeletedAny)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 const SCEV *SaturationPoint;
4289 switch (Kind) {
4291 SaturationPoint = getZero(Ops[0]->getType());
4292 Pred = ICmpInst::ICMP_ULE;
4293 break;
4294 default:
4295 llvm_unreachable("Not a sequential min/max type.");
4296 }
4297
4298 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4299 if (!isGuaranteedNotToCauseUB(Ops[i]))
4300 continue;
4301 // We can replace %x umin_seq %y with %x umin %y if either:
4302 // * %y being poison implies %x is also poison.
4303 // * %x cannot be the saturating value (e.g. zero for umin).
4304 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4305 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4306 SaturationPoint)) {
4307 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4308 Ops[i - 1] = getMinMaxExpr(
4310 SeqOps);
4311 Ops.erase(Ops.begin() + i);
4312 return getSequentialMinMaxExpr(Kind, Ops);
4313 }
4314 // Fold %x umin_seq %y to %x if %x ule %y.
4315 // TODO: We might be able to prove the predicate for a later operand.
4316 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4317 Ops.erase(Ops.begin() + i);
4318 return getSequentialMinMaxExpr(Kind, Ops);
4319 }
4320 }
4321
4322 // Okay, it looks like we really DO need an expr. Check to see if we
4323 // already have one, otherwise create a new one.
4325 ID.AddInteger(Kind);
4326 for (const SCEV *Op : Ops)
4327 ID.AddPointer(Op);
4328 void *IP = nullptr;
4329 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4330 if (ExistingSCEV)
4331 return ExistingSCEV;
4332
4333 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4334 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4335 SCEV *S = new (SCEVAllocator)
4336 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4337
4338 UniqueSCEVs.InsertNode(S, IP);
4339 registerUser(S, Ops);
4340 return S;
4341}
4342
4343const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4345 return getSMaxExpr(Ops);
4346}
4347
4349 return getMinMaxExpr(scSMaxExpr, Ops);
4350}
4351
4352const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4354 return getUMaxExpr(Ops);
4355}
4356
4358 return getMinMaxExpr(scUMaxExpr, Ops);
4359}
4360
4362 const SCEV *RHS) {
4364 return getSMinExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMinExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4372 bool Sequential) {
4374 return getUMinExpr(Ops, Sequential);
4375}
4376
4378 bool Sequential) {
4379 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4380 : getMinMaxExpr(scUMinExpr, Ops);
4381}
4382
4383const SCEV *
4385 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4386 if (Size.isScalable())
4387 Res = getMulExpr(Res, getVScale(IntTy));
4388 return Res;
4389}
4390
4392 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4393}
4394
4396 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4397}
4398
4400 StructType *STy,
4401 unsigned FieldNo) {
4402 // We can bypass creating a target-independent constant expression and then
4403 // folding it back into a ConstantInt. This is just a compile-time
4404 // optimization.
4405 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4406 assert(!SL->getSizeInBits().isScalable() &&
4407 "Cannot get offset for structure containing scalable vector types");
4408 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4409}
4410
4412 // Don't attempt to do anything other than create a SCEVUnknown object
4413 // here. createSCEV only calls getUnknown after checking for all other
4414 // interesting possibilities, and any other code that calls getUnknown
4415 // is doing so in order to hide a value from SCEV canonicalization.
4416
4418 ID.AddInteger(scUnknown);
4419 ID.AddPointer(V);
4420 void *IP = nullptr;
4421 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4422 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4423 "Stale SCEVUnknown in uniquing map!");
4424 return S;
4425 }
4426 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4427 FirstUnknown);
4428 FirstUnknown = cast<SCEVUnknown>(S);
4429 UniqueSCEVs.InsertNode(S, IP);
4430 return S;
4431}
4432
4433//===----------------------------------------------------------------------===//
4434// Basic SCEV Analysis and PHI Idiom Recognition Code
4435//
4436
4437/// Test if values of the given type are analyzable within the SCEV
4438/// framework. This primarily includes integer types, and it can optionally
4439/// include pointer types if the ScalarEvolution class has access to
4440/// target-specific information.
4442 // Integers and pointers are always SCEVable.
4443 return Ty->isIntOrPtrTy();
4444}
4445
4446/// Return the size in bits of the specified type, for which isSCEVable must
4447/// return true.
4449 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4450 if (Ty->isPointerTy())
4452 return getDataLayout().getTypeSizeInBits(Ty);
4453}
4454
4455/// Return a type with the same bitwidth as the given type and which represents
4456/// how SCEV will treat the given type, for which isSCEVable must return
4457/// true. For pointer types, this is the pointer index sized integer type.
4459 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4460
4461 if (Ty->isIntegerTy())
4462 return Ty;
4463
4464 // The only other support type is pointer.
4465 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4466 return getDataLayout().getIndexType(Ty);
4467}
4468
4470 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4471}
4472
4474 const SCEV *B) {
4475 /// For a valid use point to exist, the defining scope of one operand
4476 /// must dominate the other.
4477 bool PreciseA, PreciseB;
4478 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4479 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4480 if (!PreciseA || !PreciseB)
4481 // Can't tell.
4482 return false;
4483 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4484 DT.dominates(ScopeB, ScopeA);
4485}
4486
4488 return CouldNotCompute.get();
4489}
4490
4491bool ScalarEvolution::checkValidity(const SCEV *S) const {
4492 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4493 auto *SU = dyn_cast<SCEVUnknown>(S);
4494 return SU && SU->getValue() == nullptr;
4495 });
4496
4497 return !ContainsNulls;
4498}
4499
4501 HasRecMapType::iterator I = HasRecMap.find(S);
4502 if (I != HasRecMap.end())
4503 return I->second;
4504
4505 bool FoundAddRec =
4506 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4507 HasRecMap.insert({S, FoundAddRec});
4508 return FoundAddRec;
4509}
4510
4511/// Return the ValueOffsetPair set for \p S. \p S can be represented
4512/// by the value and offset from any ValueOffsetPair in the set.
4513ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4514 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4515 if (SI == ExprValueMap.end())
4516 return {};
4517 return SI->second.getArrayRef();
4518}
4519
4520/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4521/// cannot be used separately. eraseValueFromMap should be used to remove
4522/// V from ValueExprMap and ExprValueMap at the same time.
4523void ScalarEvolution::eraseValueFromMap(Value *V) {
4524 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4525 if (I != ValueExprMap.end()) {
4526 auto EVIt = ExprValueMap.find(I->second);
4527 bool Removed = EVIt->second.remove(V);
4528 (void) Removed;
4529 assert(Removed && "Value not in ExprValueMap?");
4530 ValueExprMap.erase(I);
4531 }
4532}
4533
4534void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4535 // A recursive query may have already computed the SCEV. It should be
4536 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4537 // inferred nowrap flags.
4538 auto It = ValueExprMap.find_as(V);
4539 if (It == ValueExprMap.end()) {
4540 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4541 ExprValueMap[S].insert(V);
4542 }
4543}
4544
4545/// Return an existing SCEV if it exists, otherwise analyze the expression and
4546/// create a new one.
4548 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4549
4550 if (const SCEV *S = getExistingSCEV(V))
4551 return S;
4552 return createSCEVIter(V);
4553}
4554
4556 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4557
4558 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4559 if (I != ValueExprMap.end()) {
4560 const SCEV *S = I->second;
4561 assert(checkValidity(S) &&
4562 "existing SCEV has not been properly invalidated");
4563 return S;
4564 }
4565 return nullptr;
4566}
4567
4568/// Return a SCEV corresponding to -V = -1*V
4570 SCEV::NoWrapFlags Flags) {
4571 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4572 return getConstant(
4573 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4574
4575 Type *Ty = V->getType();
4576 Ty = getEffectiveSCEVType(Ty);
4577 return getMulExpr(V, getMinusOne(Ty), Flags);
4578}
4579
4580/// If Expr computes ~A, return A else return nullptr
4581static const SCEV *MatchNotExpr(const SCEV *Expr) {
4582 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4583 if (!Add || Add->getNumOperands() != 2 ||
4584 !Add->getOperand(0)->isAllOnesValue())
4585 return nullptr;
4586
4587 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4588 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4589 !AddRHS->getOperand(0)->isAllOnesValue())
4590 return nullptr;
4591
4592 return AddRHS->getOperand(1);
4593}
4594
4595/// Return a SCEV corresponding to ~V = -1-V
4597 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4598
4599 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4600 return getConstant(
4601 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4602
4603 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4604 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4605 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4606 SmallVector<const SCEV *, 2> MatchedOperands;
4607 for (const SCEV *Operand : MME->operands()) {
4608 const SCEV *Matched = MatchNotExpr(Operand);
4609 if (!Matched)
4610 return (const SCEV *)nullptr;
4611 MatchedOperands.push_back(Matched);
4612 }
4613 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4614 MatchedOperands);
4615 };
4616 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4617 return Replaced;
4618 }
4619
4620 Type *Ty = V->getType();
4621 Ty = getEffectiveSCEVType(Ty);
4622 return getMinusSCEV(getMinusOne(Ty), V);
4623}
4624
4626 assert(P->getType()->isPointerTy());
4627
4628 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4629 // The base of an AddRec is the first operand.
4630 SmallVector<const SCEV *> Ops{AddRec->operands()};
4631 Ops[0] = removePointerBase(Ops[0]);
4632 // Don't try to transfer nowrap flags for now. We could in some cases
4633 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4634 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4635 }
4636 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4637 // The base of an Add is the pointer operand.
4638 SmallVector<const SCEV *> Ops{Add->operands()};
4639 const SCEV **PtrOp = nullptr;
4640 for (const SCEV *&AddOp : Ops) {
4641 if (AddOp->getType()->isPointerTy()) {
4642 assert(!PtrOp && "Cannot have multiple pointer ops");
4643 PtrOp = &AddOp;
4644 }
4645 }
4646 *PtrOp = removePointerBase(*PtrOp);
4647 // Don't try to transfer nowrap flags for now. We could in some cases
4648 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4649 return getAddExpr(Ops);
4650 }
4651 // Any other expression must be a pointer base.
4652 return getZero(P->getType());
4653}
4654
4655const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4656 SCEV::NoWrapFlags Flags,
4657 unsigned Depth) {
4658 // Fast path: X - X --> 0.
4659 if (LHS == RHS)
4660 return getZero(LHS->getType());
4661
4662 // If we subtract two pointers with different pointer bases, bail.
4663 // Eventually, we're going to add an assertion to getMulExpr that we
4664 // can't multiply by a pointer.
4665 if (RHS->getType()->isPointerTy()) {
4666 if (!LHS->getType()->isPointerTy() ||
4668 return getCouldNotCompute();
4671 }
4672
4673 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4674 // makes it so that we cannot make much use of NUW.
4675 auto AddFlags = SCEV::FlagAnyWrap;
4676 const bool RHSIsNotMinSigned =
4678 if (hasFlags(Flags, SCEV::FlagNSW)) {
4679 // Let M be the minimum representable signed value. Then (-1)*RHS
4680 // signed-wraps if and only if RHS is M. That can happen even for
4681 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4682 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4683 // (-1)*RHS, we need to prove that RHS != M.
4684 //
4685 // If LHS is non-negative and we know that LHS - RHS does not
4686 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4687 // either by proving that RHS > M or that LHS >= 0.
4688 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4689 AddFlags = SCEV::FlagNSW;
4690 }
4691 }
4692
4693 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4694 // RHS is NSW and LHS >= 0.
4695 //
4696 // The difficulty here is that the NSW flag may have been proven
4697 // relative to a loop that is to be found in a recurrence in LHS and
4698 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4699 // larger scope than intended.
4700 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4701
4702 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4703}
4704
4706 unsigned Depth) {
4707 Type *SrcTy = V->getType();
4708 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4709 "Cannot truncate or zero extend with non-integer arguments!");
4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4711 return V; // No conversion
4712 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4713 return getTruncateExpr(V, Ty, Depth);
4714 return getZeroExtendExpr(V, Ty, Depth);
4715}
4716
4718 unsigned Depth) {
4719 Type *SrcTy = V->getType();
4720 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4721 "Cannot truncate or zero extend with non-integer arguments!");
4722 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4723 return V; // No conversion
4724 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4725 return getTruncateExpr(V, Ty, Depth);
4726 return getSignExtendExpr(V, Ty, Depth);
4727}
4728
4729const SCEV *
4731 Type *SrcTy = V->getType();
4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4733 "Cannot noop or zero extend with non-integer arguments!");
4735 "getNoopOrZeroExtend cannot truncate!");
4736 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4737 return V; // No conversion
4738 return getZeroExtendExpr(V, Ty);
4739}
4740
4741const SCEV *
4743 Type *SrcTy = V->getType();
4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4745 "Cannot noop or sign extend with non-integer arguments!");
4747 "getNoopOrSignExtend cannot truncate!");
4748 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4749 return V; // No conversion
4750 return getSignExtendExpr(V, Ty);
4751}
4752
4753const SCEV *
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot noop or any extend with non-integer arguments!");
4759 "getNoopOrAnyExtend cannot truncate!");
4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4761 return V; // No conversion
4762 return getAnyExtendExpr(V, Ty);
4763}
4764
4765const SCEV *
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot truncate or noop with non-integer arguments!");
4771 "getTruncateOrNoop cannot extend!");
4772 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4773 return V; // No conversion
4774 return getTruncateExpr(V, Ty);
4775}
4776
4778 const SCEV *RHS) {
4779 const SCEV *PromotedLHS = LHS;
4780 const SCEV *PromotedRHS = RHS;
4781
4783 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4784 else
4785 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4786
4787 return getUMaxExpr(PromotedLHS, PromotedRHS);
4788}
4789
4791 const SCEV *RHS,
4792 bool Sequential) {
4794 return getUMinFromMismatchedTypes(Ops, Sequential);
4795}
4796
4797const SCEV *
4799 bool Sequential) {
4800 assert(!Ops.empty() && "At least one operand must be!");
4801 // Trivial case.
4802 if (Ops.size() == 1)
4803 return Ops[0];
4804
4805 // Find the max type first.
4806 Type *MaxType = nullptr;
4807 for (const auto *S : Ops)
4808 if (MaxType)
4809 MaxType = getWiderType(MaxType, S->getType());
4810 else
4811 MaxType = S->getType();
4812 assert(MaxType && "Failed to find maximum type!");
4813
4814 // Extend all ops to max type.
4815 SmallVector<const SCEV *, 2> PromotedOps;
4816 for (const auto *S : Ops)
4817 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4818
4819 // Generate umin.
4820 return getUMinExpr(PromotedOps, Sequential);
4821}
4822
4824 // A pointer operand may evaluate to a nonpointer expression, such as null.
4825 if (!V->getType()->isPointerTy())
4826 return V;
4827
4828 while (true) {
4829 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4830 V = AddRec->getStart();
4831 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4832 const SCEV *PtrOp = nullptr;
4833 for (const SCEV *AddOp : Add->operands()) {
4834 if (AddOp->getType()->isPointerTy()) {
4835 assert(!PtrOp && "Cannot have multiple pointer ops");
4836 PtrOp = AddOp;
4837 }
4838 }
4839 assert(PtrOp && "Must have pointer op");
4840 V = PtrOp;
4841 } else // Not something we can look further into.
4842 return V;
4843 }
4844}
4845
4846/// Push users of the given Instruction onto the given Worklist.
4850 // Push the def-use children onto the Worklist stack.
4851 for (User *U : I->users()) {
4852 auto *UserInsn = cast<Instruction>(U);
4853 if (Visited.insert(UserInsn).second)
4854 Worklist.push_back(UserInsn);
4855 }
4856}
4857
4858namespace {
4859
4860/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4861/// expression in case its Loop is L. If it is not L then
4862/// if IgnoreOtherLoops is true then use AddRec itself
4863/// otherwise rewrite cannot be done.
4864/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4865class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4866public:
4867 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4868 bool IgnoreOtherLoops = true) {
4869 SCEVInitRewriter Rewriter(L, SE);
4870 const SCEV *Result = Rewriter.visit(S);
4871 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4872 return SE.getCouldNotCompute();
4873 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4874 ? SE.getCouldNotCompute()
4875 : Result;
4876 }
4877
4878 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4879 if (!SE.isLoopInvariant(Expr, L))
4880 SeenLoopVariantSCEVUnknown = true;
4881 return Expr;
4882 }
4883
4884 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4885 // Only re-write AddRecExprs for this loop.
4886 if (Expr->getLoop() == L)
4887 return Expr->getStart();
4888 SeenOtherLoops = true;
4889 return Expr;
4890 }
4891
4892 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4893
4894 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4895
4896private:
4897 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4898 : SCEVRewriteVisitor(SE), L(L) {}
4899
4900 const Loop *L;
4901 bool SeenLoopVariantSCEVUnknown = false;
4902 bool SeenOtherLoops = false;
4903};
4904
4905/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4906/// increment expression in case its Loop is L. If it is not L then
4907/// use AddRec itself.
4908/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4909class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4910public:
4911 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4912 SCEVPostIncRewriter Rewriter(L, SE);
4913 const SCEV *Result = Rewriter.visit(S);
4914 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4915 ? SE.getCouldNotCompute()
4916 : Result;
4917 }
4918
4919 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4920 if (!SE.isLoopInvariant(Expr, L))
4921 SeenLoopVariantSCEVUnknown = true;
4922 return Expr;
4923 }
4924
4925 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4926 // Only re-write AddRecExprs for this loop.
4927 if (Expr->getLoop() == L)
4928 return Expr->getPostIncExpr(SE);
4929 SeenOtherLoops = true;
4930 return Expr;
4931 }
4932
4933 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4934
4935 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4936
4937private:
4938 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4939 : SCEVRewriteVisitor(SE), L(L) {}
4940
4941 const Loop *L;
4942 bool SeenLoopVariantSCEVUnknown = false;
4943 bool SeenOtherLoops = false;
4944};
4945
4946/// This class evaluates the compare condition by matching it against the
4947/// condition of loop latch. If there is a match we assume a true value
4948/// for the condition while building SCEV nodes.
4949class SCEVBackedgeConditionFolder
4950 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4951public:
4952 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4953 ScalarEvolution &SE) {
4954 bool IsPosBECond = false;
4955 Value *BECond = nullptr;
4956 if (BasicBlock *Latch = L->getLoopLatch()) {
4957 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4958 if (BI && BI->isConditional()) {
4959 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4960 "Both outgoing branches should not target same header!");
4961 BECond = BI->getCondition();
4962 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4963 } else {
4964 return S;
4965 }
4966 }
4967 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4968 return Rewriter.visit(S);
4969 }
4970
4971 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4972 const SCEV *Result = Expr;
4973 bool InvariantF = SE.isLoopInvariant(Expr, L);
4974
4975 if (!InvariantF) {
4976 Instruction *I = cast<Instruction>(Expr->getValue());
4977 switch (I->getOpcode()) {
4978 case Instruction::Select: {
4979 SelectInst *SI = cast<SelectInst>(I);
4980 std::optional<const SCEV *> Res =
4981 compareWithBackedgeCondition(SI->getCondition());
4982 if (Res) {
4983 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4984 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4985 }
4986 break;
4987 }
4988 default: {
4989 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4990 if (Res)
4991 Result = *Res;
4992 break;
4993 }
4994 }
4995 }
4996 return Result;
4997 }
4998
4999private:
5000 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5001 bool IsPosBECond, ScalarEvolution &SE)
5002 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5003 IsPositiveBECond(IsPosBECond) {}
5004
5005 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5006
5007 const Loop *L;
5008 /// Loop back condition.
5009 Value *BackedgeCond = nullptr;
5010 /// Set to true if loop back is on positive branch condition.
5011 bool IsPositiveBECond;
5012};
5013
5014std::optional<const SCEV *>
5015SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5016
5017 // If value matches the backedge condition for loop latch,
5018 // then return a constant evolution node based on loopback
5019 // branch taken.
5020 if (BackedgeCond == IC)
5021 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5023 return std::nullopt;
5024}
5025
5026class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5027public:
5028 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5029 ScalarEvolution &SE) {
5030 SCEVShiftRewriter Rewriter(L, SE);
5031 const SCEV *Result = Rewriter.visit(S);
5032 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5033 }
5034
5035 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5036 // Only allow AddRecExprs for this loop.
5037 if (!SE.isLoopInvariant(Expr, L))
5038 Valid = false;
5039 return Expr;
5040 }
5041
5042 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5043 if (Expr->getLoop() == L && Expr->isAffine())
5044 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5045 Valid = false;
5046 return Expr;
5047 }
5048
5049 bool isValid() { return Valid; }
5050
5051private:
5052 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5053 : SCEVRewriteVisitor(SE), L(L) {}
5054
5055 const Loop *L;
5056 bool Valid = true;
5057};
5058
5059} // end anonymous namespace
5060
5062ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5063 if (!AR->isAffine())
5064 return SCEV::FlagAnyWrap;
5065
5066 using OBO = OverflowingBinaryOperator;
5067
5069
5070 if (!AR->hasNoSelfWrap()) {
5071 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5072 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5073 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5074 const APInt &BECountAP = BECountMax->getAPInt();
5075 unsigned NoOverflowBitWidth =
5076 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5077 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5079 }
5080 }
5081
5082 if (!AR->hasNoSignedWrap()) {
5083 ConstantRange AddRecRange = getSignedRange(AR);
5084 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5085
5087 Instruction::Add, IncRange, OBO::NoSignedWrap);
5088 if (NSWRegion.contains(AddRecRange))
5090 }
5091
5092 if (!AR->hasNoUnsignedWrap()) {
5093 ConstantRange AddRecRange = getUnsignedRange(AR);
5094 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5095
5097 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5098 if (NUWRegion.contains(AddRecRange))
5100 }
5101
5102 return Result;
5103}
5104
5106ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5108
5109 if (AR->hasNoSignedWrap())
5110 return Result;
5111
5112 if (!AR->isAffine())
5113 return Result;
5114
5115 // This function can be expensive, only try to prove NSW once per AddRec.
5116 if (!SignedWrapViaInductionTried.insert(AR).second)
5117 return Result;
5118
5119 const SCEV *Step = AR->getStepRecurrence(*this);
5120 const Loop *L = AR->getLoop();
5121
5122 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5123 // Note that this serves two purposes: It filters out loops that are
5124 // simply not analyzable, and it covers the case where this code is
5125 // being called from within backedge-taken count analysis, such that
5126 // attempting to ask for the backedge-taken count would likely result
5127 // in infinite recursion. In the later case, the analysis code will
5128 // cope with a conservative value, and it will take care to purge
5129 // that value once it has finished.
5130 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5131
5132 // Normally, in the cases we can prove no-overflow via a
5133 // backedge guarding condition, we can also compute a backedge
5134 // taken count for the loop. The exceptions are assumptions and
5135 // guards present in the loop -- SCEV is not great at exploiting
5136 // these to compute max backedge taken counts, but can still use
5137 // these to prove lack of overflow. Use this fact to avoid
5138 // doing extra work that may not pay off.
5139
5140 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5141 AC.assumptions().empty())
5142 return Result;
5143
5144 // If the backedge is guarded by a comparison with the pre-inc value the
5145 // addrec is safe. Also, if the entry is guarded by a comparison with the
5146 // start value and the backedge is guarded by a comparison with the post-inc
5147 // value, the addrec is safe.
5149 const SCEV *OverflowLimit =
5150 getSignedOverflowLimitForStep(Step, &Pred, this);
5151 if (OverflowLimit &&
5152 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5153 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5154 Result = setFlags(Result, SCEV::FlagNSW);
5155 }
5156 return Result;
5157}
5159ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5161
5162 if (AR->hasNoUnsignedWrap())
5163 return Result;
5164
5165 if (!AR->isAffine())
5166 return Result;
5167
5168 // This function can be expensive, only try to prove NUW once per AddRec.
5169 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5170 return Result;
5171
5172 const SCEV *Step = AR->getStepRecurrence(*this);
5173 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5174 const Loop *L = AR->getLoop();
5175
5176 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5177 // Note that this serves two purposes: It filters out loops that are
5178 // simply not analyzable, and it covers the case where this code is
5179 // being called from within backedge-taken count analysis, such that
5180 // attempting to ask for the backedge-taken count would likely result
5181 // in infinite recursion. In the later case, the analysis code will
5182 // cope with a conservative value, and it will take care to purge
5183 // that value once it has finished.
5184 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5185
5186 // Normally, in the cases we can prove no-overflow via a
5187 // backedge guarding condition, we can also compute a backedge
5188 // taken count for the loop. The exceptions are assumptions and
5189 // guards present in the loop -- SCEV is not great at exploiting
5190 // these to compute max backedge taken counts, but can still use
5191 // these to prove lack of overflow. Use this fact to avoid
5192 // doing extra work that may not pay off.
5193
5194 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5195 AC.assumptions().empty())
5196 return Result;
5197
5198 // If the backedge is guarded by a comparison with the pre-inc value the
5199 // addrec is safe. Also, if the entry is guarded by a comparison with the
5200 // start value and the backedge is guarded by a comparison with the post-inc
5201 // value, the addrec is safe.
5202 if (isKnownPositive(Step)) {
5204 getUnsignedRangeMax(Step));
5207 Result = setFlags(Result, SCEV::FlagNUW);
5208 }
5209 }
5210
5211 return Result;
5212}
5213
5214namespace {
5215
5216/// Represents an abstract binary operation. This may exist as a
5217/// normal instruction or constant expression, or may have been
5218/// derived from an expression tree.
5219struct BinaryOp {
5220 unsigned Opcode;
5221 Value *LHS;
5222 Value *RHS;
5223 bool IsNSW = false;
5224 bool IsNUW = false;
5225
5226 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5227 /// constant expression.
5228 Operator *Op = nullptr;
5229
5230 explicit BinaryOp(Operator *Op)
5231 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5232 Op(Op) {
5233 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5234 IsNSW = OBO->hasNoSignedWrap();
5235 IsNUW = OBO->hasNoUnsignedWrap();
5236 }
5237 }
5238
5239 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5240 bool IsNUW = false)
5241 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5242};
5243
5244} // end anonymous namespace
5245
5246/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5247static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5248 AssumptionCache &AC,
5249 const DominatorTree &DT,
5250 const Instruction *CxtI) {
5251 auto *Op = dyn_cast<Operator>(V);
5252 if (!Op)
5253 return std::nullopt;
5254
5255 // Implementation detail: all the cleverness here should happen without
5256 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5257 // SCEV expressions when possible, and we should not break that.
5258
5259 switch (Op->getOpcode()) {
5260 case Instruction::Add:
5261 case Instruction::Sub:
5262 case Instruction::Mul:
5263 case Instruction::UDiv:
5264 case Instruction::URem:
5265 case Instruction::And:
5266 case Instruction::AShr:
5267 case Instruction::Shl:
5268 return BinaryOp(Op);
5269
5270 case Instruction::Or: {
5271 // Convert or disjoint into add nuw nsw.
5272 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5273 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5274 /*IsNSW=*/true, /*IsNUW=*/true);
5275 return BinaryOp(Op);
5276 }
5277
5278 case Instruction::Xor:
5279 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5280 // If the RHS of the xor is a signmask, then this is just an add.
5281 // Instcombine turns add of signmask into xor as a strength reduction step.
5282 if (RHSC->getValue().isSignMask())
5283 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5284 // Binary `xor` is a bit-wise `add`.
5285 if (V->getType()->isIntegerTy(1))
5286 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5287 return BinaryOp(Op);
5288
5289 case Instruction::LShr:
5290 // Turn logical shift right of a constant into a unsigned divide.
5291 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5292 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5293
5294 // If the shift count is not less than the bitwidth, the result of
5295 // the shift is undefined. Don't try to analyze it, because the
5296 // resolution chosen here may differ from the resolution chosen in
5297 // other parts of the compiler.
5298 if (SA->getValue().ult(BitWidth)) {
5299 Constant *X =
5300 ConstantInt::get(SA->getContext(),
5301 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5302 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5303 }
5304 }
5305 return BinaryOp(Op);
5306
5307 case Instruction::ExtractValue: {
5308 auto *EVI = cast<ExtractValueInst>(Op);
5309 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5310 break;
5311
5312 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5313 if (!WO)
5314 break;
5315
5316 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5317 bool Signed = WO->isSigned();
5318 // TODO: Should add nuw/nsw flags for mul as well.
5319 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5320 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5321
5322 // Now that we know that all uses of the arithmetic-result component of
5323 // CI are guarded by the overflow check, we can go ahead and pretend
5324 // that the arithmetic is non-overflowing.
5325 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5326 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5327 }
5328
5329 default:
5330 break;
5331 }
5332
5333 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5334 // semantics as a Sub, return a binary sub expression.
5335 if (auto *II = dyn_cast<IntrinsicInst>(V))
5336 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5337 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5338
5339 return std::nullopt;
5340}
5341
5342/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5343/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5344/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5345/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5346/// follows one of the following patterns:
5347/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5348/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5349/// If the SCEV expression of \p Op conforms with one of the expected patterns
5350/// we return the type of the truncation operation, and indicate whether the
5351/// truncated type should be treated as signed/unsigned by setting
5352/// \p Signed to true/false, respectively.
5353static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5354 bool &Signed, ScalarEvolution &SE) {
5355 // The case where Op == SymbolicPHI (that is, with no type conversions on
5356 // the way) is handled by the regular add recurrence creating logic and
5357 // would have already been triggered in createAddRecForPHI. Reaching it here
5358 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5359 // because one of the other operands of the SCEVAddExpr updating this PHI is
5360 // not invariant).
5361 //
5362 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5363 // this case predicates that allow us to prove that Op == SymbolicPHI will
5364 // be added.
5365 if (Op == SymbolicPHI)
5366 return nullptr;
5367
5368 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5369 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5370 if (SourceBits != NewBits)
5371 return nullptr;
5372
5373 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5374 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5375 if (!SExt && !ZExt)
5376 return nullptr;
5377 const SCEVTruncateExpr *Trunc =
5378 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5379 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5380 if (!Trunc)
5381 return nullptr;
5382 const SCEV *X = Trunc->getOperand();
5383 if (X != SymbolicPHI)
5384 return nullptr;
5385 Signed = SExt != nullptr;
5386 return Trunc->getType();
5387}
5388
5389static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5390 if (!PN->getType()->isIntegerTy())
5391 return nullptr;
5392 const Loop *L = LI.getLoopFor(PN->getParent());
5393 if (!L || L->getHeader() != PN->getParent())
5394 return nullptr;
5395 return L;
5396}
5397
5398// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5399// computation that updates the phi follows the following pattern:
5400// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5401// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5402// If so, try to see if it can be rewritten as an AddRecExpr under some
5403// Predicates. If successful, return them as a pair. Also cache the results
5404// of the analysis.
5405//
5406// Example usage scenario:
5407// Say the Rewriter is called for the following SCEV:
5408// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5409// where:
5410// %X = phi i64 (%Start, %BEValue)
5411// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5412// and call this function with %SymbolicPHI = %X.
5413//
5414// The analysis will find that the value coming around the backedge has
5415// the following SCEV:
5416// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5417// Upon concluding that this matches the desired pattern, the function
5418// will return the pair {NewAddRec, SmallPredsVec} where:
5419// NewAddRec = {%Start,+,%Step}
5420// SmallPredsVec = {P1, P2, P3} as follows:
5421// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5422// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5423// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5424// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5425// under the predicates {P1,P2,P3}.
5426// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5427// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5428//
5429// TODO's:
5430//
5431// 1) Extend the Induction descriptor to also support inductions that involve
5432// casts: When needed (namely, when we are called in the context of the
5433// vectorizer induction analysis), a Set of cast instructions will be
5434// populated by this method, and provided back to isInductionPHI. This is
5435// needed to allow the vectorizer to properly record them to be ignored by
5436// the cost model and to avoid vectorizing them (otherwise these casts,
5437// which are redundant under the runtime overflow checks, will be
5438// vectorized, which can be costly).
5439//
5440// 2) Support additional induction/PHISCEV patterns: We also want to support
5441// inductions where the sext-trunc / zext-trunc operations (partly) occur
5442// after the induction update operation (the induction increment):
5443//
5444// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5445// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5446//
5447// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5448// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5449//
5450// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5451std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5452ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5454
5455 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5456 // return an AddRec expression under some predicate.
5457
5458 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5459 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5460 assert(L && "Expecting an integer loop header phi");
5461
5462 // The loop may have multiple entrances or multiple exits; we can analyze
5463 // this phi as an addrec if it has a unique entry value and a unique
5464 // backedge value.
5465 Value *BEValueV = nullptr, *StartValueV = nullptr;
5466 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5467 Value *V = PN->getIncomingValue(i);
5468 if (L->contains(PN->getIncomingBlock(i))) {
5469 if (!BEValueV) {
5470 BEValueV = V;
5471 } else if (BEValueV != V) {
5472 BEValueV = nullptr;
5473 break;
5474 }
5475 } else if (!StartValueV) {
5476 StartValueV = V;
5477 } else if (StartValueV != V) {
5478 StartValueV = nullptr;
5479 break;
5480 }
5481 }
5482 if (!BEValueV || !StartValueV)
5483 return std::nullopt;
5484
5485 const SCEV *BEValue = getSCEV(BEValueV);
5486
5487 // If the value coming around the backedge is an add with the symbolic
5488 // value we just inserted, possibly with casts that we can ignore under
5489 // an appropriate runtime guard, then we found a simple induction variable!
5490 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5491 if (!Add)
5492 return std::nullopt;
5493
5494 // If there is a single occurrence of the symbolic value, possibly
5495 // casted, replace it with a recurrence.
5496 unsigned FoundIndex = Add->getNumOperands();
5497 Type *TruncTy = nullptr;
5498 bool Signed;
5499 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5500 if ((TruncTy =
5501 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5502 if (FoundIndex == e) {
5503 FoundIndex = i;
5504 break;
5505 }
5506
5507 if (FoundIndex == Add->getNumOperands())
5508 return std::nullopt;
5509
5510 // Create an add with everything but the specified operand.
5512 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5513 if (i != FoundIndex)
5514 Ops.push_back(Add->getOperand(i));
5515 const SCEV *Accum = getAddExpr(Ops);
5516
5517 // The runtime checks will not be valid if the step amount is
5518 // varying inside the loop.
5519 if (!isLoopInvariant(Accum, L))
5520 return std::nullopt;
5521
5522 // *** Part2: Create the predicates
5523
5524 // Analysis was successful: we have a phi-with-cast pattern for which we
5525 // can return an AddRec expression under the following predicates:
5526 //
5527 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5528 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5529 // P2: An Equal predicate that guarantees that
5530 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5531 // P3: An Equal predicate that guarantees that
5532 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5533 //
5534 // As we next prove, the above predicates guarantee that:
5535 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5536 //
5537 //
5538 // More formally, we want to prove that:
5539 // Expr(i+1) = Start + (i+1) * Accum
5540 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5541 //
5542 // Given that:
5543 // 1) Expr(0) = Start
5544 // 2) Expr(1) = Start + Accum
5545 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5546 // 3) Induction hypothesis (step i):
5547 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5548 //
5549 // Proof:
5550 // Expr(i+1) =
5551 // = Start + (i+1)*Accum
5552 // = (Start + i*Accum) + Accum
5553 // = Expr(i) + Accum
5554 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5555 // :: from step i
5556 //
5557 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5558 //
5559 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5560 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5561 // + Accum :: from P3
5562 //
5563 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5564 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5565 //
5566 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5567 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5568 //
5569 // By induction, the same applies to all iterations 1<=i<n:
5570 //
5571
5572 // Create a truncated addrec for which we will add a no overflow check (P1).
5573 const SCEV *StartVal = getSCEV(StartValueV);
5574 const SCEV *PHISCEV =
5575 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5576 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5577
5578 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5579 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5580 // will be constant.
5581 //
5582 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5583 // add P1.
5584 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5588 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5589 Predicates.push_back(AddRecPred);
5590 }
5591
5592 // Create the Equal Predicates P2,P3:
5593
5594 // It is possible that the predicates P2 and/or P3 are computable at
5595 // compile time due to StartVal and/or Accum being constants.
5596 // If either one is, then we can check that now and escape if either P2
5597 // or P3 is false.
5598
5599 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5600 // for each of StartVal and Accum
5601 auto getExtendedExpr = [&](const SCEV *Expr,
5602 bool CreateSignExtend) -> const SCEV * {
5603 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5604 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5605 const SCEV *ExtendedExpr =
5606 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5607 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5608 return ExtendedExpr;
5609 };
5610
5611 // Given:
5612 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5613 // = getExtendedExpr(Expr)
5614 // Determine whether the predicate P: Expr == ExtendedExpr
5615 // is known to be false at compile time
5616 auto PredIsKnownFalse = [&](const SCEV *Expr,
5617 const SCEV *ExtendedExpr) -> bool {
5618 return Expr != ExtendedExpr &&
5619 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5620 };
5621
5622 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5623 if (PredIsKnownFalse(StartVal, StartExtended)) {
5624 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5625 return std::nullopt;
5626 }
5627
5628 // The Step is always Signed (because the overflow checks are either
5629 // NSSW or NUSW)
5630 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5631 if (PredIsKnownFalse(Accum, AccumExtended)) {
5632 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5633 return std::nullopt;
5634 }
5635
5636 auto AppendPredicate = [&](const SCEV *Expr,
5637 const SCEV *ExtendedExpr) -> void {
5638 if (Expr != ExtendedExpr &&
5639 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5640 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5641 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5642 Predicates.push_back(Pred);
5643 }
5644 };
5645
5646 AppendPredicate(StartVal, StartExtended);
5647 AppendPredicate(Accum, AccumExtended);
5648
5649 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5650 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5651 // into NewAR if it will also add the runtime overflow checks specified in
5652 // Predicates.
5653 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5654
5655 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5656 std::make_pair(NewAR, Predicates);
5657 // Remember the result of the analysis for this SCEV at this locayyytion.
5658 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5659 return PredRewrite;
5660}
5661
5662std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5664 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5665 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5666 if (!L)
5667 return std::nullopt;
5668
5669 // Check to see if we already analyzed this PHI.
5670 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5671 if (I != PredicatedSCEVRewrites.end()) {
5672 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5673 I->second;
5674 // Analysis was done before and failed to create an AddRec:
5675 if (Rewrite.first == SymbolicPHI)
5676 return std::nullopt;
5677 // Analysis was done before and succeeded to create an AddRec under
5678 // a predicate:
5679 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5680 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5681 return Rewrite;
5682 }
5683
5684 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5685 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5686
5687 // Record in the cache that the analysis failed
5688 if (!Rewrite) {
5690 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5691 return std::nullopt;
5692 }
5693
5694 return Rewrite;
5695}
5696
5697// FIXME: This utility is currently required because the Rewriter currently
5698// does not rewrite this expression:
5699// {0, +, (sext ix (trunc iy to ix) to iy)}
5700// into {0, +, %step},
5701// even when the following Equal predicate exists:
5702// "%step == (sext ix (trunc iy to ix) to iy)".
5704 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5705 if (AR1 == AR2)
5706 return true;
5707
5708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709 if (Expr1 != Expr2 &&
5710 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5712 return false;
5713 return true;
5714 };
5715
5716 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5717 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5718 return false;
5719 return true;
5720}
5721
5722/// A helper function for createAddRecFromPHI to handle simple cases.
5723///
5724/// This function tries to find an AddRec expression for the simplest (yet most
5725/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5726/// If it fails, createAddRecFromPHI will use a more general, but slow,
5727/// technique for finding the AddRec expression.
5728const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5729 Value *BEValueV,
5730 Value *StartValueV) {
5731 const Loop *L = LI.getLoopFor(PN->getParent());
5732 assert(L && L->getHeader() == PN->getParent());
5733 assert(BEValueV && StartValueV);
5734
5735 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5736 if (!BO)
5737 return nullptr;
5738
5739 if (BO->Opcode != Instruction::Add)
5740 return nullptr;
5741
5742 const SCEV *Accum = nullptr;
5743 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5744 Accum = getSCEV(BO->RHS);
5745 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5746 Accum = getSCEV(BO->LHS);
5747
5748 if (!Accum)
5749 return nullptr;
5750
5752 if (BO->IsNUW)
5753 Flags = setFlags(Flags, SCEV::FlagNUW);
5754 if (BO->IsNSW)
5755 Flags = setFlags(Flags, SCEV::FlagNSW);
5756
5757 const SCEV *StartVal = getSCEV(StartValueV);
5758 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5759 insertValueToMap(PN, PHISCEV);
5760
5761 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5762 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5764 proveNoWrapViaConstantRanges(AR)));
5765 }
5766
5767 // We can add Flags to the post-inc expression only if we
5768 // know that it is *undefined behavior* for BEValueV to
5769 // overflow.
5770 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5771 assert(isLoopInvariant(Accum, L) &&
5772 "Accum is defined outside L, but is not invariant?");
5773 if (isAddRecNeverPoison(BEInst, L))
5774 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5775 }
5776
5777 return PHISCEV;
5778}
5779
5780const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5781 const Loop *L = LI.getLoopFor(PN->getParent());
5782 if (!L || L->getHeader() != PN->getParent())
5783 return nullptr;
5784
5785 // The loop may have multiple entrances or multiple exits; we can analyze
5786 // this phi as an addrec if it has a unique entry value and a unique
5787 // backedge value.
5788 Value *BEValueV = nullptr, *StartValueV = nullptr;
5789 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5790 Value *V = PN->getIncomingValue(i);
5791 if (L->contains(PN->getIncomingBlock(i))) {
5792 if (!BEValueV) {
5793 BEValueV = V;
5794 } else if (BEValueV != V) {
5795 BEValueV = nullptr;
5796 break;
5797 }
5798 } else if (!StartValueV) {
5799 StartValueV = V;
5800 } else if (StartValueV != V) {
5801 StartValueV = nullptr;
5802 break;
5803 }
5804 }
5805 if (!BEValueV || !StartValueV)
5806 return nullptr;
5807
5808 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5809 "PHI node already processed?");
5810
5811 // First, try to find AddRec expression without creating a fictituos symbolic
5812 // value for PN.
5813 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5814 return S;
5815
5816 // Handle PHI node value symbolically.
5817 const SCEV *SymbolicName = getUnknown(PN);
5818 insertValueToMap(PN, SymbolicName);
5819
5820 // Using this symbolic name for the PHI, analyze the value coming around
5821 // the back-edge.
5822 const SCEV *BEValue = getSCEV(BEValueV);
5823
5824 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5825 // has a special value for the first iteration of the loop.
5826
5827 // If the value coming around the backedge is an add with the symbolic
5828 // value we just inserted, then we found a simple induction variable!
5829 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5830 // If there is a single occurrence of the symbolic value, replace it
5831 // with a recurrence.
5832 unsigned FoundIndex = Add->getNumOperands();
5833 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5834 if (Add->getOperand(i) == SymbolicName)
5835 if (FoundIndex == e) {
5836 FoundIndex = i;
5837 break;
5838 }
5839
5840 if (FoundIndex != Add->getNumOperands()) {
5841 // Create an add with everything but the specified operand.
5843 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5844 if (i != FoundIndex)
5845 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5846 L, *this));
5847 const SCEV *Accum = getAddExpr(Ops);
5848
5849 // This is not a valid addrec if the step amount is varying each
5850 // loop iteration, but is not itself an addrec in this loop.
5851 if (isLoopInvariant(Accum, L) ||
5852 (isa<SCEVAddRecExpr>(Accum) &&
5853 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5855
5856 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5857 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5858 if (BO->IsNUW)
5859 Flags = setFlags(Flags, SCEV::FlagNUW);
5860 if (BO->IsNSW)
5861 Flags = setFlags(Flags, SCEV::FlagNSW);
5862 }
5863 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5864 if (GEP->getOperand(0) == PN) {
5865 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5866 // If the increment has any nowrap flags, then we know the address
5867 // space cannot be wrapped around.
5868 if (NW != GEPNoWrapFlags::none())
5869 Flags = setFlags(Flags, SCEV::FlagNW);
5870 // If the GEP is nuw or nusw with non-negative offset, we know that
5871 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5872 // offset is treated as signed, while the base is unsigned.
5873 if (NW.hasNoUnsignedWrap() ||
5875 Flags = setFlags(Flags, SCEV::FlagNUW);
5876 }
5877
5878 // We cannot transfer nuw and nsw flags from subtraction
5879 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5880 // for instance.
5881 }
5882
5883 const SCEV *StartVal = getSCEV(StartValueV);
5884 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5885
5886 // Okay, for the entire analysis of this edge we assumed the PHI
5887 // to be symbolic. We now need to go back and purge all of the
5888 // entries for the scalars that use the symbolic expression.
5889 forgetMemoizedResults(SymbolicName);
5890 insertValueToMap(PN, PHISCEV);
5891
5892 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5893 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5895 proveNoWrapViaConstantRanges(AR)));
5896 }
5897
5898 // We can add Flags to the post-inc expression only if we
5899 // know that it is *undefined behavior* for BEValueV to
5900 // overflow.
5901 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5902 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5903 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5904
5905 return PHISCEV;
5906 }
5907 }
5908 } else {
5909 // Otherwise, this could be a loop like this:
5910 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5911 // In this case, j = {1,+,1} and BEValue is j.
5912 // Because the other in-value of i (0) fits the evolution of BEValue
5913 // i really is an addrec evolution.
5914 //
5915 // We can generalize this saying that i is the shifted value of BEValue
5916 // by one iteration:
5917 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5918
5919 // Do not allow refinement in rewriting of BEValue.
5920 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5921 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5922 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5923 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5924 const SCEV *StartVal = getSCEV(StartValueV);
5925 if (Start == StartVal) {
5926 // Okay, for the entire analysis of this edge we assumed the PHI
5927 // to be symbolic. We now need to go back and purge all of the
5928 // entries for the scalars that use the symbolic expression.
5929 forgetMemoizedResults(SymbolicName);
5930 insertValueToMap(PN, Shifted);
5931 return Shifted;
5932 }
5933 }
5934 }
5935
5936 // Remove the temporary PHI node SCEV that has been inserted while intending
5937 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5938 // as it will prevent later (possibly simpler) SCEV expressions to be added
5939 // to the ValueExprMap.
5940 eraseValueFromMap(PN);
5941
5942 return nullptr;
5943}
5944
5945// Try to match a control flow sequence that branches out at BI and merges back
5946// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5947// match.
5949 Value *&C, Value *&LHS, Value *&RHS) {
5950 C = BI->getCondition();
5951
5952 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5953 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5954
5955 if (!LeftEdge.isSingleEdge())
5956 return false;
5957
5958 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5959
5960 Use &LeftUse = Merge->getOperandUse(0);
5961 Use &RightUse = Merge->getOperandUse(1);
5962
5963 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5964 LHS = LeftUse;
5965 RHS = RightUse;
5966 return true;
5967 }
5968
5969 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5970 LHS = RightUse;
5971 RHS = LeftUse;
5972 return true;
5973 }
5974
5975 return false;
5976}
5977
5978const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5979 auto IsReachable =
5980 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5981 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5982 // Try to match
5983 //
5984 // br %cond, label %left, label %right
5985 // left:
5986 // br label %merge
5987 // right:
5988 // br label %merge
5989 // merge:
5990 // V = phi [ %x, %left ], [ %y, %right ]
5991 //
5992 // as "select %cond, %x, %y"
5993
5994 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5995 assert(IDom && "At least the entry block should dominate PN");
5996
5997 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5998 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5999
6000 if (BI && BI->isConditional() &&
6001 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6002 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6003 properlyDominates(getSCEV(RHS), PN->getParent()))
6004 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6005 }
6006
6007 return nullptr;
6008}
6009
6010/// Returns SCEV for the first operand of a phi if all phi operands have
6011/// identical opcodes and operands
6012/// eg.
6013/// a: %add = %a + %b
6014/// br %c
6015/// b: %add1 = %a + %b
6016/// br %c
6017/// c: %phi = phi [%add, a], [%add1, b]
6018/// scev(%phi) => scev(%add)
6019const SCEV *
6020ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6021 BinaryOperator *CommonInst = nullptr;
6022 // Check if instructions are identical.
6023 for (Value *Incoming : PN->incoming_values()) {
6024 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6025 if (!IncomingInst)
6026 return nullptr;
6027 if (CommonInst) {
6028 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6029 return nullptr; // Not identical, give up
6030 } else {
6031 // Remember binary operator
6032 CommonInst = IncomingInst;
6033 }
6034 }
6035 if (!CommonInst)
6036 return nullptr;
6037
6038 // Check if SCEV exprs for instructions are identical.
6039 const SCEV *CommonSCEV = getSCEV(CommonInst);
6040 bool SCEVExprsIdentical =
6042 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6043 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6044}
6045
6046const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6047 if (const SCEV *S = createAddRecFromPHI(PN))
6048 return S;
6049
6050 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6051 // phi node for X.
6052 if (Value *V = simplifyInstruction(
6053 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6054 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6055 return getSCEV(V);
6056
6057 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6058 return S;
6059
6060 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6061 return S;
6062
6063 // If it's not a loop phi, we can't handle it yet.
6064 return getUnknown(PN);
6065}
6066
6067bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6068 SCEVTypes RootKind) {
6069 struct FindClosure {
6070 const SCEV *OperandToFind;
6071 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6072 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6073
6074 bool Found = false;
6075
6076 bool canRecurseInto(SCEVTypes Kind) const {
6077 // We can only recurse into the SCEV expression of the same effective type
6078 // as the type of our root SCEV expression, and into zero-extensions.
6079 return RootKind == Kind || NonSequentialRootKind == Kind ||
6080 scZeroExtend == Kind;
6081 };
6082
6083 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6084 : OperandToFind(OperandToFind), RootKind(RootKind),
6085 NonSequentialRootKind(
6087 RootKind)) {}
6088
6089 bool follow(const SCEV *S) {
6090 Found = S == OperandToFind;
6091
6092 return !isDone() && canRecurseInto(S->getSCEVType());
6093 }
6094
6095 bool isDone() const { return Found; }
6096 };
6097
6098 FindClosure FC(OperandToFind, RootKind);
6099 visitAll(Root, FC);
6100 return FC.Found;
6101}
6102
6103std::optional<const SCEV *>
6104ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6105 ICmpInst *Cond,
6106 Value *TrueVal,
6107 Value *FalseVal) {
6108 // Try to match some simple smax or umax patterns.
6109 auto *ICI = Cond;
6110
6111 Value *LHS = ICI->getOperand(0);
6112 Value *RHS = ICI->getOperand(1);
6113
6114 switch (ICI->getPredicate()) {
6115 case ICmpInst::ICMP_SLT:
6116 case ICmpInst::ICMP_SLE:
6117 case ICmpInst::ICMP_ULT:
6118 case ICmpInst::ICMP_ULE:
6119 std::swap(LHS, RHS);
6120 [[fallthrough]];
6121 case ICmpInst::ICMP_SGT:
6122 case ICmpInst::ICMP_SGE:
6123 case ICmpInst::ICMP_UGT:
6124 case ICmpInst::ICMP_UGE:
6125 // a > b ? a+x : b+x -> max(a, b)+x
6126 // a > b ? b+x : a+x -> min(a, b)+x
6128 bool Signed = ICI->isSigned();
6129 const SCEV *LA = getSCEV(TrueVal);
6130 const SCEV *RA = getSCEV(FalseVal);
6131 const SCEV *LS = getSCEV(LHS);
6132 const SCEV *RS = getSCEV(RHS);
6133 if (LA->getType()->isPointerTy()) {
6134 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6135 // Need to make sure we can't produce weird expressions involving
6136 // negated pointers.
6137 if (LA == LS && RA == RS)
6138 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6139 if (LA == RS && RA == LS)
6140 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6141 }
6142 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6143 if (Op->getType()->isPointerTy()) {
6145 if (isa<SCEVCouldNotCompute>(Op))
6146 return Op;
6147 }
6148 if (Signed)
6149 Op = getNoopOrSignExtend(Op, Ty);
6150 else
6151 Op = getNoopOrZeroExtend(Op, Ty);
6152 return Op;
6153 };
6154 LS = CoerceOperand(LS);
6155 RS = CoerceOperand(RS);
6156 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6157 break;
6158 const SCEV *LDiff = getMinusSCEV(LA, LS);
6159 const SCEV *RDiff = getMinusSCEV(RA, RS);
6160 if (LDiff == RDiff)
6161 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6162 LDiff);
6163 LDiff = getMinusSCEV(LA, RS);
6164 RDiff = getMinusSCEV(RA, LS);
6165 if (LDiff == RDiff)
6166 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6167 LDiff);
6168 }
6169 break;
6170 case ICmpInst::ICMP_NE:
6171 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6172 std::swap(TrueVal, FalseVal);
6173 [[fallthrough]];
6174 case ICmpInst::ICMP_EQ:
6175 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6177 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6178 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6179 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6180 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6181 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6182 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6183 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6184 return getAddExpr(getUMaxExpr(X, C), Y);
6185 }
6186 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6187 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6188 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6189 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6190 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6191 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6192 const SCEV *X = getSCEV(LHS);
6193 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6194 X = ZExt->getOperand();
6195 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6196 const SCEV *FalseValExpr = getSCEV(FalseVal);
6197 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6198 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6199 /*Sequential=*/true);
6200 }
6201 }
6202 break;
6203 default:
6204 break;
6205 }
6206
6207 return std::nullopt;
6208}
6209
6210static std::optional<const SCEV *>
6212 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6213 assert(CondExpr->getType()->isIntegerTy(1) &&
6214 TrueExpr->getType() == FalseExpr->getType() &&
6215 TrueExpr->getType()->isIntegerTy(1) &&
6216 "Unexpected operands of a select.");
6217
6218 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6219 // --> C + (umin_seq cond, x - C)
6220 //
6221 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6222 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6223 // --> C + (umin_seq ~cond, x - C)
6224
6225 // FIXME: while we can't legally model the case where both of the hands
6226 // are fully variable, we only require that the *difference* is constant.
6227 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6228 return std::nullopt;
6229
6230 const SCEV *X, *C;
6231 if (isa<SCEVConstant>(TrueExpr)) {
6232 CondExpr = SE->getNotSCEV(CondExpr);
6233 X = FalseExpr;
6234 C = TrueExpr;
6235 } else {
6236 X = TrueExpr;
6237 C = FalseExpr;
6238 }
6239 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6240 /*Sequential=*/true));
6241}
6242
6243static std::optional<const SCEV *>
6245 Value *FalseVal) {
6246 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6247 return std::nullopt;
6248
6249 const auto *SECond = SE->getSCEV(Cond);
6250 const auto *SETrue = SE->getSCEV(TrueVal);
6251 const auto *SEFalse = SE->getSCEV(FalseVal);
6252 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6253}
6254
6255const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6256 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6257 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6258 assert(TrueVal->getType() == FalseVal->getType() &&
6259 V->getType() == TrueVal->getType() &&
6260 "Types of select hands and of the result must match.");
6261
6262 // For now, only deal with i1-typed `select`s.
6263 if (!V->getType()->isIntegerTy(1))
6264 return getUnknown(V);
6265
6266 if (std::optional<const SCEV *> S =
6267 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6268 return *S;
6269
6270 return getUnknown(V);
6271}
6272
6273const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6274 Value *TrueVal,
6275 Value *FalseVal) {
6276 // Handle "constant" branch or select. This can occur for instance when a
6277 // loop pass transforms an inner loop and moves on to process the outer loop.
6278 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6279 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6280
6281 if (auto *I = dyn_cast<Instruction>(V)) {
6282 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6283 if (std::optional<const SCEV *> S =
6284 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6285 TrueVal, FalseVal))
6286 return *S;
6287 }
6288 }
6289
6290 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6291}
6292
6293/// Expand GEP instructions into add and multiply operations. This allows them
6294/// to be analyzed by regular SCEV code.
6295const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6296 assert(GEP->getSourceElementType()->isSized() &&
6297 "GEP source element type must be sized");
6298
6300 for (Value *Index : GEP->indices())
6301 IndexExprs.push_back(getSCEV(Index));
6302 return getGEPExpr(GEP, IndexExprs);
6303}
6304
6305APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6307 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6308 return TrailingZeros >= BitWidth
6310 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6311 };
6312 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6313 // The result is GCD of all operands results.
6314 APInt Res = getConstantMultiple(N->getOperand(0));
6315 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6317 Res, getConstantMultiple(N->getOperand(I)));
6318 return Res;
6319 };
6320
6321 switch (S->getSCEVType()) {
6322 case scConstant:
6323 return cast<SCEVConstant>(S)->getAPInt();
6324 case scPtrToInt:
6325 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6326 case scUDivExpr:
6327 case scVScale:
6328 return APInt(BitWidth, 1);
6329 case scTruncate: {
6330 // Only multiples that are a power of 2 will hold after truncation.
6331 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6332 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6333 return GetShiftedByZeros(TZ);
6334 }
6335 case scZeroExtend: {
6336 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6337 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6338 }
6339 case scSignExtend: {
6340 // Only multiples that are a power of 2 will hold after sext.
6341 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6343 return GetShiftedByZeros(TZ);
6344 }
6345 case scMulExpr: {
6346 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6347 if (M->hasNoUnsignedWrap()) {
6348 // The result is the product of all operand results.
6349 APInt Res = getConstantMultiple(M->getOperand(0));
6350 for (const SCEV *Operand : M->operands().drop_front())
6351 Res = Res * getConstantMultiple(Operand);
6352 return Res;
6353 }
6354
6355 // If there are no wrap guarentees, find the trailing zeros, which is the
6356 // sum of trailing zeros for all its operands.
6357 uint32_t TZ = 0;
6358 for (const SCEV *Operand : M->operands())
6359 TZ += getMinTrailingZeros(Operand);
6360 return GetShiftedByZeros(TZ);
6361 }
6362 case scAddExpr:
6363 case scAddRecExpr: {
6364 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6365 if (N->hasNoUnsignedWrap())
6366 return GetGCDMultiple(N);
6367 // Find the trailing bits, which is the minimum of its operands.
6368 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6369 for (const SCEV *Operand : N->operands().drop_front())
6370 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6371 return GetShiftedByZeros(TZ);
6372 }
6373 case scUMaxExpr:
6374 case scSMaxExpr:
6375 case scUMinExpr:
6376 case scSMinExpr:
6378 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6379 case scUnknown: {
6380 // ask ValueTracking for known bits
6381 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6382 unsigned Known =
6383 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6384 .countMinTrailingZeros();
6385 return GetShiftedByZeros(Known);
6386 }
6387 case scCouldNotCompute:
6388 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6389 }
6390 llvm_unreachable("Unknown SCEV kind!");
6391}
6392
6394 auto I = ConstantMultipleCache.find(S);
6395 if (I != ConstantMultipleCache.end())
6396 return I->second;
6397
6398 APInt Result = getConstantMultipleImpl(S);
6399 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6400 assert(InsertPair.second && "Should insert a new key");
6401 return InsertPair.first->second;
6402}
6403
6405 APInt Multiple = getConstantMultiple(S);
6406 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6407}
6408
6410 return std::min(getConstantMultiple(S).countTrailingZeros(),
6411 (unsigned)getTypeSizeInBits(S->getType()));
6412}
6413
6414/// Helper method to assign a range to V from metadata present in the IR.
6415static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6416 if (Instruction *I = dyn_cast<Instruction>(V)) {
6417 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6418 return getConstantRangeFromMetadata(*MD);
6419 if (const auto *CB = dyn_cast<CallBase>(V))
6420 if (std::optional<ConstantRange> Range = CB->getRange())
6421 return Range;
6422 }
6423 if (auto *A = dyn_cast<Argument>(V))
6424 if (std::optional<ConstantRange> Range = A->getRange())
6425 return Range;
6426
6427 return std::nullopt;
6428}
6429
6431 SCEV::NoWrapFlags Flags) {
6432 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6433 AddRec->setNoWrapFlags(Flags);
6434 UnsignedRanges.erase(AddRec);
6435 SignedRanges.erase(AddRec);
6436 ConstantMultipleCache.erase(AddRec);
6437 }
6438}
6439
6440ConstantRange ScalarEvolution::
6441getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6442 const DataLayout &DL = getDataLayout();
6443
6444 unsigned BitWidth = getTypeSizeInBits(U->getType());
6445 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6446
6447 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6448 // use information about the trip count to improve our available range. Note
6449 // that the trip count independent cases are already handled by known bits.
6450 // WARNING: The definition of recurrence used here is subtly different than
6451 // the one used by AddRec (and thus most of this file). Step is allowed to
6452 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6453 // and other addrecs in the same loop (for non-affine addrecs). The code
6454 // below intentionally handles the case where step is not loop invariant.
6455 auto *P = dyn_cast<PHINode>(U->getValue());
6456 if (!P)
6457 return FullSet;
6458
6459 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6460 // even the values that are not available in these blocks may come from them,
6461 // and this leads to false-positive recurrence test.
6462 for (auto *Pred : predecessors(P->getParent()))
6463 if (!DT.isReachableFromEntry(Pred))
6464 return FullSet;
6465
6466 BinaryOperator *BO;
6467 Value *Start, *Step;
6468 if (!matchSimpleRecurrence(P, BO, Start, Step))
6469 return FullSet;
6470
6471 // If we found a recurrence in reachable code, we must be in a loop. Note
6472 // that BO might be in some subloop of L, and that's completely okay.
6473 auto *L = LI.getLoopFor(P->getParent());
6474 assert(L && L->getHeader() == P->getParent());
6475 if (!L->contains(BO->getParent()))
6476 // NOTE: This bailout should be an assert instead. However, asserting
6477 // the condition here exposes a case where LoopFusion is querying SCEV
6478 // with malformed loop information during the midst of the transform.
6479 // There doesn't appear to be an obvious fix, so for the moment bailout
6480 // until the caller issue can be fixed. PR49566 tracks the bug.
6481 return FullSet;
6482
6483 // TODO: Extend to other opcodes such as mul, and div
6484 switch (BO->getOpcode()) {
6485 default:
6486 return FullSet;
6487 case Instruction::AShr:
6488 case Instruction::LShr:
6489 case Instruction::Shl:
6490 break;
6491 };
6492
6493 if (BO->getOperand(0) != P)
6494 // TODO: Handle the power function forms some day.
6495 return FullSet;
6496
6497 unsigned TC = getSmallConstantMaxTripCount(L);
6498 if (!TC || TC >= BitWidth)
6499 return FullSet;
6500
6501 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6502 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6503 assert(KnownStart.getBitWidth() == BitWidth &&
6504 KnownStep.getBitWidth() == BitWidth);
6505
6506 // Compute total shift amount, being careful of overflow and bitwidths.
6507 auto MaxShiftAmt = KnownStep.getMaxValue();
6508 APInt TCAP(BitWidth, TC-1);
6509 bool Overflow = false;
6510 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6511 if (Overflow)
6512 return FullSet;
6513
6514 switch (BO->getOpcode()) {
6515 default:
6516 llvm_unreachable("filtered out above");
6517 case Instruction::AShr: {
6518 // For each ashr, three cases:
6519 // shift = 0 => unchanged value
6520 // saturation => 0 or -1
6521 // other => a value closer to zero (of the same sign)
6522 // Thus, the end value is closer to zero than the start.
6523 auto KnownEnd = KnownBits::ashr(KnownStart,
6524 KnownBits::makeConstant(TotalShift));
6525 if (KnownStart.isNonNegative())
6526 // Analogous to lshr (simply not yet canonicalized)
6527 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6528 KnownStart.getMaxValue() + 1);
6529 if (KnownStart.isNegative())
6530 // End >=u Start && End <=s Start
6531 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6532 KnownEnd.getMaxValue() + 1);
6533 break;
6534 }
6535 case Instruction::LShr: {
6536 // For each lshr, three cases:
6537 // shift = 0 => unchanged value
6538 // saturation => 0
6539 // other => a smaller positive number
6540 // Thus, the low end of the unsigned range is the last value produced.
6541 auto KnownEnd = KnownBits::lshr(KnownStart,
6542 KnownBits::makeConstant(TotalShift));
6543 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6544 KnownStart.getMaxValue() + 1);
6545 }
6546 case Instruction::Shl: {
6547 // Iff no bits are shifted out, value increases on every shift.
6548 auto KnownEnd = KnownBits::shl(KnownStart,
6549 KnownBits::makeConstant(TotalShift));
6550 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6551 return ConstantRange(KnownStart.getMinValue(),
6552 KnownEnd.getMaxValue() + 1);
6553 break;
6554 }
6555 };
6556 return FullSet;
6557}
6558
6559const ConstantRange &
6560ScalarEvolution::getRangeRefIter(const SCEV *S,
6561 ScalarEvolution::RangeSignHint SignHint) {
6563 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6564 : SignedRanges;
6567
6568 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6569 // SCEVUnknown PHI node.
6570 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6571 if (!Seen.insert(Expr).second)
6572 return;
6573 if (Cache.contains(Expr))
6574 return;
6575 switch (Expr->getSCEVType()) {
6576 case scUnknown:
6577 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6578 break;
6579 [[fallthrough]];
6580 case scConstant:
6581 case scVScale:
6582 case scTruncate:
6583 case scZeroExtend:
6584 case scSignExtend:
6585 case scPtrToInt:
6586 case scAddExpr:
6587 case scMulExpr:
6588 case scUDivExpr:
6589 case scAddRecExpr:
6590 case scUMaxExpr:
6591 case scSMaxExpr:
6592 case scUMinExpr:
6593 case scSMinExpr:
6595 WorkList.push_back(Expr);
6596 break;
6597 case scCouldNotCompute:
6598 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6599 }
6600 };
6601 AddToWorklist(S);
6602
6603 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6604 for (unsigned I = 0; I != WorkList.size(); ++I) {
6605 const SCEV *P = WorkList[I];
6606 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6607 // If it is not a `SCEVUnknown`, just recurse into operands.
6608 if (!UnknownS) {
6609 for (const SCEV *Op : P->operands())
6610 AddToWorklist(Op);
6611 continue;
6612 }
6613 // `SCEVUnknown`'s require special treatment.
6614 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6615 if (!PendingPhiRangesIter.insert(P).second)
6616 continue;
6617 for (auto &Op : reverse(P->operands()))
6618 AddToWorklist(getSCEV(Op));
6619 }
6620 }
6621
6622 if (!WorkList.empty()) {
6623 // Use getRangeRef to compute ranges for items in the worklist in reverse
6624 // order. This will force ranges for earlier operands to be computed before
6625 // their users in most cases.
6626 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6627 getRangeRef(P, SignHint);
6628
6629 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6630 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6631 PendingPhiRangesIter.erase(P);
6632 }
6633 }
6634
6635 return getRangeRef(S, SignHint, 0);
6636}
6637
6638/// Determine the range for a particular SCEV. If SignHint is
6639/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6640/// with a "cleaner" unsigned (resp. signed) representation.
6641const ConstantRange &ScalarEvolution::getRangeRef(
6642 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6644 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6645 : SignedRanges;
6647 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6649
6650 // See if we've computed this range already.
6652 if (I != Cache.end())
6653 return I->second;
6654
6655 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6656 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6657
6658 // Switch to iteratively computing the range for S, if it is part of a deeply
6659 // nested expression.
6661 return getRangeRefIter(S, SignHint);
6662
6663 unsigned BitWidth = getTypeSizeInBits(S->getType());
6664 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6665 using OBO = OverflowingBinaryOperator;
6666
6667 // If the value has known zeros, the maximum value will have those known zeros
6668 // as well.
6669 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6670 APInt Multiple = getNonZeroConstantMultiple(S);
6671 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6672 if (!Remainder.isZero())
6673 ConservativeResult =
6675 APInt::getMaxValue(BitWidth) - Remainder + 1);
6676 }
6677 else {
6679 if (TZ != 0) {
6680 ConservativeResult = ConstantRange(
6682 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6683 }
6684 }
6685
6686 switch (S->getSCEVType()) {
6687 case scConstant:
6688 llvm_unreachable("Already handled above.");
6689 case scVScale:
6690 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6691 case scTruncate: {
6692 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6693 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6694 return setRange(
6695 Trunc, SignHint,
6696 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6697 }
6698 case scZeroExtend: {
6699 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6700 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6701 return setRange(
6702 ZExt, SignHint,
6703 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6704 }
6705 case scSignExtend: {
6706 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6707 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6708 return setRange(
6709 SExt, SignHint,
6710 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6711 }
6712 case scPtrToInt: {
6713 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6714 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6715 return setRange(PtrToInt, SignHint, X);
6716 }
6717 case scAddExpr: {
6718 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6719 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6720 unsigned WrapType = OBO::AnyWrap;
6721 if (Add->hasNoSignedWrap())
6722 WrapType |= OBO::NoSignedWrap;
6723 if (Add->hasNoUnsignedWrap())
6724 WrapType |= OBO::NoUnsignedWrap;
6725 for (const SCEV *Op : drop_begin(Add->operands()))
6726 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6727 RangeType);
6728 return setRange(Add, SignHint,
6729 ConservativeResult.intersectWith(X, RangeType));
6730 }
6731 case scMulExpr: {
6732 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6733 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6734 for (const SCEV *Op : drop_begin(Mul->operands()))
6735 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6736 return setRange(Mul, SignHint,
6737 ConservativeResult.intersectWith(X, RangeType));
6738 }
6739 case scUDivExpr: {
6740 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6741 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6742 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6743 return setRange(UDiv, SignHint,
6744 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6745 }
6746 case scAddRecExpr: {
6747 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6748 // If there's no unsigned wrap, the value will never be less than its
6749 // initial value.
6750 if (AddRec->hasNoUnsignedWrap()) {
6751 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6752 if (!UnsignedMinValue.isZero())
6753 ConservativeResult = ConservativeResult.intersectWith(
6754 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6755 }
6756
6757 // If there's no signed wrap, and all the operands except initial value have
6758 // the same sign or zero, the value won't ever be:
6759 // 1: smaller than initial value if operands are non negative,
6760 // 2: bigger than initial value if operands are non positive.
6761 // For both cases, value can not cross signed min/max boundary.
6762 if (AddRec->hasNoSignedWrap()) {
6763 bool AllNonNeg = true;
6764 bool AllNonPos = true;
6765 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6766 if (!isKnownNonNegative(AddRec->getOperand(i)))
6767 AllNonNeg = false;
6768 if (!isKnownNonPositive(AddRec->getOperand(i)))
6769 AllNonPos = false;
6770 }
6771 if (AllNonNeg)
6772 ConservativeResult = ConservativeResult.intersectWith(
6775 RangeType);
6776 else if (AllNonPos)
6777 ConservativeResult = ConservativeResult.intersectWith(
6779 getSignedRangeMax(AddRec->getStart()) +
6780 1),
6781 RangeType);
6782 }
6783
6784 // TODO: non-affine addrec
6785 if (AddRec->isAffine()) {
6786 const SCEV *MaxBEScev =
6788 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6789 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6790
6791 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6792 // MaxBECount's active bits are all <= AddRec's bit width.
6793 if (MaxBECount.getBitWidth() > BitWidth &&
6794 MaxBECount.getActiveBits() <= BitWidth)
6795 MaxBECount = MaxBECount.trunc(BitWidth);
6796 else if (MaxBECount.getBitWidth() < BitWidth)
6797 MaxBECount = MaxBECount.zext(BitWidth);
6798
6799 if (MaxBECount.getBitWidth() == BitWidth) {
6800 auto RangeFromAffine = getRangeForAffineAR(
6801 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6802 ConservativeResult =
6803 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6804
6805 auto RangeFromFactoring = getRangeViaFactoring(
6806 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6807 ConservativeResult =
6808 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6809 }
6810 }
6811
6812 // Now try symbolic BE count and more powerful methods.
6814 const SCEV *SymbolicMaxBECount =
6816 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6817 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6818 AddRec->hasNoSelfWrap()) {
6819 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6820 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6821 ConservativeResult =
6822 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6823 }
6824 }
6825 }
6826
6827 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6828 }
6829 case scUMaxExpr:
6830 case scSMaxExpr:
6831 case scUMinExpr:
6832 case scSMinExpr:
6833 case scSequentialUMinExpr: {
6835 switch (S->getSCEVType()) {
6836 case scUMaxExpr:
6837 ID = Intrinsic::umax;
6838 break;
6839 case scSMaxExpr:
6840 ID = Intrinsic::smax;
6841 break;
6842 case scUMinExpr:
6844 ID = Intrinsic::umin;
6845 break;
6846 case scSMinExpr:
6847 ID = Intrinsic::smin;
6848 break;
6849 default:
6850 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6851 }
6852
6853 const auto *NAry = cast<SCEVNAryExpr>(S);
6854 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6855 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6856 X = X.intrinsic(
6857 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6858 return setRange(S, SignHint,
6859 ConservativeResult.intersectWith(X, RangeType));
6860 }
6861 case scUnknown: {
6862 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6863 Value *V = U->getValue();
6864
6865 // Check if the IR explicitly contains !range metadata.
6866 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6867 if (MDRange)
6868 ConservativeResult =
6869 ConservativeResult.intersectWith(*MDRange, RangeType);
6870
6871 // Use facts about recurrences in the underlying IR. Note that add
6872 // recurrences are AddRecExprs and thus don't hit this path. This
6873 // primarily handles shift recurrences.
6874 auto CR = getRangeForUnknownRecurrence(U);
6875 ConservativeResult = ConservativeResult.intersectWith(CR);
6876
6877 // See if ValueTracking can give us a useful range.
6878 const DataLayout &DL = getDataLayout();
6879 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6880 if (Known.getBitWidth() != BitWidth)
6881 Known = Known.zextOrTrunc(BitWidth);
6882
6883 // ValueTracking may be able to compute a tighter result for the number of
6884 // sign bits than for the value of those sign bits.
6885 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6886 if (U->getType()->isPointerTy()) {
6887 // If the pointer size is larger than the index size type, this can cause
6888 // NS to be larger than BitWidth. So compensate for this.
6889 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6890 int ptrIdxDiff = ptrSize - BitWidth;
6891 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6892 NS -= ptrIdxDiff;
6893 }
6894
6895 if (NS > 1) {
6896 // If we know any of the sign bits, we know all of the sign bits.
6897 if (!Known.Zero.getHiBits(NS).isZero())
6898 Known.Zero.setHighBits(NS);
6899 if (!Known.One.getHiBits(NS).isZero())
6900 Known.One.setHighBits(NS);
6901 }
6902
6903 if (Known.getMinValue() != Known.getMaxValue() + 1)
6904 ConservativeResult = ConservativeResult.intersectWith(
6905 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6906 RangeType);
6907 if (NS > 1)
6908 ConservativeResult = ConservativeResult.intersectWith(
6910 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6911 RangeType);
6912
6913 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6914 // Strengthen the range if the underlying IR value is a
6915 // global/alloca/heap allocation using the size of the object.
6916 bool CanBeNull, CanBeFreed;
6917 uint64_t DerefBytes =
6918 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6919 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6920 // The highest address the object can start is DerefBytes bytes before
6921 // the end (unsigned max value). If this value is not a multiple of the
6922 // alignment, the last possible start value is the next lowest multiple
6923 // of the alignment. Note: The computations below cannot overflow,
6924 // because if they would there's no possible start address for the
6925 // object.
6926 APInt MaxVal =
6927 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6928 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6929 uint64_t Rem = MaxVal.urem(Align);
6930 MaxVal -= APInt(BitWidth, Rem);
6931 APInt MinVal = APInt::getZero(BitWidth);
6932 if (llvm::isKnownNonZero(V, DL))
6933 MinVal = Align;
6934 ConservativeResult = ConservativeResult.intersectWith(
6935 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6936 }
6937 }
6938
6939 // A range of Phi is a subset of union of all ranges of its input.
6940 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6941 // Make sure that we do not run over cycled Phis.
6942 if (PendingPhiRanges.insert(Phi).second) {
6943 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6944
6945 for (const auto &Op : Phi->operands()) {
6946 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6947 RangeFromOps = RangeFromOps.unionWith(OpRange);
6948 // No point to continue if we already have a full set.
6949 if (RangeFromOps.isFullSet())
6950 break;
6951 }
6952 ConservativeResult =
6953 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6954 bool Erased = PendingPhiRanges.erase(Phi);
6955 assert(Erased && "Failed to erase Phi properly?");
6956 (void)Erased;
6957 }
6958 }
6959
6960 // vscale can't be equal to zero
6961 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6962 if (II->getIntrinsicID() == Intrinsic::vscale) {
6964 ConservativeResult = ConservativeResult.difference(Disallowed);
6965 }
6966
6967 return setRange(U, SignHint, std::move(ConservativeResult));
6968 }
6969 case scCouldNotCompute:
6970 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6971 }
6972
6973 return setRange(S, SignHint, std::move(ConservativeResult));
6974}
6975
6976// Given a StartRange, Step and MaxBECount for an expression compute a range of
6977// values that the expression can take. Initially, the expression has a value
6978// from StartRange and then is changed by Step up to MaxBECount times. Signed
6979// argument defines if we treat Step as signed or unsigned.
6981 const ConstantRange &StartRange,
6982 const APInt &MaxBECount,
6983 bool Signed) {
6984 unsigned BitWidth = Step.getBitWidth();
6985 assert(BitWidth == StartRange.getBitWidth() &&
6986 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6987 // If either Step or MaxBECount is 0, then the expression won't change, and we
6988 // just need to return the initial range.
6989 if (Step == 0 || MaxBECount == 0)
6990 return StartRange;
6991
6992 // If we don't know anything about the initial value (i.e. StartRange is
6993 // FullRange), then we don't know anything about the final range either.
6994 // Return FullRange.
6995 if (StartRange.isFullSet())
6996 return ConstantRange::getFull(BitWidth);
6997
6998 // If Step is signed and negative, then we use its absolute value, but we also
6999 // note that we're moving in the opposite direction.
7000 bool Descending = Signed && Step.isNegative();
7001
7002 if (Signed)
7003 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7004 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7005 // This equations hold true due to the well-defined wrap-around behavior of
7006 // APInt.
7007 Step = Step.abs();
7008
7009 // Check if Offset is more than full span of BitWidth. If it is, the
7010 // expression is guaranteed to overflow.
7011 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7012 return ConstantRange::getFull(BitWidth);
7013
7014 // Offset is by how much the expression can change. Checks above guarantee no
7015 // overflow here.
7016 APInt Offset = Step * MaxBECount;
7017
7018 // Minimum value of the final range will match the minimal value of StartRange
7019 // if the expression is increasing and will be decreased by Offset otherwise.
7020 // Maximum value of the final range will match the maximal value of StartRange
7021 // if the expression is decreasing and will be increased by Offset otherwise.
7022 APInt StartLower = StartRange.getLower();
7023 APInt StartUpper = StartRange.getUpper() - 1;
7024 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7025 : (StartUpper + std::move(Offset));
7026
7027 // It's possible that the new minimum/maximum value will fall into the initial
7028 // range (due to wrap around). This means that the expression can take any
7029 // value in this bitwidth, and we have to return full range.
7030 if (StartRange.contains(MovedBoundary))
7031 return ConstantRange::getFull(BitWidth);
7032
7033 APInt NewLower =
7034 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7035 APInt NewUpper =
7036 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7037 NewUpper += 1;
7038
7039 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7040 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7041}
7042
7043ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7044 const SCEV *Step,
7045 const APInt &MaxBECount) {
7046 assert(getTypeSizeInBits(Start->getType()) ==
7047 getTypeSizeInBits(Step->getType()) &&
7048 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7049 "mismatched bit widths");
7050
7051 // First, consider step signed.
7052 ConstantRange StartSRange = getSignedRange(Start);
7053 ConstantRange StepSRange = getSignedRange(Step);
7054
7055 // If Step can be both positive and negative, we need to find ranges for the
7056 // maximum absolute step values in both directions and union them.
7058 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7060 StartSRange, MaxBECount,
7061 /* Signed = */ true));
7062
7063 // Next, consider step unsigned.
7065 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7066 /* Signed = */ false);
7067
7068 // Finally, intersect signed and unsigned ranges.
7070}
7071
7072ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7073 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7074 ScalarEvolution::RangeSignHint SignHint) {
7075 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7076 assert(AddRec->hasNoSelfWrap() &&
7077 "This only works for non-self-wrapping AddRecs!");
7078 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7079 const SCEV *Step = AddRec->getStepRecurrence(*this);
7080 // Only deal with constant step to save compile time.
7081 if (!isa<SCEVConstant>(Step))
7082 return ConstantRange::getFull(BitWidth);
7083 // Let's make sure that we can prove that we do not self-wrap during
7084 // MaxBECount iterations. We need this because MaxBECount is a maximum
7085 // iteration count estimate, and we might infer nw from some exit for which we
7086 // do not know max exit count (or any other side reasoning).
7087 // TODO: Turn into assert at some point.
7088 if (getTypeSizeInBits(MaxBECount->getType()) >
7089 getTypeSizeInBits(AddRec->getType()))
7090 return ConstantRange::getFull(BitWidth);
7091 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7092 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7093 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7094 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7095 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7096 MaxItersWithoutWrap))
7097 return ConstantRange::getFull(BitWidth);
7098
7099 ICmpInst::Predicate LEPred =
7101 ICmpInst::Predicate GEPred =
7103 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7104
7105 // We know that there is no self-wrap. Let's take Start and End values and
7106 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7107 // the iteration. They either lie inside the range [Min(Start, End),
7108 // Max(Start, End)] or outside it:
7109 //
7110 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7111 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7112 //
7113 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7114 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7115 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7116 // Start <= End and step is positive, or Start >= End and step is negative.
7117 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7118 ConstantRange StartRange = getRangeRef(Start, SignHint);
7119 ConstantRange EndRange = getRangeRef(End, SignHint);
7120 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7121 // If they already cover full iteration space, we will know nothing useful
7122 // even if we prove what we want to prove.
7123 if (RangeBetween.isFullSet())
7124 return RangeBetween;
7125 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7126 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7127 : RangeBetween.isWrappedSet();
7128 if (IsWrappedSet)
7129 return ConstantRange::getFull(BitWidth);
7130
7131 if (isKnownPositive(Step) &&
7132 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7133 return RangeBetween;
7134 if (isKnownNegative(Step) &&
7135 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7136 return RangeBetween;
7137 return ConstantRange::getFull(BitWidth);
7138}
7139
7140ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7141 const SCEV *Step,
7142 const APInt &MaxBECount) {
7143 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7144 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7145
7146 unsigned BitWidth = MaxBECount.getBitWidth();
7147 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7148 getTypeSizeInBits(Step->getType()) == BitWidth &&
7149 "mismatched bit widths");
7150
7151 struct SelectPattern {
7152 Value *Condition = nullptr;
7153 APInt TrueValue;
7154 APInt FalseValue;
7155
7156 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7157 const SCEV *S) {
7158 std::optional<unsigned> CastOp;
7159 APInt Offset(BitWidth, 0);
7160
7162 "Should be!");
7163
7164 // Peel off a constant offset:
7165 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7166 // In the future we could consider being smarter here and handle
7167 // {Start+Step,+,Step} too.
7168 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7169 return;
7170
7171 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7172 S = SA->getOperand(1);
7173 }
7174
7175 // Peel off a cast operation
7176 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7177 CastOp = SCast->getSCEVType();
7178 S = SCast->getOperand();
7179 }
7180
7181 using namespace llvm::PatternMatch;
7182
7183 auto *SU = dyn_cast<SCEVUnknown>(S);
7184 const APInt *TrueVal, *FalseVal;
7185 if (!SU ||
7186 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7187 m_APInt(FalseVal)))) {
7188 Condition = nullptr;
7189 return;
7190 }
7191
7192 TrueValue = *TrueVal;
7193 FalseValue = *FalseVal;
7194
7195 // Re-apply the cast we peeled off earlier
7196 if (CastOp)
7197 switch (*CastOp) {
7198 default:
7199 llvm_unreachable("Unknown SCEV cast type!");
7200
7201 case scTruncate:
7202 TrueValue = TrueValue.trunc(BitWidth);
7203 FalseValue = FalseValue.trunc(BitWidth);
7204 break;
7205 case scZeroExtend:
7206 TrueValue = TrueValue.zext(BitWidth);
7207 FalseValue = FalseValue.zext(BitWidth);
7208 break;
7209 case scSignExtend:
7210 TrueValue = TrueValue.sext(BitWidth);
7211 FalseValue = FalseValue.sext(BitWidth);
7212 break;
7213 }
7214
7215 // Re-apply the constant offset we peeled off earlier
7216 TrueValue += Offset;
7217 FalseValue += Offset;
7218 }
7219
7220 bool isRecognized() { return Condition != nullptr; }
7221 };
7222
7223 SelectPattern StartPattern(*this, BitWidth, Start);
7224 if (!StartPattern.isRecognized())
7225 return ConstantRange::getFull(BitWidth);
7226
7227 SelectPattern StepPattern(*this, BitWidth, Step);
7228 if (!StepPattern.isRecognized())
7229 return ConstantRange::getFull(BitWidth);
7230
7231 if (StartPattern.Condition != StepPattern.Condition) {
7232 // We don't handle this case today; but we could, by considering four
7233 // possibilities below instead of two. I'm not sure if there are cases where
7234 // that will help over what getRange already does, though.
7235 return ConstantRange::getFull(BitWidth);
7236 }
7237
7238 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7239 // construct arbitrary general SCEV expressions here. This function is called
7240 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7241 // say) can end up caching a suboptimal value.
7242
7243 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7244 // C2352 and C2512 (otherwise it isn't needed).
7245
7246 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7247 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7248 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7249 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7250
7251 ConstantRange TrueRange =
7252 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7253 ConstantRange FalseRange =
7254 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7255
7256 return TrueRange.unionWith(FalseRange);
7257}
7258
7259SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7260 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7261 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7262
7263 // Return early if there are no flags to propagate to the SCEV.
7265 if (BinOp->hasNoUnsignedWrap())
7267 if (BinOp->hasNoSignedWrap())
7269 if (Flags == SCEV::FlagAnyWrap)
7270 return SCEV::FlagAnyWrap;
7271
7272 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7273}
7274
7275const Instruction *
7276ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7277 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7278 return &*AddRec->getLoop()->getHeader()->begin();
7279 if (auto *U = dyn_cast<SCEVUnknown>(S))
7280 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7281 return I;
7282 return nullptr;
7283}
7284
7285const Instruction *
7286ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7287 bool &Precise) {
7288 Precise = true;
7289 // Do a bounded search of the def relation of the requested SCEVs.
7292 auto pushOp = [&](const SCEV *S) {
7293 if (!Visited.insert(S).second)
7294 return;
7295 // Threshold of 30 here is arbitrary.
7296 if (Visited.size() > 30) {
7297 Precise = false;
7298 return;
7299 }
7300 Worklist.push_back(S);
7301 };
7302
7303 for (const auto *S : Ops)
7304 pushOp(S);
7305
7306 const Instruction *Bound = nullptr;
7307 while (!Worklist.empty()) {
7308 auto *S = Worklist.pop_back_val();
7309 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7310 if (!Bound || DT.dominates(Bound, DefI))
7311 Bound = DefI;
7312 } else {
7313 for (const auto *Op : S->operands())
7314 pushOp(Op);
7315 }
7316 }
7317 return Bound ? Bound : &*F.getEntryBlock().begin();
7318}
7319
7320const Instruction *
7321ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7322 bool Discard;
7323 return getDefiningScopeBound(Ops, Discard);
7324}
7325
7326bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7327 const Instruction *B) {
7328 if (A->getParent() == B->getParent() &&
7330 B->getIterator()))
7331 return true;
7332
7333 auto *BLoop = LI.getLoopFor(B->getParent());
7334 if (BLoop && BLoop->getHeader() == B->getParent() &&
7335 BLoop->getLoopPreheader() == A->getParent() &&
7337 A->getParent()->end()) &&
7338 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7339 B->getIterator()))
7340 return true;
7341 return false;
7342}
7343
7344bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7345 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7346 visitAll(Op, PC);
7347 return PC.MaybePoison.empty();
7348}
7349
7350bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7351 return !SCEVExprContains(Op, [this](const SCEV *S) {
7352 auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7353 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7354 // is a non-zero constant, we have to assume the UDiv may be UB.
7355 return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7356 !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7357 });
7358}
7359
7360bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7361 // Only proceed if we can prove that I does not yield poison.
7363 return false;
7364
7365 // At this point we know that if I is executed, then it does not wrap
7366 // according to at least one of NSW or NUW. If I is not executed, then we do
7367 // not know if the calculation that I represents would wrap. Multiple
7368 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7369 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7370 // derived from other instructions that map to the same SCEV. We cannot make
7371 // that guarantee for cases where I is not executed. So we need to find a
7372 // upper bound on the defining scope for the SCEV, and prove that I is
7373 // executed every time we enter that scope. When the bounding scope is a
7374 // loop (the common case), this is equivalent to proving I executes on every
7375 // iteration of that loop.
7377 for (const Use &Op : I->operands()) {
7378 // I could be an extractvalue from a call to an overflow intrinsic.
7379 // TODO: We can do better here in some cases.
7380 if (isSCEVable(Op->getType()))
7381 SCEVOps.push_back(getSCEV(Op));
7382 }
7383 auto *DefI = getDefiningScopeBound(SCEVOps);
7384 return isGuaranteedToTransferExecutionTo(DefI, I);
7385}
7386
7387bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7388 // If we know that \c I can never be poison period, then that's enough.
7389 if (isSCEVExprNeverPoison(I))
7390 return true;
7391
7392 // If the loop only has one exit, then we know that, if the loop is entered,
7393 // any instruction dominating that exit will be executed. If any such
7394 // instruction would result in UB, the addrec cannot be poison.
7395 //
7396 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7397 // also handles uses outside the loop header (they just need to dominate the
7398 // single exit).
7399
7400 auto *ExitingBB = L->getExitingBlock();
7401 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7402 return false;
7403
7406
7407 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7408 // things that are known to be poison under that assumption go on the
7409 // Worklist.
7410 KnownPoison.insert(I);
7411 Worklist.push_back(I);
7412
7413 while (!Worklist.empty()) {
7414 const Instruction *Poison = Worklist.pop_back_val();
7415
7416 for (const Use &U : Poison->uses()) {
7417 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7418 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7419 DT.dominates(PoisonUser->getParent(), ExitingBB))
7420 return true;
7421
7422 if (propagatesPoison(U) && L->contains(PoisonUser))
7423 if (KnownPoison.insert(PoisonUser).second)
7424 Worklist.push_back(PoisonUser);
7425 }
7426 }
7427
7428 return false;
7429}
7430
7431ScalarEvolution::LoopProperties
7432ScalarEvolution::getLoopProperties(const Loop *L) {
7433 using LoopProperties = ScalarEvolution::LoopProperties;
7434
7435 auto Itr = LoopPropertiesCache.find(L);
7436 if (Itr == LoopPropertiesCache.end()) {
7437 auto HasSideEffects = [](Instruction *I) {
7438 if (auto *SI = dyn_cast<StoreInst>(I))
7439 return !SI->isSimple();
7440
7441 return I->mayThrow() || I->mayWriteToMemory();
7442 };
7443
7444 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7445 /*HasNoSideEffects*/ true};
7446
7447 for (auto *BB : L->getBlocks())
7448 for (auto &I : *BB) {
7450 LP.HasNoAbnormalExits = false;
7451 if (HasSideEffects(&I))
7452 LP.HasNoSideEffects = false;
7453 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7454 break; // We're already as pessimistic as we can get.
7455 }
7456
7457 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7458 assert(InsertPair.second && "We just checked!");
7459 Itr = InsertPair.first;
7460 }
7461
7462 return Itr->second;
7463}
7464
7466 // A mustprogress loop without side effects must be finite.
7467 // TODO: The check used here is very conservative. It's only *specific*
7468 // side effects which are well defined in infinite loops.
7469 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7470}
7471
7472const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7473 // Worklist item with a Value and a bool indicating whether all operands have
7474 // been visited already.
7477
7478 Stack.emplace_back(V, true);
7479 Stack.emplace_back(V, false);
7480 while (!Stack.empty()) {
7481 auto E = Stack.pop_back_val();
7482 Value *CurV = E.getPointer();
7483
7484 if (getExistingSCEV(CurV))
7485 continue;
7486
7488 const SCEV *CreatedSCEV = nullptr;
7489 // If all operands have been visited already, create the SCEV.
7490 if (E.getInt()) {
7491 CreatedSCEV = createSCEV(CurV);
7492 } else {
7493 // Otherwise get the operands we need to create SCEV's for before creating
7494 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7495 // just use it.
7496 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7497 }
7498
7499 if (CreatedSCEV) {
7500 insertValueToMap(CurV, CreatedSCEV);
7501 } else {
7502 // Queue CurV for SCEV creation, followed by its's operands which need to
7503 // be constructed first.
7504 Stack.emplace_back(CurV, true);
7505 for (Value *Op : Ops)
7506 Stack.emplace_back(Op, false);
7507 }
7508 }
7509
7510 return getExistingSCEV(V);
7511}
7512
7513const SCEV *
7514ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7515 if (!isSCEVable(V->getType()))
7516 return getUnknown(V);
7517
7518 if (Instruction *I = dyn_cast<Instruction>(V)) {
7519 // Don't attempt to analyze instructions in blocks that aren't
7520 // reachable. Such instructions don't matter, and they aren't required
7521 // to obey basic rules for definitions dominating uses which this
7522 // analysis depends on.
7523 if (!DT.isReachableFromEntry(I->getParent()))
7524 return getUnknown(PoisonValue::get(V->getType()));
7525 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7526 return getConstant(CI);
7527 else if (isa<GlobalAlias>(V))
7528 return getUnknown(V);
7529 else if (!isa<ConstantExpr>(V))
7530 return getUnknown(V);
7531
7532 Operator *U = cast<Operator>(V);
7533 if (auto BO =
7534 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7535 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7536 switch (BO->Opcode) {
7537 case Instruction::Add:
7538 case Instruction::Mul: {
7539 // For additions and multiplications, traverse add/mul chains for which we
7540 // can potentially create a single SCEV, to reduce the number of
7541 // get{Add,Mul}Expr calls.
7542 do {
7543 if (BO->Op) {
7544 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7545 Ops.push_back(BO->Op);
7546 break;
7547 }
7548 }
7549 Ops.push_back(BO->RHS);
7550 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7551 dyn_cast<Instruction>(V));
7552 if (!NewBO ||
7553 (BO->Opcode == Instruction::Add &&
7554 (NewBO->Opcode != Instruction::Add &&
7555 NewBO->Opcode != Instruction::Sub)) ||
7556 (BO->Opcode == Instruction::Mul &&
7557 NewBO->Opcode != Instruction::Mul)) {
7558 Ops.push_back(BO->LHS);
7559 break;
7560 }
7561 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7562 // requires a SCEV for the LHS.
7563 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7564 auto *I = dyn_cast<Instruction>(BO->Op);
7565 if (I && programUndefinedIfPoison(I)) {
7566 Ops.push_back(BO->LHS);
7567 break;
7568 }
7569 }
7570 BO = NewBO;
7571 } while (true);
7572 return nullptr;
7573 }
7574 case Instruction::Sub:
7575 case Instruction::UDiv:
7576 case Instruction::URem:
7577 break;
7578 case Instruction::AShr:
7579 case Instruction::Shl:
7580 case Instruction::Xor:
7581 if (!IsConstArg)
7582 return nullptr;
7583 break;
7584 case Instruction::And:
7585 case Instruction::Or:
7586 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7587 return nullptr;
7588 break;
7589 case Instruction::LShr:
7590 return getUnknown(V);
7591 default:
7592 llvm_unreachable("Unhandled binop");
7593 break;
7594 }
7595
7596 Ops.push_back(BO->LHS);
7597 Ops.push_back(BO->RHS);
7598 return nullptr;
7599 }
7600
7601 switch (U->getOpcode()) {
7602 case Instruction::Trunc:
7603 case Instruction::ZExt:
7604 case Instruction::SExt:
7605 case Instruction::PtrToInt:
7606 Ops.push_back(U->getOperand(0));
7607 return nullptr;
7608
7609 case Instruction::BitCast:
7610 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7611 Ops.push_back(U->getOperand(0));
7612 return nullptr;
7613 }
7614 return getUnknown(V);
7615
7616 case Instruction::SDiv:
7617 case Instruction::SRem:
7618 Ops.push_back(U->getOperand(0));
7619 Ops.push_back(U->getOperand(1));
7620 return nullptr;
7621
7622 case Instruction::GetElementPtr:
7623 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7624 "GEP source element type must be sized");
7625 for (Value *Index : U->operands())
7626 Ops.push_back(Index);
7627 return nullptr;
7628
7629 case Instruction::IntToPtr:
7630 return getUnknown(V);
7631
7632 case Instruction::PHI:
7633 // Keep constructing SCEVs' for phis recursively for now.
7634 return nullptr;
7635
7636 case Instruction::Select: {
7637 // Check if U is a select that can be simplified to a SCEVUnknown.
7638 auto CanSimplifyToUnknown = [this, U]() {
7639 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7640 return false;
7641
7642 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7643 if (!ICI)
7644 return false;
7645 Value *LHS = ICI->getOperand(0);
7646 Value *RHS = ICI->getOperand(1);
7647 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7648 ICI->getPredicate() == CmpInst::ICMP_NE) {
7649 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7650 return true;
7651 } else if (getTypeSizeInBits(LHS->getType()) >
7652 getTypeSizeInBits(U->getType()))
7653 return true;
7654 return false;
7655 };
7656 if (CanSimplifyToUnknown())
7657 return getUnknown(U);
7658
7659 for (Value *Inc : U->operands())
7660 Ops.push_back(Inc);
7661 return nullptr;
7662 break;
7663 }
7664 case Instruction::Call:
7665 case Instruction::Invoke:
7666 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7667 Ops.push_back(RV);
7668 return nullptr;
7669 }
7670
7671 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7672 switch (II->getIntrinsicID()) {
7673 case Intrinsic::abs:
7674 Ops.push_back(II->getArgOperand(0));
7675 return nullptr;
7676 case Intrinsic::umax:
7677 case Intrinsic::umin:
7678 case Intrinsic::smax:
7679 case Intrinsic::smin:
7680 case Intrinsic::usub_sat:
7681 case Intrinsic::uadd_sat:
7682 Ops.push_back(II->getArgOperand(0));
7683 Ops.push_back(II->getArgOperand(1));
7684 return nullptr;
7685 case Intrinsic::start_loop_iterations:
7686 case Intrinsic::annotation:
7687 case Intrinsic::ptr_annotation:
7688 Ops.push_back(II->getArgOperand(0));
7689 return nullptr;
7690 default:
7691 break;
7692 }
7693 }
7694 break;
7695 }
7696
7697 return nullptr;
7698}
7699
7700const SCEV *ScalarEvolution::createSCEV(Value *V) {
7701 if (!isSCEVable(V->getType()))
7702 return getUnknown(V);
7703
7704 if (Instruction *I = dyn_cast<Instruction>(V)) {
7705 // Don't attempt to analyze instructions in blocks that aren't
7706 // reachable. Such instructions don't matter, and they aren't required
7707 // to obey basic rules for definitions dominating uses which this
7708 // analysis depends on.
7709 if (!DT.isReachableFromEntry(I->getParent()))
7710 return getUnknown(PoisonValue::get(V->getType()));
7711 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7712 return getConstant(CI);
7713 else if (isa<GlobalAlias>(V))
7714 return getUnknown(V);
7715 else if (!isa<ConstantExpr>(V))
7716 return getUnknown(V);
7717
7718 const SCEV *LHS;
7719 const SCEV *RHS;
7720
7721 Operator *U = cast<Operator>(V);
7722 if (auto BO =
7723 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7724 switch (BO->Opcode) {
7725 case Instruction::Add: {
7726 // The simple thing to do would be to just call getSCEV on both operands
7727 // and call getAddExpr with the result. However if we're looking at a
7728 // bunch of things all added together, this can be quite inefficient,
7729 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7730 // Instead, gather up all the operands and make a single getAddExpr call.
7731 // LLVM IR canonical form means we need only traverse the left operands.
7733 do {
7734 if (BO->Op) {
7735 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7736 AddOps.push_back(OpSCEV);
7737 break;
7738 }
7739
7740 // If a NUW or NSW flag can be applied to the SCEV for this
7741 // addition, then compute the SCEV for this addition by itself
7742 // with a separate call to getAddExpr. We need to do that
7743 // instead of pushing the operands of the addition onto AddOps,
7744 // since the flags are only known to apply to this particular
7745 // addition - they may not apply to other additions that can be
7746 // formed with operands from AddOps.
7747 const SCEV *RHS = getSCEV(BO->RHS);
7748 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7749 if (Flags != SCEV::FlagAnyWrap) {
7750 const SCEV *LHS = getSCEV(BO->LHS);
7751 if (BO->Opcode == Instruction::Sub)
7752 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7753 else
7754 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7755 break;
7756 }
7757 }
7758
7759 if (BO->Opcode == Instruction::Sub)
7760 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7761 else
7762 AddOps.push_back(getSCEV(BO->RHS));
7763
7764 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7765 dyn_cast<Instruction>(V));
7766 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7767 NewBO->Opcode != Instruction::Sub)) {
7768 AddOps.push_back(getSCEV(BO->LHS));
7769 break;
7770 }
7771 BO = NewBO;
7772 } while (true);
7773
7774 return getAddExpr(AddOps);
7775 }
7776
7777 case Instruction::Mul: {
7779 do {
7780 if (BO->Op) {
7781 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7782 MulOps.push_back(OpSCEV);
7783 break;
7784 }
7785
7786 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7787 if (Flags != SCEV::FlagAnyWrap) {
7788 LHS = getSCEV(BO->LHS);
7789 RHS = getSCEV(BO->RHS);
7790 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7791 break;
7792 }
7793 }
7794
7795 MulOps.push_back(getSCEV(BO->RHS));
7796 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7797 dyn_cast<Instruction>(V));
7798 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7799 MulOps.push_back(getSCEV(BO->LHS));
7800 break;
7801 }
7802 BO = NewBO;
7803 } while (true);
7804
7805 return getMulExpr(MulOps);
7806 }
7807 case Instruction::UDiv:
7808 LHS = getSCEV(BO->LHS);
7809 RHS = getSCEV(BO->RHS);
7810 return getUDivExpr(LHS, RHS);
7811 case Instruction::URem:
7812 LHS = getSCEV(BO->LHS);
7813 RHS = getSCEV(BO->RHS);
7814 return getURemExpr(LHS, RHS);
7815 case Instruction::Sub: {
7817 if (BO->Op)
7818 Flags = getNoWrapFlagsFromUB(BO->Op);
7819 LHS = getSCEV(BO->LHS);
7820 RHS = getSCEV(BO->RHS);
7821 return getMinusSCEV(LHS, RHS, Flags);
7822 }
7823 case Instruction::And:
7824 // For an expression like x&255 that merely masks off the high bits,
7825 // use zext(trunc(x)) as the SCEV expression.
7826 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7827 if (CI->isZero())
7828 return getSCEV(BO->RHS);
7829 if (CI->isMinusOne())
7830 return getSCEV(BO->LHS);
7831 const APInt &A = CI->getValue();
7832
7833 // Instcombine's ShrinkDemandedConstant may strip bits out of
7834 // constants, obscuring what would otherwise be a low-bits mask.
7835 // Use computeKnownBits to compute what ShrinkDemandedConstant
7836 // knew about to reconstruct a low-bits mask value.
7837 unsigned LZ = A.countl_zero();
7838 unsigned TZ = A.countr_zero();
7839 unsigned BitWidth = A.getBitWidth();
7840 KnownBits Known(BitWidth);
7841 computeKnownBits(BO->LHS, Known, getDataLayout(),
7842 0, &AC, nullptr, &DT);
7843
7844 APInt EffectiveMask =
7845 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7846 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7847 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7848 const SCEV *LHS = getSCEV(BO->LHS);
7849 const SCEV *ShiftedLHS = nullptr;
7850 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7851 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7852 // For an expression like (x * 8) & 8, simplify the multiply.
7853 unsigned MulZeros = OpC->getAPInt().countr_zero();
7854 unsigned GCD = std::min(MulZeros, TZ);
7855 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7857 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7858 append_range(MulOps, LHSMul->operands().drop_front());
7859 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7860 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7861 }
7862 }
7863 if (!ShiftedLHS)
7864 ShiftedLHS = getUDivExpr(LHS, MulCount);
7865 return getMulExpr(
7867 getTruncateExpr(ShiftedLHS,
7868 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7869 BO->LHS->getType()),
7870 MulCount);
7871 }
7872 }
7873 // Binary `and` is a bit-wise `umin`.
7874 if (BO->LHS->getType()->isIntegerTy(1)) {
7875 LHS = getSCEV(BO->LHS);
7876 RHS = getSCEV(BO->RHS);
7877 return getUMinExpr(LHS, RHS);
7878 }
7879 break;
7880
7881 case Instruction::Or:
7882 // Binary `or` is a bit-wise `umax`.
7883 if (BO->LHS->getType()->isIntegerTy(1)) {
7884 LHS = getSCEV(BO->LHS);
7885 RHS = getSCEV(BO->RHS);
7886 return getUMaxExpr(LHS, RHS);
7887 }
7888 break;
7889
7890 case Instruction::Xor:
7891 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7892 // If the RHS of xor is -1, then this is a not operation.
7893 if (CI->isMinusOne())
7894 return getNotSCEV(getSCEV(BO->LHS));
7895
7896 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7897 // This is a variant of the check for xor with -1, and it handles
7898 // the case where instcombine has trimmed non-demanded bits out
7899 // of an xor with -1.
7900 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7901 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7902 if (LBO->getOpcode() == Instruction::And &&
7903 LCI->getValue() == CI->getValue())
7904 if (const SCEVZeroExtendExpr *Z =
7905 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7906 Type *UTy = BO->LHS->getType();
7907 const SCEV *Z0 = Z->getOperand();
7908 Type *Z0Ty = Z0->getType();
7909 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7910
7911 // If C is a low-bits mask, the zero extend is serving to
7912 // mask off the high bits. Complement the operand and
7913 // re-apply the zext.
7914 if (CI->getValue().isMask(Z0TySize))
7915 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7916
7917 // If C is a single bit, it may be in the sign-bit position
7918 // before the zero-extend. In this case, represent the xor
7919 // using an add, which is equivalent, and re-apply the zext.
7920 APInt Trunc = CI->getValue().trunc(Z0TySize);
7921 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7922 Trunc.isSignMask())
7923 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7924 UTy);
7925 }
7926 }
7927 break;
7928
7929 case Instruction::Shl:
7930 // Turn shift left of a constant amount into a multiply.
7931 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7932 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7933
7934 // If the shift count is not less than the bitwidth, the result of
7935 // the shift is undefined. Don't try to analyze it, because the
7936 // resolution chosen here may differ from the resolution chosen in
7937 // other parts of the compiler.
7938 if (SA->getValue().uge(BitWidth))
7939 break;
7940
7941 // We can safely preserve the nuw flag in all cases. It's also safe to
7942 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7943 // requires special handling. It can be preserved as long as we're not
7944 // left shifting by bitwidth - 1.
7945 auto Flags = SCEV::FlagAnyWrap;
7946 if (BO->Op) {
7947 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7948 if ((MulFlags & SCEV::FlagNSW) &&
7949 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7951 if (MulFlags & SCEV::FlagNUW)
7953 }
7954
7955 ConstantInt *X = ConstantInt::get(
7956 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7957 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7958 }
7959 break;
7960
7961 case Instruction::AShr:
7962 // AShr X, C, where C is a constant.
7963 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7964 if (!CI)
7965 break;
7966
7967 Type *OuterTy = BO->LHS->getType();
7969 // If the shift count is not less than the bitwidth, the result of
7970 // the shift is undefined. Don't try to analyze it, because the
7971 // resolution chosen here may differ from the resolution chosen in
7972 // other parts of the compiler.
7973 if (CI->getValue().uge(BitWidth))
7974 break;
7975
7976 if (CI->isZero())
7977 return getSCEV(BO->LHS); // shift by zero --> noop
7978
7979 uint64_t AShrAmt = CI->getZExtValue();
7980 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7981
7982 Operator *L = dyn_cast<Operator>(BO->LHS);
7983 const SCEV *AddTruncateExpr = nullptr;
7984 ConstantInt *ShlAmtCI = nullptr;
7985 const SCEV *AddConstant = nullptr;
7986
7987 if (L && L->getOpcode() == Instruction::Add) {
7988 // X = Shl A, n
7989 // Y = Add X, c
7990 // Z = AShr Y, m
7991 // n, c and m are constants.
7992
7993 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7994 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7995 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7996 if (AddOperandCI) {
7997 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7998 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7999 // since we truncate to TruncTy, the AddConstant should be of the
8000 // same type, so create a new Constant with type same as TruncTy.
8001 // Also, the Add constant should be shifted right by AShr amount.
8002 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8003 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8004 // we model the expression as sext(add(trunc(A), c << n)), since the
8005 // sext(trunc) part is already handled below, we create a
8006 // AddExpr(TruncExp) which will be used later.
8007 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8008 }
8009 }
8010 } else if (L && L->getOpcode() == Instruction::Shl) {
8011 // X = Shl A, n
8012 // Y = AShr X, m
8013 // Both n and m are constant.
8014
8015 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8016 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8017 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8018 }
8019
8020 if (AddTruncateExpr && ShlAmtCI) {
8021 // We can merge the two given cases into a single SCEV statement,
8022 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8023 // a simpler case. The following code handles the two cases:
8024 //
8025 // 1) For a two-shift sext-inreg, i.e. n = m,
8026 // use sext(trunc(x)) as the SCEV expression.
8027 //
8028 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8029 // expression. We already checked that ShlAmt < BitWidth, so
8030 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8031 // ShlAmt - AShrAmt < Amt.
8032 const APInt &ShlAmt = ShlAmtCI->getValue();
8033 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8035 ShlAmtCI->getZExtValue() - AShrAmt);
8036 const SCEV *CompositeExpr =
8037 getMulExpr(AddTruncateExpr, getConstant(Mul));
8038 if (L->getOpcode() != Instruction::Shl)
8039 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8040
8041 return getSignExtendExpr(CompositeExpr, OuterTy);
8042 }
8043 }
8044 break;
8045 }
8046 }
8047
8048 switch (U->getOpcode()) {
8049 case Instruction::Trunc:
8050 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8051
8052 case Instruction::ZExt:
8053 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8054
8055 case Instruction::SExt:
8056 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8057 dyn_cast<Instruction>(V))) {
8058 // The NSW flag of a subtract does not always survive the conversion to
8059 // A + (-1)*B. By pushing sign extension onto its operands we are much
8060 // more likely to preserve NSW and allow later AddRec optimisations.
8061 //
8062 // NOTE: This is effectively duplicating this logic from getSignExtend:
8063 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8064 // but by that point the NSW information has potentially been lost.
8065 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8066 Type *Ty = U->getType();
8067 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8068 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8069 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8070 }
8071 }
8072 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8073
8074 case Instruction::BitCast:
8075 // BitCasts are no-op casts so we just eliminate the cast.
8076 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8077 return getSCEV(U->getOperand(0));
8078 break;
8079
8080 case Instruction::PtrToInt: {
8081 // Pointer to integer cast is straight-forward, so do model it.
8082 const SCEV *Op = getSCEV(U->getOperand(0));
8083 Type *DstIntTy = U->getType();
8084 // But only if effective SCEV (integer) type is wide enough to represent
8085 // all possible pointer values.
8086 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8087 if (isa<SCEVCouldNotCompute>(IntOp))
8088 return getUnknown(V);
8089 return IntOp;
8090 }
8091 case Instruction::IntToPtr:
8092 // Just don't deal with inttoptr casts.
8093 return getUnknown(V);
8094
8095 case Instruction::SDiv:
8096 // If both operands are non-negative, this is just an udiv.
8097 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8098 isKnownNonNegative(getSCEV(U->getOperand(1))))
8099 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8100 break;
8101
8102 case Instruction::SRem:
8103 // If both operands are non-negative, this is just an urem.
8104 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8105 isKnownNonNegative(getSCEV(U->getOperand(1))))
8106 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8107 break;
8108
8109 case Instruction::GetElementPtr:
8110 return createNodeForGEP(cast<GEPOperator>(U));
8111
8112 case Instruction::PHI:
8113 return createNodeForPHI(cast<PHINode>(U));
8114
8115 case Instruction::Select:
8116 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8117 U->getOperand(2));
8118
8119 case Instruction::Call:
8120 case Instruction::Invoke:
8121 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8122 return getSCEV(RV);
8123
8124 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8125 switch (II->getIntrinsicID()) {
8126 case Intrinsic::abs:
8127 return getAbsExpr(
8128 getSCEV(II->getArgOperand(0)),
8129 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8130 case Intrinsic::umax:
8131 LHS = getSCEV(II->getArgOperand(0));
8132 RHS = getSCEV(II->getArgOperand(1));
8133 return getUMaxExpr(LHS, RHS);
8134 case Intrinsic::umin:
8135 LHS = getSCEV(II->getArgOperand(0));
8136 RHS = getSCEV(II->getArgOperand(1));
8137 return getUMinExpr(LHS, RHS);
8138 case Intrinsic::smax:
8139 LHS = getSCEV(II->getArgOperand(0));
8140 RHS = getSCEV(II->getArgOperand(1));
8141 return getSMaxExpr(LHS, RHS);
8142 case Intrinsic::smin:
8143 LHS = getSCEV(II->getArgOperand(0));
8144 RHS = getSCEV(II->getArgOperand(1));
8145 return getSMinExpr(LHS, RHS);
8146 case Intrinsic::usub_sat: {
8147 const SCEV *X = getSCEV(II->getArgOperand(0));
8148 const SCEV *Y = getSCEV(II->getArgOperand(1));
8149 const SCEV *ClampedY = getUMinExpr(X, Y);
8150 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8151 }
8152 case Intrinsic::uadd_sat: {
8153 const SCEV *X = getSCEV(II->getArgOperand(0));
8154 const SCEV *Y = getSCEV(II->getArgOperand(1));
8155 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8156 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8157 }
8158 case Intrinsic::start_loop_iterations:
8159 case Intrinsic::annotation:
8160 case Intrinsic::ptr_annotation:
8161 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8162 // just eqivalent to the first operand for SCEV purposes.
8163 return getSCEV(II->getArgOperand(0));
8164 case Intrinsic::vscale:
8165 return getVScale(II->getType());
8166 default:
8167 break;
8168 }
8169 }
8170 break;
8171 }
8172
8173 return getUnknown(V);
8174}
8175
8176//===----------------------------------------------------------------------===//
8177// Iteration Count Computation Code
8178//
8179
8181 if (isa<SCEVCouldNotCompute>(ExitCount))
8182 return getCouldNotCompute();
8183
8184 auto *ExitCountType = ExitCount->getType();
8185 assert(ExitCountType->isIntegerTy());
8186 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8187 1 + ExitCountType->getScalarSizeInBits());
8188 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8189}
8190
8192 Type *EvalTy,
8193 const Loop *L) {
8194 if (isa<SCEVCouldNotCompute>(ExitCount))
8195 return getCouldNotCompute();
8196
8197 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8198 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8199
8200 auto CanAddOneWithoutOverflow = [&]() {
8201 ConstantRange ExitCountRange =
8202 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8203 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8204 return true;
8205
8206 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8207 getMinusOne(ExitCount->getType()));
8208 };
8209
8210 // If we need to zero extend the backedge count, check if we can add one to
8211 // it prior to zero extending without overflow. Provided this is safe, it
8212 // allows better simplification of the +1.
8213 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8214 return getZeroExtendExpr(
8215 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8216
8217 // Get the total trip count from the count by adding 1. This may wrap.
8218 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8219}
8220
8221static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8222 if (!ExitCount)
8223 return 0;
8224
8225 ConstantInt *ExitConst = ExitCount->getValue();
8226
8227 // Guard against huge trip counts.
8228 if (ExitConst->getValue().getActiveBits() > 32)
8229 return 0;
8230
8231 // In case of integer overflow, this returns 0, which is correct.
8232 return ((unsigned)ExitConst->getZExtValue()) + 1;
8233}
8234
8236 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8237 return getConstantTripCount(ExitCount);
8238}
8239
8240unsigned
8242 const BasicBlock *ExitingBlock) {
8243 assert(ExitingBlock && "Must pass a non-null exiting block!");
8244 assert(L->isLoopExiting(ExitingBlock) &&
8245 "Exiting block must actually branch out of the loop!");
8246 const SCEVConstant *ExitCount =
8247 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8248 return getConstantTripCount(ExitCount);
8249}
8250
8252 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8253
8254 const auto *MaxExitCount =
8255 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8257 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8258}
8259
8261 SmallVector<BasicBlock *, 8> ExitingBlocks;
8262 L->getExitingBlocks(ExitingBlocks);
8263
8264 std::optional<unsigned> Res;
8265 for (auto *ExitingBB : ExitingBlocks) {
8266 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8267 if (!Res)
8268 Res = Multiple;
8269 Res = (unsigned)std::gcd(*Res, Multiple);
8270 }
8271 return Res.value_or(1);
8272}
8273
8275 const SCEV *ExitCount) {
8276 if (ExitCount == getCouldNotCompute())
8277 return 1;
8278
8279 // Get the trip count
8280 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8281
8282 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8283 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8284 // the greatest power of 2 divisor less than 2^32.
8285 return Multiple.getActiveBits() > 32
8286 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8287 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8288}
8289
8290/// Returns the largest constant divisor of the trip count of this loop as a
8291/// normal unsigned value, if possible. This means that the actual trip count is
8292/// always a multiple of the returned value (don't forget the trip count could
8293/// very well be zero as well!).
8294///
8295/// Returns 1 if the trip count is unknown or not guaranteed to be the
8296/// multiple of a constant (which is also the case if the trip count is simply
8297/// constant, use getSmallConstantTripCount for that case), Will also return 1
8298/// if the trip count is very large (>= 2^32).
8299///
8300/// As explained in the comments for getSmallConstantTripCount, this assumes
8301/// that control exits the loop via ExitingBlock.
8302unsigned
8304 const BasicBlock *ExitingBlock) {
8305 assert(ExitingBlock && "Must pass a non-null exiting block!");
8306 assert(L->isLoopExiting(ExitingBlock) &&
8307 "Exiting block must actually branch out of the loop!");
8308 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8309 return getSmallConstantTripMultiple(L, ExitCount);
8310}
8311
8313 const BasicBlock *ExitingBlock,
8314 ExitCountKind Kind) {
8315 switch (Kind) {
8316 case Exact:
8317 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8318 case SymbolicMaximum:
8319 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8320 case ConstantMaximum:
8321 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8322 };
8323 llvm_unreachable("Invalid ExitCountKind!");
8324}
8325
8327 const Loop *L, const BasicBlock *ExitingBlock,
8329 switch (Kind) {
8330 case Exact:
8331 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8332 Predicates);
8333 case SymbolicMaximum:
8334 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8335 Predicates);
8336 case ConstantMaximum:
8337 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8338 Predicates);
8339 };
8340 llvm_unreachable("Invalid ExitCountKind!");
8341}
8342
8345 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8346}
8347
8349 ExitCountKind Kind) {
8350 switch (Kind) {
8351 case Exact:
8352 return getBackedgeTakenInfo(L).getExact(L, this);
8353 case ConstantMaximum:
8354 return getBackedgeTakenInfo(L).getConstantMax(this);
8355 case SymbolicMaximum:
8356 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8357 };
8358 llvm_unreachable("Invalid ExitCountKind!");
8359}
8360
8363 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8364}
8365
8368 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8369}
8370
8372 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8373}
8374
8375/// Push PHI nodes in the header of the given loop onto the given Worklist.
8376static void PushLoopPHIs(const Loop *L,
8379 BasicBlock *Header = L->getHeader();
8380
8381 // Push all Loop-header PHIs onto the Worklist stack.
8382 for (PHINode &PN : Header->phis())
8383 if (Visited.insert(&PN).second)
8384 Worklist.push_back(&PN);
8385}
8386
8387ScalarEvolution::BackedgeTakenInfo &
8388ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8389 auto &BTI = getBackedgeTakenInfo(L);
8390 if (BTI.hasFullInfo())
8391 return BTI;
8392
8393 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8394
8395 if (!Pair.second)
8396 return Pair.first->second;
8397
8398 BackedgeTakenInfo Result =
8399 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8400
8401 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8402}
8403
8404ScalarEvolution::BackedgeTakenInfo &
8405ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8406 // Initially insert an invalid entry for this loop. If the insertion
8407 // succeeds, proceed to actually compute a backedge-taken count and
8408 // update the value. The temporary CouldNotCompute value tells SCEV
8409 // code elsewhere that it shouldn't attempt to request a new
8410 // backedge-taken count, which could result in infinite recursion.
8411 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8412 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8413 if (!Pair.second)
8414 return Pair.first->second;
8415
8416 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8417 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8418 // must be cleared in this scope.
8419 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8420
8421 // Now that we know more about the trip count for this loop, forget any
8422 // existing SCEV values for PHI nodes in this loop since they are only
8423 // conservative estimates made without the benefit of trip count
8424 // information. This invalidation is not necessary for correctness, and is
8425 // only done to produce more precise results.
8426 if (Result.hasAnyInfo()) {
8427 // Invalidate any expression using an addrec in this loop.
8429 auto LoopUsersIt = LoopUsers.find(L);
8430 if (LoopUsersIt != LoopUsers.end())
8431 append_range(ToForget, LoopUsersIt->second);
8432 forgetMemoizedResults(ToForget);
8433
8434 // Invalidate constant-evolved loop header phis.
8435 for (PHINode &PN : L->getHeader()->phis())
8436 ConstantEvolutionLoopExitValue.erase(&PN);
8437 }
8438
8439 // Re-lookup the insert position, since the call to
8440 // computeBackedgeTakenCount above could result in a
8441 // recusive call to getBackedgeTakenInfo (on a different
8442 // loop), which would invalidate the iterator computed
8443 // earlier.
8444 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8445}
8446
8448 // This method is intended to forget all info about loops. It should
8449 // invalidate caches as if the following happened:
8450 // - The trip counts of all loops have changed arbitrarily
8451 // - Every llvm::Value has been updated in place to produce a different
8452 // result.
8453 BackedgeTakenCounts.clear();
8454 PredicatedBackedgeTakenCounts.clear();
8455 BECountUsers.clear();
8456 LoopPropertiesCache.clear();
8457 ConstantEvolutionLoopExitValue.clear();
8458 ValueExprMap.clear();
8459 ValuesAtScopes.clear();
8460 ValuesAtScopesUsers.clear();
8461 LoopDispositions.clear();
8462 BlockDispositions.clear();
8463 UnsignedRanges.clear();
8464 SignedRanges.clear();
8465 ExprValueMap.clear();
8466 HasRecMap.clear();
8467 ConstantMultipleCache.clear();
8468 PredicatedSCEVRewrites.clear();
8469 FoldCache.clear();
8470 FoldCacheUser.clear();
8471}
8472void ScalarEvolution::visitAndClearUsers(
8476 while (!Worklist.empty()) {
8477 Instruction *I = Worklist.pop_back_val();
8478 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8479 continue;
8480
8482 ValueExprMap.find_as(static_cast<Value *>(I));
8483 if (It != ValueExprMap.end()) {
8484 eraseValueFromMap(It->first);
8485 ToForget.push_back(It->second);
8486 if (PHINode *PN = dyn_cast<PHINode>(I))
8487 ConstantEvolutionLoopExitValue.erase(PN);
8488 }
8489
8490 PushDefUseChildren(I, Worklist, Visited);
8491 }
8492}
8493
8495 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8499
8500 // Iterate over all the loops and sub-loops to drop SCEV information.
8501 while (!LoopWorklist.empty()) {
8502 auto *CurrL = LoopWorklist.pop_back_val();
8503
8504 // Drop any stored trip count value.
8505 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8506 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8507
8508 // Drop information about predicated SCEV rewrites for this loop.
8509 for (auto I = PredicatedSCEVRewrites.begin();
8510 I != PredicatedSCEVRewrites.end();) {
8511 std::pair<const SCEV *, const Loop *> Entry = I->first;
8512 if (Entry.second == CurrL)
8513 PredicatedSCEVRewrites.erase(I++);
8514 else
8515 ++I;
8516 }
8517
8518 auto LoopUsersItr = LoopUsers.find(CurrL);
8519 if (LoopUsersItr != LoopUsers.end()) {
8520 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8521 LoopUsersItr->second.end());
8522 }
8523
8524 // Drop information about expressions based on loop-header PHIs.
8525 PushLoopPHIs(CurrL, Worklist, Visited);
8526 visitAndClearUsers(Worklist, Visited, ToForget);
8527
8528 LoopPropertiesCache.erase(CurrL);
8529 // Forget all contained loops too, to avoid dangling entries in the
8530 // ValuesAtScopes map.
8531 LoopWorklist.append(CurrL->begin(), CurrL->end());
8532 }
8533 forgetMemoizedResults(ToForget);
8534}
8535
8537 forgetLoop(L->getOutermostLoop());
8538}
8539
8541 Instruction *I = dyn_cast<Instruction>(V);
8542 if (!I) return;
8543
8544 // Drop information about expressions based on loop-header PHIs.
8548 Worklist.push_back(I);
8549 Visited.insert(I);
8550 visitAndClearUsers(Worklist, Visited, ToForget);
8551
8552 forgetMemoizedResults(ToForget);
8553}
8554
8556 if (!isSCEVable(V->getType()))
8557 return;
8558
8559 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8560 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8561 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8562 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8563 if (const SCEV *S = getExistingSCEV(V)) {
8564 struct InvalidationRootCollector {
8565 Loop *L;
8567
8568 InvalidationRootCollector(Loop *L) : L(L) {}
8569
8570 bool follow(const SCEV *S) {
8571 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8572 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8573 if (L->contains(I))
8574 Roots.push_back(S);
8575 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8576 if (L->contains(AddRec->getLoop()))
8577 Roots.push_back(S);
8578 }
8579 return true;
8580 }
8581 bool isDone() const { return false; }
8582 };
8583
8584 InvalidationRootCollector C(L);
8585 visitAll(S, C);
8586 forgetMemoizedResults(C.Roots);
8587 }
8588
8589 // Also perform the normal invalidation.
8590 forgetValue(V);
8591}
8592
8593void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8594
8596 // Unless a specific value is passed to invalidation, completely clear both
8597 // caches.
8598 if (!V) {
8599 BlockDispositions.clear();
8600 LoopDispositions.clear();
8601 return;
8602 }
8603
8604 if (!isSCEVable(V->getType()))
8605 return;
8606
8607 const SCEV *S = getExistingSCEV(V);
8608 if (!S)
8609 return;
8610
8611 // Invalidate the block and loop dispositions cached for S. Dispositions of
8612 // S's users may change if S's disposition changes (i.e. a user may change to
8613 // loop-invariant, if S changes to loop invariant), so also invalidate
8614 // dispositions of S's users recursively.
8615 SmallVector<const SCEV *, 8> Worklist = {S};
8617 while (!Worklist.empty()) {
8618 const SCEV *Curr = Worklist.pop_back_val();
8619 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8620 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8621 if (!LoopDispoRemoved && !BlockDispoRemoved)
8622 continue;
8623 auto Users = SCEVUsers.find(Curr);
8624 if (Users != SCEVUsers.end())
8625 for (const auto *User : Users->second)
8626 if (Seen.insert(User).second)
8627 Worklist.push_back(User);
8628 }
8629}
8630
8631/// Get the exact loop backedge taken count considering all loop exits. A
8632/// computable result can only be returned for loops with all exiting blocks
8633/// dominating the latch. howFarToZero assumes that the limit of each loop test
8634/// is never skipped. This is a valid assumption as long as the loop exits via
8635/// that test. For precise results, it is the caller's responsibility to specify
8636/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8637const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8638 const Loop *L, ScalarEvolution *SE,
8640 // If any exits were not computable, the loop is not computable.
8641 if (!isComplete() || ExitNotTaken.empty())
8642 return SE->getCouldNotCompute();
8643
8644 const BasicBlock *Latch = L->getLoopLatch();
8645 // All exiting blocks we have collected must dominate the only backedge.
8646 if (!Latch)
8647 return SE->getCouldNotCompute();
8648
8649 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8650 // count is simply a minimum out of all these calculated exit counts.
8652 for (const auto &ENT : ExitNotTaken) {
8653 const SCEV *BECount = ENT.ExactNotTaken;
8654 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8655 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8656 "We should only have known counts for exiting blocks that dominate "
8657 "latch!");
8658
8659 Ops.push_back(BECount);
8660
8661 if (Preds)
8662 append_range(*Preds, ENT.Predicates);
8663
8664 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8665 "Predicate should be always true!");
8666 }
8667
8668 // If an earlier exit exits on the first iteration (exit count zero), then
8669 // a later poison exit count should not propagate into the result. This are
8670 // exactly the semantics provided by umin_seq.
8671 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8672}
8673
8674const ScalarEvolution::ExitNotTakenInfo *
8675ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8676 const BasicBlock *ExitingBlock,
8677 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8678 for (const auto &ENT : ExitNotTaken)
8679 if (ENT.ExitingBlock == ExitingBlock) {
8680 if (ENT.hasAlwaysTruePredicate())
8681 return &ENT;
8682 else if (Predicates) {
8683 append_range(*Predicates, ENT.Predicates);
8684 return &ENT;
8685 }
8686 }
8687
8688 return nullptr;
8689}
8690
8691/// getConstantMax - Get the constant max backedge taken count for the loop.
8692const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8693 ScalarEvolution *SE,
8694 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8695 if (!getConstantMax())
8696 return SE->getCouldNotCompute();
8697
8698 for (const auto &ENT : ExitNotTaken)
8699 if (!ENT.hasAlwaysTruePredicate()) {
8700 if (!Predicates)
8701 return SE->getCouldNotCompute();
8702 append_range(*Predicates, ENT.Predicates);
8703 }
8704
8705 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8706 isa<SCEVConstant>(getConstantMax())) &&
8707 "No point in having a non-constant max backedge taken count!");
8708 return getConstantMax();
8709}
8710
8711const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8712 const Loop *L, ScalarEvolution *SE,
8714 if (!SymbolicMax) {
8715 // Form an expression for the maximum exit count possible for this loop. We
8716 // merge the max and exact information to approximate a version of
8717 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8718 // constants.
8720
8721 for (const auto &ENT : ExitNotTaken) {
8722 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8723 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8724 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8725 "We should only have known counts for exiting blocks that "
8726 "dominate latch!");
8727 ExitCounts.push_back(ExitCount);
8728 if (Predicates)
8729 append_range(*Predicates, ENT.Predicates);
8730
8731 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8732 "Predicate should be always true!");
8733 }
8734 }
8735 if (ExitCounts.empty())
8736 SymbolicMax = SE->getCouldNotCompute();
8737 else
8738 SymbolicMax =
8739 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8740 }
8741 return SymbolicMax;
8742}
8743
8744bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8745 ScalarEvolution *SE) const {
8746 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8747 return !ENT.hasAlwaysTruePredicate();
8748 };
8749 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8750}
8751
8753 : ExitLimit(E, E, E, false) {}
8754
8756 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8757 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8759 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8760 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8761 // If we prove the max count is zero, so is the symbolic bound. This happens
8762 // in practice due to differences in a) how context sensitive we've chosen
8763 // to be and b) how we reason about bounds implied by UB.
8764 if (ConstantMaxNotTaken->isZero()) {
8766 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8767 }
8768
8769 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8770 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8771 "Exact is not allowed to be less precise than Constant Max");
8772 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8773 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8774 "Exact is not allowed to be less precise than Symbolic Max");
8775 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8776 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8777 "Symbolic Max is not allowed to be less precise than Constant Max");
8778 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8779 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8780 "No point in having a non-constant max backedge taken count!");
8782 for (const auto PredList : PredLists)
8783 for (const auto *P : PredList) {
8784 if (SeenPreds.contains(P))
8785 continue;
8786 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8787 SeenPreds.insert(P);
8788 Predicates.push_back(P);
8789 }
8790 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8791 "Backedge count should be int");
8792 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8794 "Max backedge count should be int");
8795}
8796
8798 const SCEV *ConstantMaxNotTaken,
8799 const SCEV *SymbolicMaxNotTaken,
8800 bool MaxOrZero,
8802 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8803 ArrayRef({PredList})) {}
8804
8805/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8806/// computable exit into a persistent ExitNotTakenInfo array.
8807ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8809 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8810 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8811 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8812
8813 ExitNotTaken.reserve(ExitCounts.size());
8814 std::transform(ExitCounts.begin(), ExitCounts.end(),
8815 std::back_inserter(ExitNotTaken),
8816 [&](const EdgeExitInfo &EEI) {
8817 BasicBlock *ExitBB = EEI.first;
8818 const ExitLimit &EL = EEI.second;
8819 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8820 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8821 EL.Predicates);
8822 });
8823 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8824 isa<SCEVConstant>(ConstantMax)) &&
8825 "No point in having a non-constant max backedge taken count!");
8826}
8827
8828/// Compute the number of times the backedge of the specified loop will execute.
8829ScalarEvolution::BackedgeTakenInfo
8830ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8831 bool AllowPredicates) {
8832 SmallVector<BasicBlock *, 8> ExitingBlocks;
8833 L->getExitingBlocks(ExitingBlocks);
8834
8835 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8836
8838 bool CouldComputeBECount = true;
8839 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8840 const SCEV *MustExitMaxBECount = nullptr;
8841 const SCEV *MayExitMaxBECount = nullptr;
8842 bool MustExitMaxOrZero = false;
8843 bool IsOnlyExit = ExitingBlocks.size() == 1;
8844
8845 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8846 // and compute maxBECount.
8847 // Do a union of all the predicates here.
8848 for (BasicBlock *ExitBB : ExitingBlocks) {
8849 // We canonicalize untaken exits to br (constant), ignore them so that
8850 // proving an exit untaken doesn't negatively impact our ability to reason
8851 // about the loop as whole.
8852 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8853 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8854 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8855 if (ExitIfTrue == CI->isZero())
8856 continue;
8857 }
8858
8859 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8860
8861 assert((AllowPredicates || EL.Predicates.empty()) &&
8862 "Predicated exit limit when predicates are not allowed!");
8863
8864 // 1. For each exit that can be computed, add an entry to ExitCounts.
8865 // CouldComputeBECount is true only if all exits can be computed.
8866 if (EL.ExactNotTaken != getCouldNotCompute())
8867 ++NumExitCountsComputed;
8868 else
8869 // We couldn't compute an exact value for this exit, so
8870 // we won't be able to compute an exact value for the loop.
8871 CouldComputeBECount = false;
8872 // Remember exit count if either exact or symbolic is known. Because
8873 // Exact always implies symbolic, only check symbolic.
8874 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8875 ExitCounts.emplace_back(ExitBB, EL);
8876 else {
8877 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8878 "Exact is known but symbolic isn't?");
8879 ++NumExitCountsNotComputed;
8880 }
8881
8882 // 2. Derive the loop's MaxBECount from each exit's max number of
8883 // non-exiting iterations. Partition the loop exits into two kinds:
8884 // LoopMustExits and LoopMayExits.
8885 //
8886 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8887 // is a LoopMayExit. If any computable LoopMustExit is found, then
8888 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8889 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8890 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8891 // any
8892 // computable EL.ConstantMaxNotTaken.
8893 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8894 DT.dominates(ExitBB, Latch)) {
8895 if (!MustExitMaxBECount) {
8896 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8897 MustExitMaxOrZero = EL.MaxOrZero;
8898 } else {
8899 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8900 EL.ConstantMaxNotTaken);
8901 }
8902 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8903 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8904 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8905 else {
8906 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8907 EL.ConstantMaxNotTaken);
8908 }
8909 }
8910 }
8911 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8912 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8913 // The loop backedge will be taken the maximum or zero times if there's
8914 // a single exit that must be taken the maximum or zero times.
8915 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8916
8917 // Remember which SCEVs are used in exit limits for invalidation purposes.
8918 // We only care about non-constant SCEVs here, so we can ignore
8919 // EL.ConstantMaxNotTaken
8920 // and MaxBECount, which must be SCEVConstant.
8921 for (const auto &Pair : ExitCounts) {
8922 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8923 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8924 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8925 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8926 {L, AllowPredicates});
8927 }
8928 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8929 MaxBECount, MaxOrZero);
8930}
8931
8933ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8934 bool IsOnlyExit, bool AllowPredicates) {
8935 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8936 // If our exiting block does not dominate the latch, then its connection with
8937 // loop's exit limit may be far from trivial.
8938 const BasicBlock *Latch = L->getLoopLatch();
8939 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8940 return getCouldNotCompute();
8941
8942 Instruction *Term = ExitingBlock->getTerminator();
8943 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8944 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8945 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8946 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8947 "It should have one successor in loop and one exit block!");
8948 // Proceed to the next level to examine the exit condition expression.
8949 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8950 /*ControlsOnlyExit=*/IsOnlyExit,
8951 AllowPredicates);
8952 }
8953
8954 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8955 // For switch, make sure that there is a single exit from the loop.
8956 BasicBlock *Exit = nullptr;
8957 for (auto *SBB : successors(ExitingBlock))
8958 if (!L->contains(SBB)) {
8959 if (Exit) // Multiple exit successors.
8960 return getCouldNotCompute();
8961 Exit = SBB;
8962 }
8963 assert(Exit && "Exiting block must have at least one exit");
8964 return computeExitLimitFromSingleExitSwitch(
8965 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8966 }
8967
8968 return getCouldNotCompute();
8969}
8970
8972 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8973 bool AllowPredicates) {
8974 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8975 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8976 ControlsOnlyExit, AllowPredicates);
8977}
8978
8979std::optional<ScalarEvolution::ExitLimit>
8980ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8981 bool ExitIfTrue, bool ControlsOnlyExit,
8982 bool AllowPredicates) {
8983 (void)this->L;
8984 (void)this->ExitIfTrue;
8985 (void)this->AllowPredicates;
8986
8987 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8988 this->AllowPredicates == AllowPredicates &&
8989 "Variance in assumed invariant key components!");
8990 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8991 if (Itr == TripCountMap.end())
8992 return std::nullopt;
8993 return Itr->second;
8994}
8995
8996void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8997 bool ExitIfTrue,
8998 bool ControlsOnlyExit,
8999 bool AllowPredicates,
9000 const ExitLimit &EL) {
9001 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9002 this->AllowPredicates == AllowPredicates &&
9003 "Variance in assumed invariant key components!");
9004
9005 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9006 assert(InsertResult.second && "Expected successful insertion!");
9007 (void)InsertResult;
9008 (void)ExitIfTrue;
9009}
9010
9011ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9012 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9013 bool ControlsOnlyExit, bool AllowPredicates) {
9014
9015 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9016 AllowPredicates))
9017 return *MaybeEL;
9018
9019 ExitLimit EL = computeExitLimitFromCondImpl(
9020 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9021 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9022 return EL;
9023}
9024
9025ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9026 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9027 bool ControlsOnlyExit, bool AllowPredicates) {
9028 // Handle BinOp conditions (And, Or).
9029 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9030 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9031 return *LimitFromBinOp;
9032
9033 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9034 // Proceed to the next level to examine the icmp.
9035 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9036 ExitLimit EL =
9037 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9038 if (EL.hasFullInfo() || !AllowPredicates)
9039 return EL;
9040
9041 // Try again, but use SCEV predicates this time.
9042 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9043 ControlsOnlyExit,
9044 /*AllowPredicates=*/true);
9045 }
9046
9047 // Check for a constant condition. These are normally stripped out by
9048 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9049 // preserve the CFG and is temporarily leaving constant conditions
9050 // in place.
9051 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9052 if (ExitIfTrue == !CI->getZExtValue())
9053 // The backedge is always taken.
9054 return getCouldNotCompute();
9055 // The backedge is never taken.
9056 return getZero(CI->getType());
9057 }
9058
9059 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9060 // with a constant step, we can form an equivalent icmp predicate and figure
9061 // out how many iterations will be taken before we exit.
9062 const WithOverflowInst *WO;
9063 const APInt *C;
9064 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9065 match(WO->getRHS(), m_APInt(C))) {
9066 ConstantRange NWR =
9068 WO->getNoWrapKind());
9069 CmpInst::Predicate Pred;
9070 APInt NewRHSC, Offset;
9071 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9072 if (!ExitIfTrue)
9073 Pred = ICmpInst::getInversePredicate(Pred);
9074 auto *LHS = getSCEV(WO->getLHS());
9075 if (Offset != 0)
9077 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9078 ControlsOnlyExit, AllowPredicates);
9079 if (EL.hasAnyInfo())
9080 return EL;
9081 }
9082
9083 // If it's not an integer or pointer comparison then compute it the hard way.
9084 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9085}
9086
9087std::optional<ScalarEvolution::ExitLimit>
9088ScalarEvolution::computeExitLimitFromCondFromBinOp(
9089 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9090 bool ControlsOnlyExit, bool AllowPredicates) {
9091 // Check if the controlling expression for this loop is an And or Or.
9092 Value *Op0, *Op1;
9093 bool IsAnd = false;
9094 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9095 IsAnd = true;
9096 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9097 IsAnd = false;
9098 else
9099 return std::nullopt;
9100
9101 // EitherMayExit is true in these two cases:
9102 // br (and Op0 Op1), loop, exit
9103 // br (or Op0 Op1), exit, loop
9104 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9105 ExitLimit EL0 = computeExitLimitFromCondCached(
9106 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9107 AllowPredicates);
9108 ExitLimit EL1 = computeExitLimitFromCondCached(
9109 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9110 AllowPredicates);
9111
9112 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9113 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9114 if (isa<ConstantInt>(Op1))
9115 return Op1 == NeutralElement ? EL0 : EL1;
9116 if (isa<ConstantInt>(Op0))
9117 return Op0 == NeutralElement ? EL1 : EL0;
9118
9119 const SCEV *BECount = getCouldNotCompute();
9120 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9121 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9122 if (EitherMayExit) {
9123 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9124 // Both conditions must be same for the loop to continue executing.
9125 // Choose the less conservative count.
9126 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9127 EL1.ExactNotTaken != getCouldNotCompute()) {
9128 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9129 UseSequentialUMin);
9130 }
9131 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9132 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9133 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9134 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9135 else
9136 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9137 EL1.ConstantMaxNotTaken);
9138 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9139 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9140 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9141 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9142 else
9143 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9144 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9145 } else {
9146 // Both conditions must be same at the same time for the loop to exit.
9147 // For now, be conservative.
9148 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9149 BECount = EL0.ExactNotTaken;
9150 }
9151
9152 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9153 // to be more aggressive when computing BECount than when computing
9154 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9155 // and
9156 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9157 // EL1.ConstantMaxNotTaken to not.
9158 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9159 !isa<SCEVCouldNotCompute>(BECount))
9160 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9161 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9162 SymbolicMaxBECount =
9163 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9164 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9165 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9166}
9167
9168ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9169 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9170 bool AllowPredicates) {
9171 // If the condition was exit on true, convert the condition to exit on false
9172 CmpPredicate Pred;
9173 if (!ExitIfTrue)
9174 Pred = ExitCond->getCmpPredicate();
9175 else
9176 Pred = ExitCond->getInverseCmpPredicate();
9177 const ICmpInst::Predicate OriginalPred = Pred;
9178
9179 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9180 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9181
9182 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9183 AllowPredicates);
9184 if (EL.hasAnyInfo())
9185 return EL;
9186
9187 auto *ExhaustiveCount =
9188 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9189
9190 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9191 return ExhaustiveCount;
9192
9193 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9194 ExitCond->getOperand(1), L, OriginalPred);
9195}
9196ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9197 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9198 bool ControlsOnlyExit, bool AllowPredicates) {
9199
9200 // Try to evaluate any dependencies out of the loop.
9201 LHS = getSCEVAtScope(LHS, L);
9202 RHS = getSCEVAtScope(RHS, L);
9203
9204 // At this point, we would like to compute how many iterations of the
9205 // loop the predicate will return true for these inputs.
9206 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9207 // If there is a loop-invariant, force it into the RHS.
9208 std::swap(LHS, RHS);
9210 }
9211
9212 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9214 // Simplify the operands before analyzing them.
9215 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9216
9217 // If we have a comparison of a chrec against a constant, try to use value
9218 // ranges to answer this query.
9219 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9220 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9221 if (AddRec->getLoop() == L) {
9222 // Form the constant range.
9223 ConstantRange CompRange =
9224 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9225
9226 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9227 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9228 }
9229
9230 // If this loop must exit based on this condition (or execute undefined
9231 // behaviour), see if we can improve wrap flags. This is essentially
9232 // a must execute style proof.
9233 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9234 // If we can prove the test sequence produced must repeat the same values
9235 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9236 // because if it did, we'd have an infinite (undefined) loop.
9237 // TODO: We can peel off any functions which are invertible *in L*. Loop
9238 // invariant terms are effectively constants for our purposes here.
9239 auto *InnerLHS = LHS;
9240 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9241 InnerLHS = ZExt->getOperand();
9242 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9243 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9244 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9245 /*OrNegative=*/true)) {
9246 auto Flags = AR->getNoWrapFlags();
9247 Flags = setFlags(Flags, SCEV::FlagNW);
9250 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9251 }
9252
9253 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9254 // From no-self-wrap, this follows trivially from the fact that every
9255 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9256 // last value before (un)signed wrap. Since we know that last value
9257 // didn't exit, nor will any smaller one.
9258 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9259 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9260 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9261 AR && AR->getLoop() == L && AR->isAffine() &&
9262 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9263 isKnownPositive(AR->getStepRecurrence(*this))) {
9264 auto Flags = AR->getNoWrapFlags();
9265 Flags = setFlags(Flags, WrapType);
9268 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9269 }
9270 }
9271 }
9272
9273 switch (Pred) {
9274 case ICmpInst::ICMP_NE: { // while (X != Y)
9275 // Convert to: while (X-Y != 0)
9276 if (LHS->getType()->isPointerTy()) {
9278 if (isa<SCEVCouldNotCompute>(LHS))
9279 return LHS;
9280 }
9281 if (RHS->getType()->isPointerTy()) {
9283 if (isa<SCEVCouldNotCompute>(RHS))
9284 return RHS;
9285 }
9286 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9287 AllowPredicates);
9288 if (EL.hasAnyInfo())
9289 return EL;
9290 break;
9291 }
9292 case ICmpInst::ICMP_EQ: { // while (X == Y)
9293 // Convert to: while (X-Y == 0)
9294 if (LHS->getType()->isPointerTy()) {
9296 if (isa<SCEVCouldNotCompute>(LHS))
9297 return LHS;
9298 }
9299 if (RHS->getType()->isPointerTy()) {
9301 if (isa<SCEVCouldNotCompute>(RHS))
9302 return RHS;
9303 }
9304 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9305 if (EL.hasAnyInfo()) return EL;
9306 break;
9307 }
9308 case ICmpInst::ICMP_SLE:
9309 case ICmpInst::ICMP_ULE:
9310 // Since the loop is finite, an invariant RHS cannot include the boundary
9311 // value, otherwise it would loop forever.
9312 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9313 !isLoopInvariant(RHS, L)) {
9314 // Otherwise, perform the addition in a wider type, to avoid overflow.
9315 // If the LHS is an addrec with the appropriate nowrap flag, the
9316 // extension will be sunk into it and the exit count can be analyzed.
9317 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9318 if (!OldType)
9319 break;
9320 // Prefer doubling the bitwidth over adding a single bit to make it more
9321 // likely that we use a legal type.
9322 auto *NewType =
9323 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9324 if (ICmpInst::isSigned(Pred)) {
9325 LHS = getSignExtendExpr(LHS, NewType);
9326 RHS = getSignExtendExpr(RHS, NewType);
9327 } else {
9328 LHS = getZeroExtendExpr(LHS, NewType);
9329 RHS = getZeroExtendExpr(RHS, NewType);
9330 }
9331 }
9332 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9333 [[fallthrough]];
9334 case ICmpInst::ICMP_SLT:
9335 case ICmpInst::ICMP_ULT: { // while (X < Y)
9336 bool IsSigned = ICmpInst::isSigned(Pred);
9337 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9338 AllowPredicates);
9339 if (EL.hasAnyInfo())
9340 return EL;
9341 break;
9342 }
9343 case ICmpInst::ICMP_SGE:
9344 case ICmpInst::ICMP_UGE:
9345 // Since the loop is finite, an invariant RHS cannot include the boundary
9346 // value, otherwise it would loop forever.
9347 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9348 !isLoopInvariant(RHS, L))
9349 break;
9350 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9351 [[fallthrough]];
9352 case ICmpInst::ICMP_SGT:
9353 case ICmpInst::ICMP_UGT: { // while (X > Y)
9354 bool IsSigned = ICmpInst::isSigned(Pred);
9355 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9356 AllowPredicates);
9357 if (EL.hasAnyInfo())
9358 return EL;
9359 break;
9360 }
9361 default:
9362 break;
9363 }
9364
9365 return getCouldNotCompute();
9366}
9367
9369ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9370 SwitchInst *Switch,
9371 BasicBlock *ExitingBlock,
9372 bool ControlsOnlyExit) {
9373 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9374
9375 // Give up if the exit is the default dest of a switch.
9376 if (Switch->getDefaultDest() == ExitingBlock)
9377 return getCouldNotCompute();
9378
9379 assert(L->contains(Switch->getDefaultDest()) &&
9380 "Default case must not exit the loop!");
9381 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9382 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9383
9384 // while (X != Y) --> while (X-Y != 0)
9385 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9386 if (EL.hasAnyInfo())
9387 return EL;
9388
9389 return getCouldNotCompute();
9390}
9391
9392static ConstantInt *
9394 ScalarEvolution &SE) {
9395 const SCEV *InVal = SE.getConstant(C);
9396 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9397 assert(isa<SCEVConstant>(Val) &&
9398 "Evaluation of SCEV at constant didn't fold correctly?");
9399 return cast<SCEVConstant>(Val)->getValue();
9400}
9401
9402ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9403 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9404 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9405 if (!RHS)
9406 return getCouldNotCompute();
9407
9408 const BasicBlock *Latch = L->getLoopLatch();
9409 if (!Latch)
9410 return getCouldNotCompute();
9411
9412 const BasicBlock *Predecessor = L->getLoopPredecessor();
9413 if (!Predecessor)
9414 return getCouldNotCompute();
9415
9416 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9417 // Return LHS in OutLHS and shift_opt in OutOpCode.
9418 auto MatchPositiveShift =
9419 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9420
9421 using namespace PatternMatch;
9422
9423 ConstantInt *ShiftAmt;
9424 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9425 OutOpCode = Instruction::LShr;
9426 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9427 OutOpCode = Instruction::AShr;
9428 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9429 OutOpCode = Instruction::Shl;
9430 else
9431 return false;
9432
9433 return ShiftAmt->getValue().isStrictlyPositive();
9434 };
9435
9436 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9437 //
9438 // loop:
9439 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9440 // %iv.shifted = lshr i32 %iv, <positive constant>
9441 //
9442 // Return true on a successful match. Return the corresponding PHI node (%iv
9443 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9444 auto MatchShiftRecurrence =
9445 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9446 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9447
9448 {
9450 Value *V;
9451
9452 // If we encounter a shift instruction, "peel off" the shift operation,
9453 // and remember that we did so. Later when we inspect %iv's backedge
9454 // value, we will make sure that the backedge value uses the same
9455 // operation.
9456 //
9457 // Note: the peeled shift operation does not have to be the same
9458 // instruction as the one feeding into the PHI's backedge value. We only
9459 // really care about it being the same *kind* of shift instruction --
9460 // that's all that is required for our later inferences to hold.
9461 if (MatchPositiveShift(LHS, V, OpC)) {
9462 PostShiftOpCode = OpC;
9463 LHS = V;
9464 }
9465 }
9466
9467 PNOut = dyn_cast<PHINode>(LHS);
9468 if (!PNOut || PNOut->getParent() != L->getHeader())
9469 return false;
9470
9471 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9472 Value *OpLHS;
9473
9474 return
9475 // The backedge value for the PHI node must be a shift by a positive
9476 // amount
9477 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9478
9479 // of the PHI node itself
9480 OpLHS == PNOut &&
9481
9482 // and the kind of shift should be match the kind of shift we peeled
9483 // off, if any.
9484 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9485 };
9486
9487 PHINode *PN;
9489 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9490 return getCouldNotCompute();
9491
9492 const DataLayout &DL = getDataLayout();
9493
9494 // The key rationale for this optimization is that for some kinds of shift
9495 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9496 // within a finite number of iterations. If the condition guarding the
9497 // backedge (in the sense that the backedge is taken if the condition is true)
9498 // is false for the value the shift recurrence stabilizes to, then we know
9499 // that the backedge is taken only a finite number of times.
9500
9501 ConstantInt *StableValue = nullptr;
9502 switch (OpCode) {
9503 default:
9504 llvm_unreachable("Impossible case!");
9505
9506 case Instruction::AShr: {
9507 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9508 // bitwidth(K) iterations.
9509 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9510 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9511 Predecessor->getTerminator(), &DT);
9512 auto *Ty = cast<IntegerType>(RHS->getType());
9513 if (Known.isNonNegative())
9514 StableValue = ConstantInt::get(Ty, 0);
9515 else if (Known.isNegative())
9516 StableValue = ConstantInt::get(Ty, -1, true);
9517 else
9518 return getCouldNotCompute();
9519
9520 break;
9521 }
9522 case Instruction::LShr:
9523 case Instruction::Shl:
9524 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9525 // stabilize to 0 in at most bitwidth(K) iterations.
9526 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9527 break;
9528 }
9529
9530 auto *Result =
9531 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9532 assert(Result->getType()->isIntegerTy(1) &&
9533 "Otherwise cannot be an operand to a branch instruction");
9534
9535 if (Result->isZeroValue()) {
9536 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9537 const SCEV *UpperBound =
9539 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9540 }
9541
9542 return getCouldNotCompute();
9543}
9544
9545/// Return true if we can constant fold an instruction of the specified type,
9546/// assuming that all operands were constants.
9547static bool CanConstantFold(const Instruction *I) {
9548 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9549 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9550 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9551 return true;
9552
9553 if (const CallInst *CI = dyn_cast<CallInst>(I))
9554 if (const Function *F = CI->getCalledFunction())
9555 return canConstantFoldCallTo(CI, F);
9556 return false;
9557}
9558
9559/// Determine whether this instruction can constant evolve within this loop
9560/// assuming its operands can all constant evolve.
9561static bool canConstantEvolve(Instruction *I, const Loop *L) {
9562 // An instruction outside of the loop can't be derived from a loop PHI.
9563 if (!L->contains(I)) return false;
9564
9565 if (isa<PHINode>(I)) {
9566 // We don't currently keep track of the control flow needed to evaluate
9567 // PHIs, so we cannot handle PHIs inside of loops.
9568 return L->getHeader() == I->getParent();
9569 }
9570
9571 // If we won't be able to constant fold this expression even if the operands
9572 // are constants, bail early.
9573 return CanConstantFold(I);
9574}
9575
9576/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9577/// recursing through each instruction operand until reaching a loop header phi.
9578static PHINode *
9581 unsigned Depth) {
9583 return nullptr;
9584
9585 // Otherwise, we can evaluate this instruction if all of its operands are
9586 // constant or derived from a PHI node themselves.
9587 PHINode *PHI = nullptr;
9588 for (Value *Op : UseInst->operands()) {
9589 if (isa<Constant>(Op)) continue;
9590
9591 Instruction *OpInst = dyn_cast<Instruction>(Op);
9592 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9593
9594 PHINode *P = dyn_cast<PHINode>(OpInst);
9595 if (!P)
9596 // If this operand is already visited, reuse the prior result.
9597 // We may have P != PHI if this is the deepest point at which the
9598 // inconsistent paths meet.
9599 P = PHIMap.lookup(OpInst);
9600 if (!P) {
9601 // Recurse and memoize the results, whether a phi is found or not.
9602 // This recursive call invalidates pointers into PHIMap.
9603 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9604 PHIMap[OpInst] = P;
9605 }
9606 if (!P)
9607 return nullptr; // Not evolving from PHI
9608 if (PHI && PHI != P)
9609 return nullptr; // Evolving from multiple different PHIs.
9610 PHI = P;
9611 }
9612 // This is a expression evolving from a constant PHI!
9613 return PHI;
9614}
9615
9616/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9617/// in the loop that V is derived from. We allow arbitrary operations along the
9618/// way, but the operands of an operation must either be constants or a value
9619/// derived from a constant PHI. If this expression does not fit with these
9620/// constraints, return null.
9622 Instruction *I = dyn_cast<Instruction>(V);
9623 if (!I || !canConstantEvolve(I, L)) return nullptr;
9624
9625 if (PHINode *PN = dyn_cast<PHINode>(I))
9626 return PN;
9627
9628 // Record non-constant instructions contained by the loop.
9630 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9631}
9632
9633/// EvaluateExpression - Given an expression that passes the
9634/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9635/// in the loop has the value PHIVal. If we can't fold this expression for some
9636/// reason, return null.
9639 const DataLayout &DL,
9640 const TargetLibraryInfo *TLI) {
9641 // Convenient constant check, but redundant for recursive calls.
9642 if (Constant *C = dyn_cast<Constant>(V)) return C;
9643 Instruction *I = dyn_cast<Instruction>(V);
9644 if (!I) return nullptr;
9645
9646 if (Constant *C = Vals.lookup(I)) return C;
9647
9648 // An instruction inside the loop depends on a value outside the loop that we
9649 // weren't given a mapping for, or a value such as a call inside the loop.
9650 if (!canConstantEvolve(I, L)) return nullptr;
9651
9652 // An unmapped PHI can be due to a branch or another loop inside this loop,
9653 // or due to this not being the initial iteration through a loop where we
9654 // couldn't compute the evolution of this particular PHI last time.
9655 if (isa<PHINode>(I)) return nullptr;
9656
9657 std::vector<Constant*> Operands(I->getNumOperands());
9658
9659 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9660 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9661 if (!Operand) {
9662 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9663 if (!Operands[i]) return nullptr;
9664 continue;
9665 }
9666 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9667 Vals[Operand] = C;
9668 if (!C) return nullptr;
9669 Operands[i] = C;
9670 }
9671
9672 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9673 /*AllowNonDeterministic=*/false);
9674}
9675
9676
9677// If every incoming value to PN except the one for BB is a specific Constant,
9678// return that, else return nullptr.
9680 Constant *IncomingVal = nullptr;
9681
9682 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9683 if (PN->getIncomingBlock(i) == BB)
9684 continue;
9685
9686 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9687 if (!CurrentVal)
9688 return nullptr;
9689
9690 if (IncomingVal != CurrentVal) {
9691 if (IncomingVal)
9692 return nullptr;
9693 IncomingVal = CurrentVal;
9694 }
9695 }
9696
9697 return IncomingVal;
9698}
9699
9700/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9701/// in the header of its containing loop, we know the loop executes a
9702/// constant number of times, and the PHI node is just a recurrence
9703/// involving constants, fold it.
9704Constant *
9705ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9706 const APInt &BEs,
9707 const Loop *L) {
9708 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9709 if (!Inserted)
9710 return I->second;
9711
9713 return nullptr; // Not going to evaluate it.
9714
9715 Constant *&RetVal = I->second;
9716
9718 BasicBlock *Header = L->getHeader();
9719 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9720
9721 BasicBlock *Latch = L->getLoopLatch();
9722 if (!Latch)
9723 return nullptr;
9724
9725 for (PHINode &PHI : Header->phis()) {
9726 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9727 CurrentIterVals[&PHI] = StartCST;
9728 }
9729 if (!CurrentIterVals.count(PN))
9730 return RetVal = nullptr;
9731
9732 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9733
9734 // Execute the loop symbolically to determine the exit value.
9735 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9736 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9737
9738 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9739 unsigned IterationNum = 0;
9740 const DataLayout &DL = getDataLayout();
9741 for (; ; ++IterationNum) {
9742 if (IterationNum == NumIterations)
9743 return RetVal = CurrentIterVals[PN]; // Got exit value!
9744
9745 // Compute the value of the PHIs for the next iteration.
9746 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9748 Constant *NextPHI =
9749 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9750 if (!NextPHI)
9751 return nullptr; // Couldn't evaluate!
9752 NextIterVals[PN] = NextPHI;
9753
9754 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9755
9756 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9757 // cease to be able to evaluate one of them or if they stop evolving,
9758 // because that doesn't necessarily prevent us from computing PN.
9760 for (const auto &I : CurrentIterVals) {
9761 PHINode *PHI = dyn_cast<PHINode>(I.first);
9762 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9763 PHIsToCompute.emplace_back(PHI, I.second);
9764 }
9765 // We use two distinct loops because EvaluateExpression may invalidate any
9766 // iterators into CurrentIterVals.
9767 for (const auto &I : PHIsToCompute) {
9768 PHINode *PHI = I.first;
9769 Constant *&NextPHI = NextIterVals[PHI];
9770 if (!NextPHI) { // Not already computed.
9771 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9772 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9773 }
9774 if (NextPHI != I.second)
9775 StoppedEvolving = false;
9776 }
9777
9778 // If all entries in CurrentIterVals == NextIterVals then we can stop
9779 // iterating, the loop can't continue to change.
9780 if (StoppedEvolving)
9781 return RetVal = CurrentIterVals[PN];
9782
9783 CurrentIterVals.swap(NextIterVals);
9784 }
9785}
9786
9787const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9788 Value *Cond,
9789 bool ExitWhen) {
9791 if (!PN) return getCouldNotCompute();
9792
9793 // If the loop is canonicalized, the PHI will have exactly two entries.
9794 // That's the only form we support here.
9795 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9796
9798 BasicBlock *Header = L->getHeader();
9799 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9800
9801 BasicBlock *Latch = L->getLoopLatch();
9802 assert(Latch && "Should follow from NumIncomingValues == 2!");
9803
9804 for (PHINode &PHI : Header->phis()) {
9805 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9806 CurrentIterVals[&PHI] = StartCST;
9807 }
9808 if (!CurrentIterVals.count(PN))
9809 return getCouldNotCompute();
9810
9811 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9812 // the loop symbolically to determine when the condition gets a value of
9813 // "ExitWhen".
9814 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9815 const DataLayout &DL = getDataLayout();
9816 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9817 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9818 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9819
9820 // Couldn't symbolically evaluate.
9821 if (!CondVal) return getCouldNotCompute();
9822
9823 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9824 ++NumBruteForceTripCountsComputed;
9825 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9826 }
9827
9828 // Update all the PHI nodes for the next iteration.
9830
9831 // Create a list of which PHIs we need to compute. We want to do this before
9832 // calling EvaluateExpression on them because that may invalidate iterators
9833 // into CurrentIterVals.
9834 SmallVector<PHINode *, 8> PHIsToCompute;
9835 for (const auto &I : CurrentIterVals) {
9836 PHINode *PHI = dyn_cast<PHINode>(I.first);
9837 if (!PHI || PHI->getParent() != Header) continue;
9838 PHIsToCompute.push_back(PHI);
9839 }
9840 for (PHINode *PHI : PHIsToCompute) {
9841 Constant *&NextPHI = NextIterVals[PHI];
9842 if (NextPHI) continue; // Already computed!
9843
9844 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9845 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9846 }
9847 CurrentIterVals.swap(NextIterVals);
9848 }
9849
9850 // Too many iterations were needed to evaluate.
9851 return getCouldNotCompute();
9852}
9853
9854const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9856 ValuesAtScopes[V];
9857 // Check to see if we've folded this expression at this loop before.
9858 for (auto &LS : Values)
9859 if (LS.first == L)
9860 return LS.second ? LS.second : V;
9861
9862 Values.emplace_back(L, nullptr);
9863
9864 // Otherwise compute it.
9865 const SCEV *C = computeSCEVAtScope(V, L);
9866 for (auto &LS : reverse(ValuesAtScopes[V]))
9867 if (LS.first == L) {
9868 LS.second = C;
9869 if (!isa<SCEVConstant>(C))
9870 ValuesAtScopesUsers[C].push_back({L, V});
9871 break;
9872 }
9873 return C;
9874}
9875
9876/// This builds up a Constant using the ConstantExpr interface. That way, we
9877/// will return Constants for objects which aren't represented by a
9878/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9879/// Returns NULL if the SCEV isn't representable as a Constant.
9881 switch (V->getSCEVType()) {
9882 case scCouldNotCompute:
9883 case scAddRecExpr:
9884 case scVScale:
9885 return nullptr;
9886 case scConstant:
9887 return cast<SCEVConstant>(V)->getValue();
9888 case scUnknown:
9889 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9890 case scPtrToInt: {
9891 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9892 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9893 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9894
9895 return nullptr;
9896 }
9897 case scTruncate: {
9898 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9899 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9900 return ConstantExpr::getTrunc(CastOp, ST->getType());
9901 return nullptr;
9902 }
9903 case scAddExpr: {
9904 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9905 Constant *C = nullptr;
9906 for (const SCEV *Op : SA->operands()) {
9908 if (!OpC)
9909 return nullptr;
9910 if (!C) {
9911 C = OpC;
9912 continue;
9913 }
9914 assert(!C->getType()->isPointerTy() &&
9915 "Can only have one pointer, and it must be last");
9916 if (OpC->getType()->isPointerTy()) {
9917 // The offsets have been converted to bytes. We can add bytes using
9918 // an i8 GEP.
9920 OpC, C);
9921 } else {
9922 C = ConstantExpr::getAdd(C, OpC);
9923 }
9924 }
9925 return C;
9926 }
9927 case scMulExpr:
9928 case scSignExtend:
9929 case scZeroExtend:
9930 case scUDivExpr:
9931 case scSMaxExpr:
9932 case scUMaxExpr:
9933 case scSMinExpr:
9934 case scUMinExpr:
9936 return nullptr;
9937 }
9938 llvm_unreachable("Unknown SCEV kind!");
9939}
9940
9941const SCEV *
9942ScalarEvolution::getWithOperands(const SCEV *S,
9944 switch (S->getSCEVType()) {
9945 case scTruncate:
9946 case scZeroExtend:
9947 case scSignExtend:
9948 case scPtrToInt:
9949 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9950 case scAddRecExpr: {
9951 auto *AddRec = cast<SCEVAddRecExpr>(S);
9952 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9953 }
9954 case scAddExpr:
9955 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9956 case scMulExpr:
9957 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9958 case scUDivExpr:
9959 return getUDivExpr(NewOps[0], NewOps[1]);
9960 case scUMaxExpr:
9961 case scSMaxExpr:
9962 case scUMinExpr:
9963 case scSMinExpr:
9964 return getMinMaxExpr(S->getSCEVType(), NewOps);
9966 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9967 case scConstant:
9968 case scVScale:
9969 case scUnknown:
9970 return S;
9971 case scCouldNotCompute:
9972 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9973 }
9974 llvm_unreachable("Unknown SCEV kind!");
9975}
9976
9977const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9978 switch (V->getSCEVType()) {
9979 case scConstant:
9980 case scVScale:
9981 return V;
9982 case scAddRecExpr: {
9983 // If this is a loop recurrence for a loop that does not contain L, then we
9984 // are dealing with the final value computed by the loop.
9985 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9986 // First, attempt to evaluate each operand.
9987 // Avoid performing the look-up in the common case where the specified
9988 // expression has no loop-variant portions.
9989 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9990 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9991 if (OpAtScope == AddRec->getOperand(i))
9992 continue;
9993
9994 // Okay, at least one of these operands is loop variant but might be
9995 // foldable. Build a new instance of the folded commutative expression.
9997 NewOps.reserve(AddRec->getNumOperands());
9998 append_range(NewOps, AddRec->operands().take_front(i));
9999 NewOps.push_back(OpAtScope);
10000 for (++i; i != e; ++i)
10001 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10002
10003 const SCEV *FoldedRec = getAddRecExpr(
10004 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10005 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10006 // The addrec may be folded to a nonrecurrence, for example, if the
10007 // induction variable is multiplied by zero after constant folding. Go
10008 // ahead and return the folded value.
10009 if (!AddRec)
10010 return FoldedRec;
10011 break;
10012 }
10013
10014 // If the scope is outside the addrec's loop, evaluate it by using the
10015 // loop exit value of the addrec.
10016 if (!AddRec->getLoop()->contains(L)) {
10017 // To evaluate this recurrence, we need to know how many times the AddRec
10018 // loop iterates. Compute this now.
10019 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10020 if (BackedgeTakenCount == getCouldNotCompute())
10021 return AddRec;
10022
10023 // Then, evaluate the AddRec.
10024 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10025 }
10026
10027 return AddRec;
10028 }
10029 case scTruncate:
10030 case scZeroExtend:
10031 case scSignExtend:
10032 case scPtrToInt:
10033 case scAddExpr:
10034 case scMulExpr:
10035 case scUDivExpr:
10036 case scUMaxExpr:
10037 case scSMaxExpr:
10038 case scUMinExpr:
10039 case scSMinExpr:
10040 case scSequentialUMinExpr: {
10041 ArrayRef<const SCEV *> Ops = V->operands();
10042 // Avoid performing the look-up in the common case where the specified
10043 // expression has no loop-variant portions.
10044 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10045 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10046 if (OpAtScope != Ops[i]) {
10047 // Okay, at least one of these operands is loop variant but might be
10048 // foldable. Build a new instance of the folded commutative expression.
10050 NewOps.reserve(Ops.size());
10051 append_range(NewOps, Ops.take_front(i));
10052 NewOps.push_back(OpAtScope);
10053
10054 for (++i; i != e; ++i) {
10055 OpAtScope = getSCEVAtScope(Ops[i], L);
10056 NewOps.push_back(OpAtScope);
10057 }
10058
10059 return getWithOperands(V, NewOps);
10060 }
10061 }
10062 // If we got here, all operands are loop invariant.
10063 return V;
10064 }
10065 case scUnknown: {
10066 // If this instruction is evolved from a constant-evolving PHI, compute the
10067 // exit value from the loop without using SCEVs.
10068 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10069 Instruction *I = dyn_cast<Instruction>(SU->getValue());
10070 if (!I)
10071 return V; // This is some other type of SCEVUnknown, just return it.
10072
10073 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10074 const Loop *CurrLoop = this->LI[I->getParent()];
10075 // Looking for loop exit value.
10076 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10077 PN->getParent() == CurrLoop->getHeader()) {
10078 // Okay, there is no closed form solution for the PHI node. Check
10079 // to see if the loop that contains it has a known backedge-taken
10080 // count. If so, we may be able to force computation of the exit
10081 // value.
10082 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10083 // This trivial case can show up in some degenerate cases where
10084 // the incoming IR has not yet been fully simplified.
10085 if (BackedgeTakenCount->isZero()) {
10086 Value *InitValue = nullptr;
10087 bool MultipleInitValues = false;
10088 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10089 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10090 if (!InitValue)
10091 InitValue = PN->getIncomingValue(i);
10092 else if (InitValue != PN->getIncomingValue(i)) {
10093 MultipleInitValues = true;
10094 break;
10095 }
10096 }
10097 }
10098 if (!MultipleInitValues && InitValue)
10099 return getSCEV(InitValue);
10100 }
10101 // Do we have a loop invariant value flowing around the backedge
10102 // for a loop which must execute the backedge?
10103 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10104 isKnownNonZero(BackedgeTakenCount) &&
10105 PN->getNumIncomingValues() == 2) {
10106
10107 unsigned InLoopPred =
10108 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10109 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10110 if (CurrLoop->isLoopInvariant(BackedgeVal))
10111 return getSCEV(BackedgeVal);
10112 }
10113 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10114 // Okay, we know how many times the containing loop executes. If
10115 // this is a constant evolving PHI node, get the final value at
10116 // the specified iteration number.
10117 Constant *RV =
10118 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10119 if (RV)
10120 return getSCEV(RV);
10121 }
10122 }
10123 }
10124
10125 // Okay, this is an expression that we cannot symbolically evaluate
10126 // into a SCEV. Check to see if it's possible to symbolically evaluate
10127 // the arguments into constants, and if so, try to constant propagate the
10128 // result. This is particularly useful for computing loop exit values.
10129 if (!CanConstantFold(I))
10130 return V; // This is some other type of SCEVUnknown, just return it.
10131
10133 Operands.reserve(I->getNumOperands());
10134 bool MadeImprovement = false;
10135 for (Value *Op : I->operands()) {
10136 if (Constant *C = dyn_cast<Constant>(Op)) {
10137 Operands.push_back(C);
10138 continue;
10139 }
10140
10141 // If any of the operands is non-constant and if they are
10142 // non-integer and non-pointer, don't even try to analyze them
10143 // with scev techniques.
10144 if (!isSCEVable(Op->getType()))
10145 return V;
10146
10147 const SCEV *OrigV = getSCEV(Op);
10148 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10149 MadeImprovement |= OrigV != OpV;
10150
10152 if (!C)
10153 return V;
10154 assert(C->getType() == Op->getType() && "Type mismatch");
10155 Operands.push_back(C);
10156 }
10157
10158 // Check to see if getSCEVAtScope actually made an improvement.
10159 if (!MadeImprovement)
10160 return V; // This is some other type of SCEVUnknown, just return it.
10161
10162 Constant *C = nullptr;
10163 const DataLayout &DL = getDataLayout();
10164 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10165 /*AllowNonDeterministic=*/false);
10166 if (!C)
10167 return V;
10168 return getSCEV(C);
10169 }
10170 case scCouldNotCompute:
10171 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10172 }
10173 llvm_unreachable("Unknown SCEV type!");
10174}
10175
10177 return getSCEVAtScope(getSCEV(V), L);
10178}
10179
10180const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10181 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10182 return stripInjectiveFunctions(ZExt->getOperand());
10183 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10184 return stripInjectiveFunctions(SExt->getOperand());
10185 return S;
10186}
10187
10188/// Finds the minimum unsigned root of the following equation:
10189///
10190/// A * X = B (mod N)
10191///
10192/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10193/// A and B isn't important.
10194///
10195/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10196static const SCEV *
10199
10200 ScalarEvolution &SE) {
10201 uint32_t BW = A.getBitWidth();
10202 assert(BW == SE.getTypeSizeInBits(B->getType()));
10203 assert(A != 0 && "A must be non-zero.");
10204
10205 // 1. D = gcd(A, N)
10206 //
10207 // The gcd of A and N may have only one prime factor: 2. The number of
10208 // trailing zeros in A is its multiplicity
10209 uint32_t Mult2 = A.countr_zero();
10210 // D = 2^Mult2
10211
10212 // 2. Check if B is divisible by D.
10213 //
10214 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10215 // is not less than multiplicity of this prime factor for D.
10216 if (SE.getMinTrailingZeros(B) < Mult2) {
10217 // Check if we can prove there's no remainder using URem.
10218 const SCEV *URem =
10219 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10220 const SCEV *Zero = SE.getZero(B->getType());
10221 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10222 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10223 if (!Predicates)
10224 return SE.getCouldNotCompute();
10225
10226 // Avoid adding a predicate that is known to be false.
10227 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10228 return SE.getCouldNotCompute();
10229 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10230 }
10231 }
10232
10233 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10234 // modulo (N / D).
10235 //
10236 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10237 // (N / D) in general. The inverse itself always fits into BW bits, though,
10238 // so we immediately truncate it.
10239 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10240 APInt I = AD.multiplicativeInverse().zext(BW);
10241
10242 // 4. Compute the minimum unsigned root of the equation:
10243 // I * (B / D) mod (N / D)
10244 // To simplify the computation, we factor out the divide by D:
10245 // (I * B mod N) / D
10246 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10247 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10248}
10249
10250/// For a given quadratic addrec, generate coefficients of the corresponding
10251/// quadratic equation, multiplied by a common value to ensure that they are
10252/// integers.
10253/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10254/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10255/// were multiplied by, and BitWidth is the bit width of the original addrec
10256/// coefficients.
10257/// This function returns std::nullopt if the addrec coefficients are not
10258/// compile- time constants.
10259static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10261 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10262 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10263 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10264 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10265 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10266 << *AddRec << '\n');
10267
10268 // We currently can only solve this if the coefficients are constants.
10269 if (!LC || !MC || !NC) {
10270 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10271 return std::nullopt;
10272 }
10273
10274 APInt L = LC->getAPInt();
10275 APInt M = MC->getAPInt();
10276 APInt N = NC->getAPInt();
10277 assert(!N.isZero() && "This is not a quadratic addrec");
10278
10279 unsigned BitWidth = LC->getAPInt().getBitWidth();
10280 unsigned NewWidth = BitWidth + 1;
10281 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10282 << BitWidth << '\n');
10283 // The sign-extension (as opposed to a zero-extension) here matches the
10284 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10285 N = N.sext(NewWidth);
10286 M = M.sext(NewWidth);
10287 L = L.sext(NewWidth);
10288
10289 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10290 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10291 // L+M, L+2M+N, L+3M+3N, ...
10292 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10293 //
10294 // The equation Acc = 0 is then
10295 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10296 // In a quadratic form it becomes:
10297 // N n^2 + (2M-N) n + 2L = 0.
10298
10299 APInt A = N;
10300 APInt B = 2 * M - A;
10301 APInt C = 2 * L;
10302 APInt T = APInt(NewWidth, 2);
10303 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10304 << "x + " << C << ", coeff bw: " << NewWidth
10305 << ", multiplied by " << T << '\n');
10306 return std::make_tuple(A, B, C, T, BitWidth);
10307}
10308
10309/// Helper function to compare optional APInts:
10310/// (a) if X and Y both exist, return min(X, Y),
10311/// (b) if neither X nor Y exist, return std::nullopt,
10312/// (c) if exactly one of X and Y exists, return that value.
10313static std::optional<APInt> MinOptional(std::optional<APInt> X,
10314 std::optional<APInt> Y) {
10315 if (X && Y) {
10316 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10317 APInt XW = X->sext(W);
10318 APInt YW = Y->sext(W);
10319 return XW.slt(YW) ? *X : *Y;
10320 }
10321 if (!X && !Y)
10322 return std::nullopt;
10323 return X ? *X : *Y;
10324}
10325
10326/// Helper function to truncate an optional APInt to a given BitWidth.
10327/// When solving addrec-related equations, it is preferable to return a value
10328/// that has the same bit width as the original addrec's coefficients. If the
10329/// solution fits in the original bit width, truncate it (except for i1).
10330/// Returning a value of a different bit width may inhibit some optimizations.
10331///
10332/// In general, a solution to a quadratic equation generated from an addrec
10333/// may require BW+1 bits, where BW is the bit width of the addrec's
10334/// coefficients. The reason is that the coefficients of the quadratic
10335/// equation are BW+1 bits wide (to avoid truncation when converting from
10336/// the addrec to the equation).
10337static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10338 unsigned BitWidth) {
10339 if (!X)
10340 return std::nullopt;
10341 unsigned W = X->getBitWidth();
10342 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10343 return X->trunc(BitWidth);
10344 return X;
10345}
10346
10347/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10348/// iterations. The values L, M, N are assumed to be signed, and they
10349/// should all have the same bit widths.
10350/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10351/// where BW is the bit width of the addrec's coefficients.
10352/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10353/// returned as such, otherwise the bit width of the returned value may
10354/// be greater than BW.
10355///
10356/// This function returns std::nullopt if
10357/// (a) the addrec coefficients are not constant, or
10358/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10359/// like x^2 = 5, no integer solutions exist, in other cases an integer
10360/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10361static std::optional<APInt>
10363 APInt A, B, C, M;
10364 unsigned BitWidth;
10365 auto T = GetQuadraticEquation(AddRec);
10366 if (!T)
10367 return std::nullopt;
10368
10369 std::tie(A, B, C, M, BitWidth) = *T;
10370 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10371 std::optional<APInt> X =
10373 if (!X)
10374 return std::nullopt;
10375
10376 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10377 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10378 if (!V->isZero())
10379 return std::nullopt;
10380
10381 return TruncIfPossible(X, BitWidth);
10382}
10383
10384/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10385/// iterations. The values M, N are assumed to be signed, and they
10386/// should all have the same bit widths.
10387/// Find the least n such that c(n) does not belong to the given range,
10388/// while c(n-1) does.
10389///
10390/// This function returns std::nullopt if
10391/// (a) the addrec coefficients are not constant, or
10392/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10393/// bounds of the range.
10394static std::optional<APInt>
10396 const ConstantRange &Range, ScalarEvolution &SE) {
10397 assert(AddRec->getOperand(0)->isZero() &&
10398 "Starting value of addrec should be 0");
10399 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10400 << Range << ", addrec " << *AddRec << '\n');
10401 // This case is handled in getNumIterationsInRange. Here we can assume that
10402 // we start in the range.
10403 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10404 "Addrec's initial value should be in range");
10405
10406 APInt A, B, C, M;
10407 unsigned BitWidth;
10408 auto T = GetQuadraticEquation(AddRec);
10409 if (!T)
10410 return std::nullopt;
10411
10412 // Be careful about the return value: there can be two reasons for not
10413 // returning an actual number. First, if no solutions to the equations
10414 // were found, and second, if the solutions don't leave the given range.
10415 // The first case means that the actual solution is "unknown", the second
10416 // means that it's known, but not valid. If the solution is unknown, we
10417 // cannot make any conclusions.
10418 // Return a pair: the optional solution and a flag indicating if the
10419 // solution was found.
10420 auto SolveForBoundary =
10421 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10422 // Solve for signed overflow and unsigned overflow, pick the lower
10423 // solution.
10424 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10425 << Bound << " (before multiplying by " << M << ")\n");
10426 Bound *= M; // The quadratic equation multiplier.
10427
10428 std::optional<APInt> SO;
10429 if (BitWidth > 1) {
10430 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10431 "signed overflow\n");
10433 }
10434 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10435 "unsigned overflow\n");
10436 std::optional<APInt> UO =
10438
10439 auto LeavesRange = [&] (const APInt &X) {
10440 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10441 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10442 if (Range.contains(V0->getValue()))
10443 return false;
10444 // X should be at least 1, so X-1 is non-negative.
10445 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10446 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10447 if (Range.contains(V1->getValue()))
10448 return true;
10449 return false;
10450 };
10451
10452 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10453 // can be a solution, but the function failed to find it. We cannot treat it
10454 // as "no solution".
10455 if (!SO || !UO)
10456 return {std::nullopt, false};
10457
10458 // Check the smaller value first to see if it leaves the range.
10459 // At this point, both SO and UO must have values.
10460 std::optional<APInt> Min = MinOptional(SO, UO);
10461 if (LeavesRange(*Min))
10462 return { Min, true };
10463 std::optional<APInt> Max = Min == SO ? UO : SO;
10464 if (LeavesRange(*Max))
10465 return { Max, true };
10466
10467 // Solutions were found, but were eliminated, hence the "true".
10468 return {std::nullopt, true};
10469 };
10470
10471 std::tie(A, B, C, M, BitWidth) = *T;
10472 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10473 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10474 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10475 auto SL = SolveForBoundary(Lower);
10476 auto SU = SolveForBoundary(Upper);
10477 // If any of the solutions was unknown, no meaninigful conclusions can
10478 // be made.
10479 if (!SL.second || !SU.second)
10480 return std::nullopt;
10481
10482 // Claim: The correct solution is not some value between Min and Max.
10483 //
10484 // Justification: Assuming that Min and Max are different values, one of
10485 // them is when the first signed overflow happens, the other is when the
10486 // first unsigned overflow happens. Crossing the range boundary is only
10487 // possible via an overflow (treating 0 as a special case of it, modeling
10488 // an overflow as crossing k*2^W for some k).
10489 //
10490 // The interesting case here is when Min was eliminated as an invalid
10491 // solution, but Max was not. The argument is that if there was another
10492 // overflow between Min and Max, it would also have been eliminated if
10493 // it was considered.
10494 //
10495 // For a given boundary, it is possible to have two overflows of the same
10496 // type (signed/unsigned) without having the other type in between: this
10497 // can happen when the vertex of the parabola is between the iterations
10498 // corresponding to the overflows. This is only possible when the two
10499 // overflows cross k*2^W for the same k. In such case, if the second one
10500 // left the range (and was the first one to do so), the first overflow
10501 // would have to enter the range, which would mean that either we had left
10502 // the range before or that we started outside of it. Both of these cases
10503 // are contradictions.
10504 //
10505 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10506 // solution is not some value between the Max for this boundary and the
10507 // Min of the other boundary.
10508 //
10509 // Justification: Assume that we had such Max_A and Min_B corresponding
10510 // to range boundaries A and B and such that Max_A < Min_B. If there was
10511 // a solution between Max_A and Min_B, it would have to be caused by an
10512 // overflow corresponding to either A or B. It cannot correspond to B,
10513 // since Min_B is the first occurrence of such an overflow. If it
10514 // corresponded to A, it would have to be either a signed or an unsigned
10515 // overflow that is larger than both eliminated overflows for A. But
10516 // between the eliminated overflows and this overflow, the values would
10517 // cover the entire value space, thus crossing the other boundary, which
10518 // is a contradiction.
10519
10520 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10521}
10522
10523ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10524 const Loop *L,
10525 bool ControlsOnlyExit,
10526 bool AllowPredicates) {
10527
10528 // This is only used for loops with a "x != y" exit test. The exit condition
10529 // is now expressed as a single expression, V = x-y. So the exit test is
10530 // effectively V != 0. We know and take advantage of the fact that this
10531 // expression only being used in a comparison by zero context.
10532
10534 // If the value is a constant
10535 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10536 // If the value is already zero, the branch will execute zero times.
10537 if (C->getValue()->isZero()) return C;
10538 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10539 }
10540
10541 const SCEVAddRecExpr *AddRec =
10542 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10543
10544 if (!AddRec && AllowPredicates)
10545 // Try to make this an AddRec using runtime tests, in the first X
10546 // iterations of this loop, where X is the SCEV expression found by the
10547 // algorithm below.
10548 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10549
10550 if (!AddRec || AddRec->getLoop() != L)
10551 return getCouldNotCompute();
10552
10553 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10554 // the quadratic equation to solve it.
10555 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10556 // We can only use this value if the chrec ends up with an exact zero
10557 // value at this index. When solving for "X*X != 5", for example, we
10558 // should not accept a root of 2.
10559 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10560 const auto *R = cast<SCEVConstant>(getConstant(*S));
10561 return ExitLimit(R, R, R, false, Predicates);
10562 }
10563 return getCouldNotCompute();
10564 }
10565
10566 // Otherwise we can only handle this if it is affine.
10567 if (!AddRec->isAffine())
10568 return getCouldNotCompute();
10569
10570 // If this is an affine expression, the execution count of this branch is
10571 // the minimum unsigned root of the following equation:
10572 //
10573 // Start + Step*N = 0 (mod 2^BW)
10574 //
10575 // equivalent to:
10576 //
10577 // Step*N = -Start (mod 2^BW)
10578 //
10579 // where BW is the common bit width of Start and Step.
10580
10581 // Get the initial value for the loop.
10582 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10583 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10584
10585 if (!isLoopInvariant(Step, L))
10586 return getCouldNotCompute();
10587
10588 LoopGuards Guards = LoopGuards::collect(L, *this);
10589 // Specialize step for this loop so we get context sensitive facts below.
10590 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10591
10592 // For positive steps (counting up until unsigned overflow):
10593 // N = -Start/Step (as unsigned)
10594 // For negative steps (counting down to zero):
10595 // N = Start/-Step
10596 // First compute the unsigned distance from zero in the direction of Step.
10597 bool CountDown = isKnownNegative(StepWLG);
10598 if (!CountDown && !isKnownNonNegative(StepWLG))
10599 return getCouldNotCompute();
10600
10601 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10602 // Handle unitary steps, which cannot wraparound.
10603 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10604 // N = Distance (as unsigned)
10605
10606 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10607 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10608 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10609
10610 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10611 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10612 // case, and see if we can improve the bound.
10613 //
10614 // Explicitly handling this here is necessary because getUnsignedRange
10615 // isn't context-sensitive; it doesn't know that we only care about the
10616 // range inside the loop.
10617 const SCEV *Zero = getZero(Distance->getType());
10618 const SCEV *One = getOne(Distance->getType());
10619 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10620 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10621 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10622 // as "unsigned_max(Distance + 1) - 1".
10623 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10624 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10625 }
10626 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10627 Predicates);
10628 }
10629
10630 // If the condition controls loop exit (the loop exits only if the expression
10631 // is true) and the addition is no-wrap we can use unsigned divide to
10632 // compute the backedge count. In this case, the step may not divide the
10633 // distance, but we don't care because if the condition is "missed" the loop
10634 // will have undefined behavior due to wrapping.
10635 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10636 loopHasNoAbnormalExits(AddRec->getLoop())) {
10637
10638 // If the stride is zero, the loop must be infinite. In C++, most loops
10639 // are finite by assumption, in which case the step being zero implies
10640 // UB must execute if the loop is entered.
10641 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10642 return getCouldNotCompute();
10643
10644 const SCEV *Exact =
10645 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10646 const SCEV *ConstantMax = getCouldNotCompute();
10647 if (Exact != getCouldNotCompute()) {
10649 ConstantMax =
10651 }
10652 const SCEV *SymbolicMax =
10653 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10654 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10655 }
10656
10657 // Solve the general equation.
10658 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10659 if (!StepC || StepC->getValue()->isZero())
10660 return getCouldNotCompute();
10662 StepC->getAPInt(), getNegativeSCEV(Start),
10663 AllowPredicates ? &Predicates : nullptr, *this);
10664
10665 const SCEV *M = E;
10666 if (E != getCouldNotCompute()) {
10667 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10668 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10669 }
10670 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10671 return ExitLimit(E, M, S, false, Predicates);
10672}
10673
10675ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10676 // Loops that look like: while (X == 0) are very strange indeed. We don't
10677 // handle them yet except for the trivial case. This could be expanded in the
10678 // future as needed.
10679
10680 // If the value is a constant, check to see if it is known to be non-zero
10681 // already. If so, the backedge will execute zero times.
10682 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10683 if (!C->getValue()->isZero())
10684 return getZero(C->getType());
10685 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10686 }
10687
10688 // We could implement others, but I really doubt anyone writes loops like
10689 // this, and if they did, they would already be constant folded.
10690 return getCouldNotCompute();
10691}
10692
10693std::pair<const BasicBlock *, const BasicBlock *>
10694ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10695 const {
10696 // If the block has a unique predecessor, then there is no path from the
10697 // predecessor to the block that does not go through the direct edge
10698 // from the predecessor to the block.
10699 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10700 return {Pred, BB};
10701
10702 // A loop's header is defined to be a block that dominates the loop.
10703 // If the header has a unique predecessor outside the loop, it must be
10704 // a block that has exactly one successor that can reach the loop.
10705 if (const Loop *L = LI.getLoopFor(BB))
10706 return {L->getLoopPredecessor(), L->getHeader()};
10707
10708 return {nullptr, BB};
10709}
10710
10711/// SCEV structural equivalence is usually sufficient for testing whether two
10712/// expressions are equal, however for the purposes of looking for a condition
10713/// guarding a loop, it can be useful to be a little more general, since a
10714/// front-end may have replicated the controlling expression.
10715static bool HasSameValue(const SCEV *A, const SCEV *B) {
10716 // Quick check to see if they are the same SCEV.
10717 if (A == B) return true;
10718
10719 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10720 // Not all instructions that are "identical" compute the same value. For
10721 // instance, two distinct alloca instructions allocating the same type are
10722 // identical and do not read memory; but compute distinct values.
10723 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10724 };
10725
10726 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10727 // two different instructions with the same value. Check for this case.
10728 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10729 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10730 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10731 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10732 if (ComputesEqualValues(AI, BI))
10733 return true;
10734
10735 // Otherwise assume they may have a different value.
10736 return false;
10737}
10738
10739static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10740 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10741 if (!Add || Add->getNumOperands() != 2)
10742 return false;
10743 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10744 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10745 LHS = Add->getOperand(1);
10746 RHS = ME->getOperand(1);
10747 return true;
10748 }
10749 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10750 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10751 LHS = Add->getOperand(0);
10752 RHS = ME->getOperand(1);
10753 return true;
10754 }
10755 return false;
10756}
10757
10759 const SCEV *&RHS, unsigned Depth) {
10760 bool Changed = false;
10761 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10762 // '0 != 0'.
10763 auto TrivialCase = [&](bool TriviallyTrue) {
10765 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10766 return true;
10767 };
10768 // If we hit the max recursion limit bail out.
10769 if (Depth >= 3)
10770 return false;
10771
10772 // Canonicalize a constant to the right side.
10773 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10774 // Check for both operands constant.
10775 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10776 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10777 return TrivialCase(false);
10778 return TrivialCase(true);
10779 }
10780 // Otherwise swap the operands to put the constant on the right.
10781 std::swap(LHS, RHS);
10783 Changed = true;
10784 }
10785
10786 // If we're comparing an addrec with a value which is loop-invariant in the
10787 // addrec's loop, put the addrec on the left. Also make a dominance check,
10788 // as both operands could be addrecs loop-invariant in each other's loop.
10789 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10790 const Loop *L = AR->getLoop();
10791 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10792 std::swap(LHS, RHS);
10794 Changed = true;
10795 }
10796 }
10797
10798 // If there's a constant operand, canonicalize comparisons with boundary
10799 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10800 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10801 const APInt &RA = RC->getAPInt();
10802
10803 bool SimplifiedByConstantRange = false;
10804
10805 if (!ICmpInst::isEquality(Pred)) {
10807 if (ExactCR.isFullSet())
10808 return TrivialCase(true);
10809 if (ExactCR.isEmptySet())
10810 return TrivialCase(false);
10811
10812 APInt NewRHS;
10813 CmpInst::Predicate NewPred;
10814 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10815 ICmpInst::isEquality(NewPred)) {
10816 // We were able to convert an inequality to an equality.
10817 Pred = NewPred;
10818 RHS = getConstant(NewRHS);
10819 Changed = SimplifiedByConstantRange = true;
10820 }
10821 }
10822
10823 if (!SimplifiedByConstantRange) {
10824 switch (Pred) {
10825 default:
10826 break;
10827 case ICmpInst::ICMP_EQ:
10828 case ICmpInst::ICMP_NE:
10829 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10830 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10831 Changed = true;
10832 break;
10833
10834 // The "Should have been caught earlier!" messages refer to the fact
10835 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10836 // should have fired on the corresponding cases, and canonicalized the
10837 // check to trivial case.
10838
10839 case ICmpInst::ICMP_UGE:
10840 assert(!RA.isMinValue() && "Should have been caught earlier!");
10841 Pred = ICmpInst::ICMP_UGT;
10842 RHS = getConstant(RA - 1);
10843 Changed = true;
10844 break;
10845 case ICmpInst::ICMP_ULE:
10846 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10847 Pred = ICmpInst::ICMP_ULT;
10848 RHS = getConstant(RA + 1);
10849 Changed = true;
10850 break;
10851 case ICmpInst::ICMP_SGE:
10852 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10853 Pred = ICmpInst::ICMP_SGT;
10854 RHS = getConstant(RA - 1);
10855 Changed = true;
10856 break;
10857 case ICmpInst::ICMP_SLE:
10858 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10859 Pred = ICmpInst::ICMP_SLT;
10860 RHS = getConstant(RA + 1);
10861 Changed = true;
10862 break;
10863 }
10864 }
10865 }
10866
10867 // Check for obvious equality.
10868 if (HasSameValue(LHS, RHS)) {
10869 if (ICmpInst::isTrueWhenEqual(Pred))
10870 return TrivialCase(true);
10872 return TrivialCase(false);
10873 }
10874
10875 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10876 // adding or subtracting 1 from one of the operands.
10877 switch (Pred) {
10878 case ICmpInst::ICMP_SLE:
10879 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10880 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10882 Pred = ICmpInst::ICMP_SLT;
10883 Changed = true;
10884 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10885 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10887 Pred = ICmpInst::ICMP_SLT;
10888 Changed = true;
10889 }
10890 break;
10891 case ICmpInst::ICMP_SGE:
10892 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10893 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10895 Pred = ICmpInst::ICMP_SGT;
10896 Changed = true;
10897 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10898 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10900 Pred = ICmpInst::ICMP_SGT;
10901 Changed = true;
10902 }
10903 break;
10904 case ICmpInst::ICMP_ULE:
10905 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10906 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10908 Pred = ICmpInst::ICMP_ULT;
10909 Changed = true;
10910 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10911 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10912 Pred = ICmpInst::ICMP_ULT;
10913 Changed = true;
10914 }
10915 break;
10916 case ICmpInst::ICMP_UGE:
10917 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10918 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10919 Pred = ICmpInst::ICMP_UGT;
10920 Changed = true;
10921 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10922 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10924 Pred = ICmpInst::ICMP_UGT;
10925 Changed = true;
10926 }
10927 break;
10928 default:
10929 break;
10930 }
10931
10932 // TODO: More simplifications are possible here.
10933
10934 // Recursively simplify until we either hit a recursion limit or nothing
10935 // changes.
10936 if (Changed)
10937 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10938
10939 return Changed;
10940}
10941
10943 return getSignedRangeMax(S).isNegative();
10944}
10945
10948}
10949
10951 return !getSignedRangeMin(S).isNegative();
10952}
10953
10956}
10957
10959 // Query push down for cases where the unsigned range is
10960 // less than sufficient.
10961 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10962 return isKnownNonZero(SExt->getOperand(0));
10963 return getUnsignedRangeMin(S) != 0;
10964}
10965
10967 bool OrNegative) {
10968 auto NonRecursive = [this, OrNegative](const SCEV *S) {
10969 if (auto *C = dyn_cast<SCEVConstant>(S))
10970 return C->getAPInt().isPowerOf2() ||
10971 (OrNegative && C->getAPInt().isNegatedPowerOf2());
10972
10973 // The vscale_range indicates vscale is a power-of-two.
10974 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10975 };
10976
10977 if (NonRecursive(S))
10978 return true;
10979
10980 auto *Mul = dyn_cast<SCEVMulExpr>(S);
10981 if (!Mul)
10982 return false;
10983 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10984}
10985
10986std::pair<const SCEV *, const SCEV *>
10988 // Compute SCEV on entry of loop L.
10989 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10990 if (Start == getCouldNotCompute())
10991 return { Start, Start };
10992 // Compute post increment SCEV for loop L.
10993 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10994 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10995 return { Start, PostInc };
10996}
10997
10999 const SCEV *RHS) {
11000 // First collect all loops.
11002 getUsedLoops(LHS, LoopsUsed);
11003 getUsedLoops(RHS, LoopsUsed);
11004
11005 if (LoopsUsed.empty())
11006 return false;
11007
11008 // Domination relationship must be a linear order on collected loops.
11009#ifndef NDEBUG
11010 for (const auto *L1 : LoopsUsed)
11011 for (const auto *L2 : LoopsUsed)
11012 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11013 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11014 "Domination relationship is not a linear order");
11015#endif
11016
11017 const Loop *MDL =
11018 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11019 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11020 });
11021
11022 // Get init and post increment value for LHS.
11023 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11024 // if LHS contains unknown non-invariant SCEV then bail out.
11025 if (SplitLHS.first == getCouldNotCompute())
11026 return false;
11027 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11028 // Get init and post increment value for RHS.
11029 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11030 // if RHS contains unknown non-invariant SCEV then bail out.
11031 if (SplitRHS.first == getCouldNotCompute())
11032 return false;
11033 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11034 // It is possible that init SCEV contains an invariant load but it does
11035 // not dominate MDL and is not available at MDL loop entry, so we should
11036 // check it here.
11037 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11038 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11039 return false;
11040
11041 // It seems backedge guard check is faster than entry one so in some cases
11042 // it can speed up whole estimation by short circuit
11043 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11044 SplitRHS.second) &&
11045 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11046}
11047
11049 const SCEV *RHS) {
11050 // Canonicalize the inputs first.
11051 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11052
11053 if (isKnownViaInduction(Pred, LHS, RHS))
11054 return true;
11055
11056 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11057 return true;
11058
11059 // Otherwise see what can be done with some simple reasoning.
11060 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11061}
11062
11064 const SCEV *LHS,
11065 const SCEV *RHS) {
11066 if (isKnownPredicate(Pred, LHS, RHS))
11067 return true;
11069 return false;
11070 return std::nullopt;
11071}
11072
11074 const SCEV *RHS,
11075 const Instruction *CtxI) {
11076 // TODO: Analyze guards and assumes from Context's block.
11077 return isKnownPredicate(Pred, LHS, RHS) ||
11079}
11080
11081std::optional<bool>
11083 const SCEV *RHS, const Instruction *CtxI) {
11084 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11085 if (KnownWithoutContext)
11086 return KnownWithoutContext;
11087
11088 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11089 return true;
11092 return false;
11093 return std::nullopt;
11094}
11095
11097 const SCEVAddRecExpr *LHS,
11098 const SCEV *RHS) {
11099 const Loop *L = LHS->getLoop();
11100 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11101 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11102}
11103
11104std::optional<ScalarEvolution::MonotonicPredicateType>
11106 ICmpInst::Predicate Pred) {
11107 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11108
11109#ifndef NDEBUG
11110 // Verify an invariant: inverting the predicate should turn a monotonically
11111 // increasing change to a monotonically decreasing one, and vice versa.
11112 if (Result) {
11113 auto ResultSwapped =
11114 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11115
11116 assert(*ResultSwapped != *Result &&
11117 "monotonicity should flip as we flip the predicate");
11118 }
11119#endif
11120
11121 return Result;
11122}
11123
11124std::optional<ScalarEvolution::MonotonicPredicateType>
11125ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11126 ICmpInst::Predicate Pred) {
11127 // A zero step value for LHS means the induction variable is essentially a
11128 // loop invariant value. We don't really depend on the predicate actually
11129 // flipping from false to true (for increasing predicates, and the other way
11130 // around for decreasing predicates), all we care about is that *if* the
11131 // predicate changes then it only changes from false to true.
11132 //
11133 // A zero step value in itself is not very useful, but there may be places
11134 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11135 // as general as possible.
11136
11137 // Only handle LE/LT/GE/GT predicates.
11138 if (!ICmpInst::isRelational(Pred))
11139 return std::nullopt;
11140
11141 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11142 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11143 "Should be greater or less!");
11144
11145 // Check that AR does not wrap.
11146 if (ICmpInst::isUnsigned(Pred)) {
11147 if (!LHS->hasNoUnsignedWrap())
11148 return std::nullopt;
11150 }
11151 assert(ICmpInst::isSigned(Pred) &&
11152 "Relational predicate is either signed or unsigned!");
11153 if (!LHS->hasNoSignedWrap())
11154 return std::nullopt;
11155
11156 const SCEV *Step = LHS->getStepRecurrence(*this);
11157
11158 if (isKnownNonNegative(Step))
11160
11161 if (isKnownNonPositive(Step))
11163
11164 return std::nullopt;
11165}
11166
11167std::optional<ScalarEvolution::LoopInvariantPredicate>
11169 const SCEV *LHS, const SCEV *RHS,
11170 const Loop *L,
11171 const Instruction *CtxI) {
11172 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11173 if (!isLoopInvariant(RHS, L)) {
11174 if (!isLoopInvariant(LHS, L))
11175 return std::nullopt;
11176
11177 std::swap(LHS, RHS);
11178 Pred = ICmpInst::getSwappedPredicate(Pred);
11179 }
11180
11181 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11182 if (!ArLHS || ArLHS->getLoop() != L)
11183 return std::nullopt;
11184
11185 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11186 if (!MonotonicType)
11187 return std::nullopt;
11188 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11189 // true as the loop iterates, and the backedge is control dependent on
11190 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11191 //
11192 // * if the predicate was false in the first iteration then the predicate
11193 // is never evaluated again, since the loop exits without taking the
11194 // backedge.
11195 // * if the predicate was true in the first iteration then it will
11196 // continue to be true for all future iterations since it is
11197 // monotonically increasing.
11198 //
11199 // For both the above possibilities, we can replace the loop varying
11200 // predicate with its value on the first iteration of the loop (which is
11201 // loop invariant).
11202 //
11203 // A similar reasoning applies for a monotonically decreasing predicate, by
11204 // replacing true with false and false with true in the above two bullets.
11206 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11207
11210 RHS);
11211
11212 if (!CtxI)
11213 return std::nullopt;
11214 // Try to prove via context.
11215 // TODO: Support other cases.
11216 switch (Pred) {
11217 default:
11218 break;
11219 case ICmpInst::ICMP_ULE:
11220 case ICmpInst::ICMP_ULT: {
11221 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11222 // Given preconditions
11223 // (1) ArLHS does not cross the border of positive and negative parts of
11224 // range because of:
11225 // - Positive step; (TODO: lift this limitation)
11226 // - nuw - does not cross zero boundary;
11227 // - nsw - does not cross SINT_MAX boundary;
11228 // (2) ArLHS <s RHS
11229 // (3) RHS >=s 0
11230 // we can replace the loop variant ArLHS <u RHS condition with loop
11231 // invariant Start(ArLHS) <u RHS.
11232 //
11233 // Because of (1) there are two options:
11234 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11235 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11236 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11237 // Because of (2) ArLHS <u RHS is trivially true.
11238 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11239 // We can strengthen this to Start(ArLHS) <u RHS.
11240 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11241 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11242 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11244 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11246 RHS);
11247 }
11248 }
11249
11250 return std::nullopt;
11251}
11252
11253std::optional<ScalarEvolution::LoopInvariantPredicate>
11255 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11256 const Instruction *CtxI, const SCEV *MaxIter) {
11258 Pred, LHS, RHS, L, CtxI, MaxIter))
11259 return LIP;
11260 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11261 // Number of iterations expressed as UMIN isn't always great for expressing
11262 // the value on the last iteration. If the straightforward approach didn't
11263 // work, try the following trick: if the a predicate is invariant for X, it
11264 // is also invariant for umin(X, ...). So try to find something that works
11265 // among subexpressions of MaxIter expressed as umin.
11266 for (auto *Op : UMin->operands())
11268 Pred, LHS, RHS, L, CtxI, Op))
11269 return LIP;
11270 return std::nullopt;
11271}
11272
11273std::optional<ScalarEvolution::LoopInvariantPredicate>
11275 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11276 const Instruction *CtxI, const SCEV *MaxIter) {
11277 // Try to prove the following set of facts:
11278 // - The predicate is monotonic in the iteration space.
11279 // - If the check does not fail on the 1st iteration:
11280 // - No overflow will happen during first MaxIter iterations;
11281 // - It will not fail on the MaxIter'th iteration.
11282 // If the check does fail on the 1st iteration, we leave the loop and no
11283 // other checks matter.
11284
11285 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11286 if (!isLoopInvariant(RHS, L)) {
11287 if (!isLoopInvariant(LHS, L))
11288 return std::nullopt;
11289
11290 std::swap(LHS, RHS);
11292 }
11293
11294 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11295 if (!AR || AR->getLoop() != L)
11296 return std::nullopt;
11297
11298 // The predicate must be relational (i.e. <, <=, >=, >).
11299 if (!ICmpInst::isRelational(Pred))
11300 return std::nullopt;
11301
11302 // TODO: Support steps other than +/- 1.
11303 const SCEV *Step = AR->getStepRecurrence(*this);
11304 auto *One = getOne(Step->getType());
11305 auto *MinusOne = getNegativeSCEV(One);
11306 if (Step != One && Step != MinusOne)
11307 return std::nullopt;
11308
11309 // Type mismatch here means that MaxIter is potentially larger than max
11310 // unsigned value in start type, which mean we cannot prove no wrap for the
11311 // indvar.
11312 if (AR->getType() != MaxIter->getType())
11313 return std::nullopt;
11314
11315 // Value of IV on suggested last iteration.
11316 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11317 // Does it still meet the requirement?
11318 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11319 return std::nullopt;
11320 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11321 // not exceed max unsigned value of this type), this effectively proves
11322 // that there is no wrap during the iteration. To prove that there is no
11323 // signed/unsigned wrap, we need to check that
11324 // Start <= Last for step = 1 or Start >= Last for step = -1.
11325 ICmpInst::Predicate NoOverflowPred =
11327 if (Step == MinusOne)
11328 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11329 const SCEV *Start = AR->getStart();
11330 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11331 return std::nullopt;
11332
11333 // Everything is fine.
11334 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11335}
11336
11337bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11338 const SCEV *LHS,
11339 const SCEV *RHS) {
11340 if (HasSameValue(LHS, RHS))
11341 return ICmpInst::isTrueWhenEqual(Pred);
11342
11343 // This code is split out from isKnownPredicate because it is called from
11344 // within isLoopEntryGuardedByCond.
11345
11346 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11347 const ConstantRange &RangeRHS) {
11348 return RangeLHS.icmp(Pred, RangeRHS);
11349 };
11350
11351 // The check at the top of the function catches the case where the values are
11352 // known to be equal.
11353 if (Pred == CmpInst::ICMP_EQ)
11354 return false;
11355
11356 if (Pred == CmpInst::ICMP_NE) {
11357 auto SL = getSignedRange(LHS);
11358 auto SR = getSignedRange(RHS);
11359 if (CheckRanges(SL, SR))
11360 return true;
11361 auto UL = getUnsignedRange(LHS);
11362 auto UR = getUnsignedRange(RHS);
11363 if (CheckRanges(UL, UR))
11364 return true;
11365 auto *Diff = getMinusSCEV(LHS, RHS);
11366 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11367 }
11368
11369 if (CmpInst::isSigned(Pred)) {
11370 auto SL = getSignedRange(LHS);
11371 auto SR = getSignedRange(RHS);
11372 return CheckRanges(SL, SR);
11373 }
11374
11375 auto UL = getUnsignedRange(LHS);
11376 auto UR = getUnsignedRange(RHS);
11377 return CheckRanges(UL, UR);
11378}
11379
11380bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11381 const SCEV *LHS,
11382 const SCEV *RHS) {
11383 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11384 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11385 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11386 // OutC1 and OutC2.
11387 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11388 APInt &OutC1, APInt &OutC2,
11389 SCEV::NoWrapFlags ExpectedFlags) {
11390 const SCEV *XNonConstOp, *XConstOp;
11391 const SCEV *YNonConstOp, *YConstOp;
11392 SCEV::NoWrapFlags XFlagsPresent;
11393 SCEV::NoWrapFlags YFlagsPresent;
11394
11395 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11396 XConstOp = getZero(X->getType());
11397 XNonConstOp = X;
11398 XFlagsPresent = ExpectedFlags;
11399 }
11400 if (!isa<SCEVConstant>(XConstOp) ||
11401 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11402 return false;
11403
11404 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11405 YConstOp = getZero(Y->getType());
11406 YNonConstOp = Y;
11407 YFlagsPresent = ExpectedFlags;
11408 }
11409
11410 if (!isa<SCEVConstant>(YConstOp) ||
11411 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11412 return false;
11413
11414 if (YNonConstOp != XNonConstOp)
11415 return false;
11416
11417 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11418 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11419
11420 return true;
11421 };
11422
11423 APInt C1;
11424 APInt C2;
11425
11426 switch (Pred) {
11427 default:
11428 break;
11429
11430 case ICmpInst::ICMP_SGE:
11431 std::swap(LHS, RHS);
11432 [[fallthrough]];
11433 case ICmpInst::ICMP_SLE:
11434 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11435 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11436 return true;
11437
11438 break;
11439
11440 case ICmpInst::ICMP_SGT:
11441 std::swap(LHS, RHS);
11442 [[fallthrough]];
11443 case ICmpInst::ICMP_SLT:
11444 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11445 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11446 return true;
11447
11448 break;
11449
11450 case ICmpInst::ICMP_UGE:
11451 std::swap(LHS, RHS);
11452 [[fallthrough]];
11453 case ICmpInst::ICMP_ULE:
11454 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11455 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11456 return true;
11457
11458 break;
11459
11460 case ICmpInst::ICMP_UGT:
11461 std::swap(LHS, RHS);
11462 [[fallthrough]];
11463 case ICmpInst::ICMP_ULT:
11464 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11465 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11466 return true;
11467 break;
11468 }
11469
11470 return false;
11471}
11472
11473bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11474 const SCEV *LHS,
11475 const SCEV *RHS) {
11476 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11477 return false;
11478
11479 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11480 // the stack can result in exponential time complexity.
11481 SaveAndRestore Restore(ProvingSplitPredicate, true);
11482
11483 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11484 //
11485 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11486 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11487 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11488 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11489 // use isKnownPredicate later if needed.
11490 return isKnownNonNegative(RHS) &&
11493}
11494
11495bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11496 const SCEV *LHS, const SCEV *RHS) {
11497 // No need to even try if we know the module has no guards.
11498 if (!HasGuards)
11499 return false;
11500
11501 return any_of(*BB, [&](const Instruction &I) {
11502 using namespace llvm::PatternMatch;
11503
11504 Value *Condition;
11505 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11506 m_Value(Condition))) &&
11507 isImpliedCond(Pred, LHS, RHS, Condition, false);
11508 });
11509}
11510
11511/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11512/// protected by a conditional between LHS and RHS. This is used to
11513/// to eliminate casts.
11515 CmpPredicate Pred,
11516 const SCEV *LHS,
11517 const SCEV *RHS) {
11518 // Interpret a null as meaning no loop, where there is obviously no guard
11519 // (interprocedural conditions notwithstanding). Do not bother about
11520 // unreachable loops.
11521 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11522 return true;
11523
11524 if (VerifyIR)
11525 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11526 "This cannot be done on broken IR!");
11527
11528
11529 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11530 return true;
11531
11532 BasicBlock *Latch = L->getLoopLatch();
11533 if (!Latch)
11534 return false;
11535
11536 BranchInst *LoopContinuePredicate =
11537 dyn_cast<BranchInst>(Latch->getTerminator());
11538 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11539 isImpliedCond(Pred, LHS, RHS,
11540 LoopContinuePredicate->getCondition(),
11541 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11542 return true;
11543
11544 // We don't want more than one activation of the following loops on the stack
11545 // -- that can lead to O(n!) time complexity.
11546 if (WalkingBEDominatingConds)
11547 return false;
11548
11549 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11550
11551 // See if we can exploit a trip count to prove the predicate.
11552 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11553 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11554 if (LatchBECount != getCouldNotCompute()) {
11555 // We know that Latch branches back to the loop header exactly
11556 // LatchBECount times. This means the backdege condition at Latch is
11557 // equivalent to "{0,+,1} u< LatchBECount".
11558 Type *Ty = LatchBECount->getType();
11559 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11560 const SCEV *LoopCounter =
11561 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11562 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11563 LatchBECount))
11564 return true;
11565 }
11566
11567 // Check conditions due to any @llvm.assume intrinsics.
11568 for (auto &AssumeVH : AC.assumptions()) {
11569 if (!AssumeVH)
11570 continue;
11571 auto *CI = cast<CallInst>(AssumeVH);
11572 if (!DT.dominates(CI, Latch->getTerminator()))
11573 continue;
11574
11575 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11576 return true;
11577 }
11578
11579 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11580 return true;
11581
11582 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11583 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11584 assert(DTN && "should reach the loop header before reaching the root!");
11585
11586 BasicBlock *BB = DTN->getBlock();
11587 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11588 return true;
11589
11590 BasicBlock *PBB = BB->getSinglePredecessor();
11591 if (!PBB)
11592 continue;
11593
11594 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11595 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11596 continue;
11597
11598 Value *Condition = ContinuePredicate->getCondition();
11599
11600 // If we have an edge `E` within the loop body that dominates the only
11601 // latch, the condition guarding `E` also guards the backedge. This
11602 // reasoning works only for loops with a single latch.
11603
11604 BasicBlockEdge DominatingEdge(PBB, BB);
11605 if (DominatingEdge.isSingleEdge()) {
11606 // We're constructively (and conservatively) enumerating edges within the
11607 // loop body that dominate the latch. The dominator tree better agree
11608 // with us on this:
11609 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11610
11611 if (isImpliedCond(Pred, LHS, RHS, Condition,
11612 BB != ContinuePredicate->getSuccessor(0)))
11613 return true;
11614 }
11615 }
11616
11617 return false;
11618}
11619
11621 CmpPredicate Pred,
11622 const SCEV *LHS,
11623 const SCEV *RHS) {
11624 // Do not bother proving facts for unreachable code.
11625 if (!DT.isReachableFromEntry(BB))
11626 return true;
11627 if (VerifyIR)
11628 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11629 "This cannot be done on broken IR!");
11630
11631 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11632 // the facts (a >= b && a != b) separately. A typical situation is when the
11633 // non-strict comparison is known from ranges and non-equality is known from
11634 // dominating predicates. If we are proving strict comparison, we always try
11635 // to prove non-equality and non-strict comparison separately.
11636 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11637 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11638 bool ProvedNonStrictComparison = false;
11639 bool ProvedNonEquality = false;
11640
11641 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11642 if (!ProvedNonStrictComparison)
11643 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11644 if (!ProvedNonEquality)
11645 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11646 if (ProvedNonStrictComparison && ProvedNonEquality)
11647 return true;
11648 return false;
11649 };
11650
11651 if (ProvingStrictComparison) {
11652 auto ProofFn = [&](CmpPredicate P) {
11653 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11654 };
11655 if (SplitAndProve(ProofFn))
11656 return true;
11657 }
11658
11659 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11660 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11661 const Instruction *CtxI = &BB->front();
11662 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11663 return true;
11664 if (ProvingStrictComparison) {
11665 auto ProofFn = [&](CmpPredicate P) {
11666 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11667 };
11668 if (SplitAndProve(ProofFn))
11669 return true;
11670 }
11671 return false;
11672 };
11673
11674 // Starting at the block's predecessor, climb up the predecessor chain, as long
11675 // as there are predecessors that can be found that have unique successors
11676 // leading to the original block.
11677 const Loop *ContainingLoop = LI.getLoopFor(BB);
11678 const BasicBlock *PredBB;
11679 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11680 PredBB = ContainingLoop->getLoopPredecessor();
11681 else
11682 PredBB = BB->getSinglePredecessor();
11683 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11684 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11685 const BranchInst *BlockEntryPredicate =
11686 dyn_cast<BranchInst>(Pair.first->getTerminator());
11687 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11688 continue;
11689
11690 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11691 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11692 return true;
11693 }
11694
11695 // Check conditions due to any @llvm.assume intrinsics.
11696 for (auto &AssumeVH : AC.assumptions()) {
11697 if (!AssumeVH)
11698 continue;
11699 auto *CI = cast<CallInst>(AssumeVH);
11700 if (!DT.dominates(CI, BB))
11701 continue;
11702
11703 if (ProveViaCond(CI->getArgOperand(0), false))
11704 return true;
11705 }
11706
11707 // Check conditions due to any @llvm.experimental.guard intrinsics.
11708 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11709 F.getParent(), Intrinsic::experimental_guard);
11710 if (GuardDecl)
11711 for (const auto *GU : GuardDecl->users())
11712 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11713 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11714 if (ProveViaCond(Guard->getArgOperand(0), false))
11715 return true;
11716 return false;
11717}
11718
11720 const SCEV *LHS,
11721 const SCEV *RHS) {
11722 // Interpret a null as meaning no loop, where there is obviously no guard
11723 // (interprocedural conditions notwithstanding).
11724 if (!L)
11725 return false;
11726
11727 // Both LHS and RHS must be available at loop entry.
11729 "LHS is not available at Loop Entry");
11731 "RHS is not available at Loop Entry");
11732
11733 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11734 return true;
11735
11736 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11737}
11738
11739bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11740 const SCEV *RHS,
11741 const Value *FoundCondValue, bool Inverse,
11742 const Instruction *CtxI) {
11743 // False conditions implies anything. Do not bother analyzing it further.
11744 if (FoundCondValue ==
11745 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11746 return true;
11747
11748 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11749 return false;
11750
11751 auto ClearOnExit =
11752 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11753
11754 // Recursively handle And and Or conditions.
11755 const Value *Op0, *Op1;
11756 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11757 if (!Inverse)
11758 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11759 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11760 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11761 if (Inverse)
11762 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11763 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11764 }
11765
11766 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11767 if (!ICI) return false;
11768
11769 // Now that we found a conditional branch that dominates the loop or controls
11770 // the loop latch. Check to see if it is the comparison we are looking for.
11771 CmpPredicate FoundPred;
11772 if (Inverse)
11773 FoundPred = ICI->getInverseCmpPredicate();
11774 else
11775 FoundPred = ICI->getCmpPredicate();
11776
11777 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11778 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11779
11780 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11781}
11782
11783bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11784 const SCEV *RHS, CmpPredicate FoundPred,
11785 const SCEV *FoundLHS, const SCEV *FoundRHS,
11786 const Instruction *CtxI) {
11787 // Balance the types.
11788 if (getTypeSizeInBits(LHS->getType()) <
11789 getTypeSizeInBits(FoundLHS->getType())) {
11790 // For unsigned and equality predicates, try to prove that both found
11791 // operands fit into narrow unsigned range. If so, try to prove facts in
11792 // narrow types.
11793 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11794 !FoundRHS->getType()->isPointerTy()) {
11795 auto *NarrowType = LHS->getType();
11796 auto *WideType = FoundLHS->getType();
11797 auto BitWidth = getTypeSizeInBits(NarrowType);
11798 const SCEV *MaxValue = getZeroExtendExpr(
11800 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11801 MaxValue) &&
11802 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11803 MaxValue)) {
11804 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11805 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11806 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11807 TruncFoundRHS, CtxI))
11808 return true;
11809 }
11810 }
11811
11812 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11813 return false;
11814 if (CmpInst::isSigned(Pred)) {
11815 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11816 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11817 } else {
11818 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11819 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11820 }
11821 } else if (getTypeSizeInBits(LHS->getType()) >
11822 getTypeSizeInBits(FoundLHS->getType())) {
11823 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11824 return false;
11825 if (CmpInst::isSigned(FoundPred)) {
11826 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11827 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11828 } else {
11829 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11830 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11831 }
11832 }
11833 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11834 FoundRHS, CtxI);
11835}
11836
11837bool ScalarEvolution::isImpliedCondBalancedTypes(
11838 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11839 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11841 getTypeSizeInBits(FoundLHS->getType()) &&
11842 "Types should be balanced!");
11843 // Canonicalize the query to match the way instcombine will have
11844 // canonicalized the comparison.
11845 if (SimplifyICmpOperands(Pred, LHS, RHS))
11846 if (LHS == RHS)
11847 return CmpInst::isTrueWhenEqual(Pred);
11848 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11849 if (FoundLHS == FoundRHS)
11850 return CmpInst::isFalseWhenEqual(FoundPred);
11851
11852 // Check to see if we can make the LHS or RHS match.
11853 if (LHS == FoundRHS || RHS == FoundLHS) {
11854 if (isa<SCEVConstant>(RHS)) {
11855 std::swap(FoundLHS, FoundRHS);
11856 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11857 } else {
11858 std::swap(LHS, RHS);
11860 }
11861 }
11862
11863 // Check whether the found predicate is the same as the desired predicate.
11864 // FIXME: use CmpPredicate::getMatching here.
11865 if (FoundPred == static_cast<CmpInst::Predicate>(Pred))
11866 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11867
11868 // Check whether swapping the found predicate makes it the same as the
11869 // desired predicate.
11870 // FIXME: use CmpPredicate::getMatching here.
11871 if (ICmpInst::getSwappedCmpPredicate(FoundPred) ==
11872 static_cast<CmpInst::Predicate>(Pred)) {
11873 // We can write the implication
11874 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11875 // using one of the following ways:
11876 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11877 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11878 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11879 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11880 // Forms 1. and 2. require swapping the operands of one condition. Don't
11881 // do this if it would break canonical constant/addrec ordering.
11882 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11883 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11884 CtxI);
11885 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11886 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11887
11888 // There's no clear preference between forms 3. and 4., try both. Avoid
11889 // forming getNotSCEV of pointer values as the resulting subtract is
11890 // not legal.
11891 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11892 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11893 FoundLHS, FoundRHS, CtxI))
11894 return true;
11895
11896 if (!FoundLHS->getType()->isPointerTy() &&
11897 !FoundRHS->getType()->isPointerTy() &&
11898 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11899 getNotSCEV(FoundRHS), CtxI))
11900 return true;
11901
11902 return false;
11903 }
11904
11905 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11906 CmpInst::Predicate P2) {
11907 assert(P1 != P2 && "Handled earlier!");
11908 return CmpInst::isRelational(P2) &&
11910 };
11911 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11912 // Unsigned comparison is the same as signed comparison when both the
11913 // operands are non-negative or negative.
11914 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11915 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11916 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11917 // Create local copies that we can freely swap and canonicalize our
11918 // conditions to "le/lt".
11919 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11920 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11921 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11922 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11923 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
11924 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
11925 std::swap(CanonicalLHS, CanonicalRHS);
11926 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11927 }
11928 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11929 "Must be!");
11930 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11931 ICmpInst::isLE(CanonicalFoundPred)) &&
11932 "Must be!");
11933 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11934 // Use implication:
11935 // x <u y && y >=s 0 --> x <s y.
11936 // If we can prove the left part, the right part is also proven.
11937 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11938 CanonicalRHS, CanonicalFoundLHS,
11939 CanonicalFoundRHS);
11940 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11941 // Use implication:
11942 // x <s y && y <s 0 --> x <u y.
11943 // If we can prove the left part, the right part is also proven.
11944 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11945 CanonicalRHS, CanonicalFoundLHS,
11946 CanonicalFoundRHS);
11947 }
11948
11949 // Check if we can make progress by sharpening ranges.
11950 if (FoundPred == ICmpInst::ICMP_NE &&
11951 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11952
11953 const SCEVConstant *C = nullptr;
11954 const SCEV *V = nullptr;
11955
11956 if (isa<SCEVConstant>(FoundLHS)) {
11957 C = cast<SCEVConstant>(FoundLHS);
11958 V = FoundRHS;
11959 } else {
11960 C = cast<SCEVConstant>(FoundRHS);
11961 V = FoundLHS;
11962 }
11963
11964 // The guarding predicate tells us that C != V. If the known range
11965 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11966 // range we consider has to correspond to same signedness as the
11967 // predicate we're interested in folding.
11968
11969 APInt Min = ICmpInst::isSigned(Pred) ?
11971
11972 if (Min == C->getAPInt()) {
11973 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11974 // This is true even if (Min + 1) wraps around -- in case of
11975 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11976
11977 APInt SharperMin = Min + 1;
11978
11979 switch (Pred) {
11980 case ICmpInst::ICMP_SGE:
11981 case ICmpInst::ICMP_UGE:
11982 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11983 // RHS, we're done.
11984 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11985 CtxI))
11986 return true;
11987 [[fallthrough]];
11988
11989 case ICmpInst::ICMP_SGT:
11990 case ICmpInst::ICMP_UGT:
11991 // We know from the range information that (V `Pred` Min ||
11992 // V == Min). We know from the guarding condition that !(V
11993 // == Min). This gives us
11994 //
11995 // V `Pred` Min || V == Min && !(V == Min)
11996 // => V `Pred` Min
11997 //
11998 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11999
12000 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12001 return true;
12002 break;
12003
12004 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12005 case ICmpInst::ICMP_SLE:
12006 case ICmpInst::ICMP_ULE:
12007 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12008 LHS, V, getConstant(SharperMin), CtxI))
12009 return true;
12010 [[fallthrough]];
12011
12012 case ICmpInst::ICMP_SLT:
12013 case ICmpInst::ICMP_ULT:
12014 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12015 LHS, V, getConstant(Min), CtxI))
12016 return true;
12017 break;
12018
12019 default:
12020 // No change
12021 break;
12022 }
12023 }
12024 }
12025
12026 // Check whether the actual condition is beyond sufficient.
12027 if (FoundPred == ICmpInst::ICMP_EQ)
12028 if (ICmpInst::isTrueWhenEqual(Pred))
12029 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12030 return true;
12031 if (Pred == ICmpInst::ICMP_NE)
12032 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12033 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12034 return true;
12035
12036 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12037 return true;
12038
12039 // Otherwise assume the worst.
12040 return false;
12041}
12042
12043bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12044 const SCEV *&L, const SCEV *&R,
12045 SCEV::NoWrapFlags &Flags) {
12046 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12047 if (!AE || AE->getNumOperands() != 2)
12048 return false;
12049
12050 L = AE->getOperand(0);
12051 R = AE->getOperand(1);
12052 Flags = AE->getNoWrapFlags();
12053 return true;
12054}
12055
12056std::optional<APInt>
12058 // We avoid subtracting expressions here because this function is usually
12059 // fairly deep in the call stack (i.e. is called many times).
12060
12061 unsigned BW = getTypeSizeInBits(More->getType());
12062 APInt Diff(BW, 0);
12063 APInt DiffMul(BW, 1);
12064 // Try various simplifications to reduce the difference to a constant. Limit
12065 // the number of allowed simplifications to keep compile-time low.
12066 for (unsigned I = 0; I < 8; ++I) {
12067 if (More == Less)
12068 return Diff;
12069
12070 // Reduce addrecs with identical steps to their start value.
12071 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
12072 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12073 const auto *MAR = cast<SCEVAddRecExpr>(More);
12074
12075 if (LAR->getLoop() != MAR->getLoop())
12076 return std::nullopt;
12077
12078 // We look at affine expressions only; not for correctness but to keep
12079 // getStepRecurrence cheap.
12080 if (!LAR->isAffine() || !MAR->isAffine())
12081 return std::nullopt;
12082
12083 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12084 return std::nullopt;
12085
12086 Less = LAR->getStart();
12087 More = MAR->getStart();
12088 continue;
12089 }
12090
12091 // Try to match a common constant multiply.
12092 auto MatchConstMul =
12093 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12094 auto *M = dyn_cast<SCEVMulExpr>(S);
12095 if (!M || M->getNumOperands() != 2 ||
12096 !isa<SCEVConstant>(M->getOperand(0)))
12097 return std::nullopt;
12098 return {
12099 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12100 };
12101 if (auto MatchedMore = MatchConstMul(More)) {
12102 if (auto MatchedLess = MatchConstMul(Less)) {
12103 if (MatchedMore->second == MatchedLess->second) {
12104 More = MatchedMore->first;
12105 Less = MatchedLess->first;
12106 DiffMul *= MatchedMore->second;
12107 continue;
12108 }
12109 }
12110 }
12111
12112 // Try to cancel out common factors in two add expressions.
12114 auto Add = [&](const SCEV *S, int Mul) {
12115 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12116 if (Mul == 1) {
12117 Diff += C->getAPInt() * DiffMul;
12118 } else {
12119 assert(Mul == -1);
12120 Diff -= C->getAPInt() * DiffMul;
12121 }
12122 } else
12123 Multiplicity[S] += Mul;
12124 };
12125 auto Decompose = [&](const SCEV *S, int Mul) {
12126 if (isa<SCEVAddExpr>(S)) {
12127 for (const SCEV *Op : S->operands())
12128 Add(Op, Mul);
12129 } else
12130 Add(S, Mul);
12131 };
12132 Decompose(More, 1);
12133 Decompose(Less, -1);
12134
12135 // Check whether all the non-constants cancel out, or reduce to new
12136 // More/Less values.
12137 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12138 for (const auto &[S, Mul] : Multiplicity) {
12139 if (Mul == 0)
12140 continue;
12141 if (Mul == 1) {
12142 if (NewMore)
12143 return std::nullopt;
12144 NewMore = S;
12145 } else if (Mul == -1) {
12146 if (NewLess)
12147 return std::nullopt;
12148 NewLess = S;
12149 } else
12150 return std::nullopt;
12151 }
12152
12153 // Values stayed the same, no point in trying further.
12154 if (NewMore == More || NewLess == Less)
12155 return std::nullopt;
12156
12157 More = NewMore;
12158 Less = NewLess;
12159
12160 // Reduced to constant.
12161 if (!More && !Less)
12162 return Diff;
12163
12164 // Left with variable on only one side, bail out.
12165 if (!More || !Less)
12166 return std::nullopt;
12167 }
12168
12169 // Did not reduce to constant.
12170 return std::nullopt;
12171}
12172
12173bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12174 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12175 const SCEV *FoundRHS, const Instruction *CtxI) {
12176 // Try to recognize the following pattern:
12177 //
12178 // FoundRHS = ...
12179 // ...
12180 // loop:
12181 // FoundLHS = {Start,+,W}
12182 // context_bb: // Basic block from the same loop
12183 // known(Pred, FoundLHS, FoundRHS)
12184 //
12185 // If some predicate is known in the context of a loop, it is also known on
12186 // each iteration of this loop, including the first iteration. Therefore, in
12187 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12188 // prove the original pred using this fact.
12189 if (!CtxI)
12190 return false;
12191 const BasicBlock *ContextBB = CtxI->getParent();
12192 // Make sure AR varies in the context block.
12193 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12194 const Loop *L = AR->getLoop();
12195 // Make sure that context belongs to the loop and executes on 1st iteration
12196 // (if it ever executes at all).
12197 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12198 return false;
12199 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12200 return false;
12201 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12202 }
12203
12204 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12205 const Loop *L = AR->getLoop();
12206 // Make sure that context belongs to the loop and executes on 1st iteration
12207 // (if it ever executes at all).
12208 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12209 return false;
12210 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12211 return false;
12212 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12213 }
12214
12215 return false;
12216}
12217
12218bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12219 const SCEV *LHS,
12220 const SCEV *RHS,
12221 const SCEV *FoundLHS,
12222 const SCEV *FoundRHS) {
12223 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12224 return false;
12225
12226 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12227 if (!AddRecLHS)
12228 return false;
12229
12230 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12231 if (!AddRecFoundLHS)
12232 return false;
12233
12234 // We'd like to let SCEV reason about control dependencies, so we constrain
12235 // both the inequalities to be about add recurrences on the same loop. This
12236 // way we can use isLoopEntryGuardedByCond later.
12237
12238 const Loop *L = AddRecFoundLHS->getLoop();
12239 if (L != AddRecLHS->getLoop())
12240 return false;
12241
12242 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12243 //
12244 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12245 // ... (2)
12246 //
12247 // Informal proof for (2), assuming (1) [*]:
12248 //
12249 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12250 //
12251 // Then
12252 //
12253 // FoundLHS s< FoundRHS s< INT_MIN - C
12254 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12255 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12256 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12257 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12258 // <=> FoundLHS + C s< FoundRHS + C
12259 //
12260 // [*]: (1) can be proved by ruling out overflow.
12261 //
12262 // [**]: This can be proved by analyzing all the four possibilities:
12263 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12264 // (A s>= 0, B s>= 0).
12265 //
12266 // Note:
12267 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12268 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12269 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12270 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12271 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12272 // C)".
12273
12274 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12275 if (!LDiff)
12276 return false;
12277 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12278 if (!RDiff || *LDiff != *RDiff)
12279 return false;
12280
12281 if (LDiff->isMinValue())
12282 return true;
12283
12284 APInt FoundRHSLimit;
12285
12286 if (Pred == CmpInst::ICMP_ULT) {
12287 FoundRHSLimit = -(*RDiff);
12288 } else {
12289 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12290 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12291 }
12292
12293 // Try to prove (1) or (2), as needed.
12294 return isAvailableAtLoopEntry(FoundRHS, L) &&
12295 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12296 getConstant(FoundRHSLimit));
12297}
12298
12299bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12300 const SCEV *RHS, const SCEV *FoundLHS,
12301 const SCEV *FoundRHS, unsigned Depth) {
12302 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12303
12304 auto ClearOnExit = make_scope_exit([&]() {
12305 if (LPhi) {
12306 bool Erased = PendingMerges.erase(LPhi);
12307 assert(Erased && "Failed to erase LPhi!");
12308 (void)Erased;
12309 }
12310 if (RPhi) {
12311 bool Erased = PendingMerges.erase(RPhi);
12312 assert(Erased && "Failed to erase RPhi!");
12313 (void)Erased;
12314 }
12315 });
12316
12317 // Find respective Phis and check that they are not being pending.
12318 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12319 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12320 if (!PendingMerges.insert(Phi).second)
12321 return false;
12322 LPhi = Phi;
12323 }
12324 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12325 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12326 // If we detect a loop of Phi nodes being processed by this method, for
12327 // example:
12328 //
12329 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12330 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12331 //
12332 // we don't want to deal with a case that complex, so return conservative
12333 // answer false.
12334 if (!PendingMerges.insert(Phi).second)
12335 return false;
12336 RPhi = Phi;
12337 }
12338
12339 // If none of LHS, RHS is a Phi, nothing to do here.
12340 if (!LPhi && !RPhi)
12341 return false;
12342
12343 // If there is a SCEVUnknown Phi we are interested in, make it left.
12344 if (!LPhi) {
12345 std::swap(LHS, RHS);
12346 std::swap(FoundLHS, FoundRHS);
12347 std::swap(LPhi, RPhi);
12349 }
12350
12351 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12352 const BasicBlock *LBB = LPhi->getParent();
12353 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12354
12355 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12356 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12357 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12358 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12359 };
12360
12361 if (RPhi && RPhi->getParent() == LBB) {
12362 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12363 // If we compare two Phis from the same block, and for each entry block
12364 // the predicate is true for incoming values from this block, then the
12365 // predicate is also true for the Phis.
12366 for (const BasicBlock *IncBB : predecessors(LBB)) {
12367 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12368 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12369 if (!ProvedEasily(L, R))
12370 return false;
12371 }
12372 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12373 // Case two: RHS is also a Phi from the same basic block, and it is an
12374 // AddRec. It means that there is a loop which has both AddRec and Unknown
12375 // PHIs, for it we can compare incoming values of AddRec from above the loop
12376 // and latch with their respective incoming values of LPhi.
12377 // TODO: Generalize to handle loops with many inputs in a header.
12378 if (LPhi->getNumIncomingValues() != 2) return false;
12379
12380 auto *RLoop = RAR->getLoop();
12381 auto *Predecessor = RLoop->getLoopPredecessor();
12382 assert(Predecessor && "Loop with AddRec with no predecessor?");
12383 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12384 if (!ProvedEasily(L1, RAR->getStart()))
12385 return false;
12386 auto *Latch = RLoop->getLoopLatch();
12387 assert(Latch && "Loop with AddRec with no latch?");
12388 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12389 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12390 return false;
12391 } else {
12392 // In all other cases go over inputs of LHS and compare each of them to RHS,
12393 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12394 // At this point RHS is either a non-Phi, or it is a Phi from some block
12395 // different from LBB.
12396 for (const BasicBlock *IncBB : predecessors(LBB)) {
12397 // Check that RHS is available in this block.
12398 if (!dominates(RHS, IncBB))
12399 return false;
12400 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12401 // Make sure L does not refer to a value from a potentially previous
12402 // iteration of a loop.
12403 if (!properlyDominates(L, LBB))
12404 return false;
12405 if (!ProvedEasily(L, RHS))
12406 return false;
12407 }
12408 }
12409 return true;
12410}
12411
12412bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12413 const SCEV *LHS,
12414 const SCEV *RHS,
12415 const SCEV *FoundLHS,
12416 const SCEV *FoundRHS) {
12417 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12418 // sure that we are dealing with same LHS.
12419 if (RHS == FoundRHS) {
12420 std::swap(LHS, RHS);
12421 std::swap(FoundLHS, FoundRHS);
12423 }
12424 if (LHS != FoundLHS)
12425 return false;
12426
12427 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12428 if (!SUFoundRHS)
12429 return false;
12430
12431 Value *Shiftee, *ShiftValue;
12432
12433 using namespace PatternMatch;
12434 if (match(SUFoundRHS->getValue(),
12435 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12436 auto *ShifteeS = getSCEV(Shiftee);
12437 // Prove one of the following:
12438 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12439 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12440 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12441 // ---> LHS <s RHS
12442 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12443 // ---> LHS <=s RHS
12444 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12445 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12446 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12447 if (isKnownNonNegative(ShifteeS))
12448 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12449 }
12450
12451 return false;
12452}
12453
12454bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12455 const SCEV *RHS,
12456 const SCEV *FoundLHS,
12457 const SCEV *FoundRHS,
12458 const Instruction *CtxI) {
12459 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12460 return true;
12461
12462 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12463 return true;
12464
12465 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12466 return true;
12467
12468 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12469 CtxI))
12470 return true;
12471
12472 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12473 FoundLHS, FoundRHS);
12474}
12475
12476/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12477template <typename MinMaxExprType>
12478static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12479 const SCEV *Candidate) {
12480 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12481 if (!MinMaxExpr)
12482 return false;
12483
12484 return is_contained(MinMaxExpr->operands(), Candidate);
12485}
12486
12488 CmpPredicate Pred, const SCEV *LHS,
12489 const SCEV *RHS) {
12490 // If both sides are affine addrecs for the same loop, with equal
12491 // steps, and we know the recurrences don't wrap, then we only
12492 // need to check the predicate on the starting values.
12493
12494 if (!ICmpInst::isRelational(Pred))
12495 return false;
12496
12497 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12498 if (!LAR)
12499 return false;
12500 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12501 if (!RAR)
12502 return false;
12503 if (LAR->getLoop() != RAR->getLoop())
12504 return false;
12505 if (!LAR->isAffine() || !RAR->isAffine())
12506 return false;
12507
12508 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12509 return false;
12510
12513 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12514 return false;
12515
12516 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12517}
12518
12519/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12520/// expression?
12522 const SCEV *LHS, const SCEV *RHS) {
12523 switch (Pred) {
12524 default:
12525 return false;
12526
12527 case ICmpInst::ICMP_SGE:
12528 std::swap(LHS, RHS);
12529 [[fallthrough]];
12530 case ICmpInst::ICMP_SLE:
12531 return
12532 // min(A, ...) <= A
12533 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12534 // A <= max(A, ...)
12535 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12536
12537 case ICmpInst::ICMP_UGE:
12538 std::swap(LHS, RHS);
12539 [[fallthrough]];
12540 case ICmpInst::ICMP_ULE:
12541 return
12542 // min(A, ...) <= A
12543 // FIXME: what about umin_seq?
12544 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12545 // A <= max(A, ...)
12546 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12547 }
12548
12549 llvm_unreachable("covered switch fell through?!");
12550}
12551
12552bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12553 const SCEV *RHS,
12554 const SCEV *FoundLHS,
12555 const SCEV *FoundRHS,
12556 unsigned Depth) {
12559 "LHS and RHS have different sizes?");
12560 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12561 getTypeSizeInBits(FoundRHS->getType()) &&
12562 "FoundLHS and FoundRHS have different sizes?");
12563 // We want to avoid hurting the compile time with analysis of too big trees.
12565 return false;
12566
12567 // We only want to work with GT comparison so far.
12568 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12570 std::swap(LHS, RHS);
12571 std::swap(FoundLHS, FoundRHS);
12572 }
12573
12574 // For unsigned, try to reduce it to corresponding signed comparison.
12575 if (Pred == ICmpInst::ICMP_UGT)
12576 // We can replace unsigned predicate with its signed counterpart if all
12577 // involved values are non-negative.
12578 // TODO: We could have better support for unsigned.
12579 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12580 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12581 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12582 // use this fact to prove that LHS and RHS are non-negative.
12583 const SCEV *MinusOne = getMinusOne(LHS->getType());
12584 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12585 FoundRHS) &&
12586 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12587 FoundRHS))
12588 Pred = ICmpInst::ICMP_SGT;
12589 }
12590
12591 if (Pred != ICmpInst::ICMP_SGT)
12592 return false;
12593
12594 auto GetOpFromSExt = [&](const SCEV *S) {
12595 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12596 return Ext->getOperand();
12597 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12598 // the constant in some cases.
12599 return S;
12600 };
12601
12602 // Acquire values from extensions.
12603 auto *OrigLHS = LHS;
12604 auto *OrigFoundLHS = FoundLHS;
12605 LHS = GetOpFromSExt(LHS);
12606 FoundLHS = GetOpFromSExt(FoundLHS);
12607
12608 // Is the SGT predicate can be proved trivially or using the found context.
12609 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12610 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12611 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12612 FoundRHS, Depth + 1);
12613 };
12614
12615 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12616 // We want to avoid creation of any new non-constant SCEV. Since we are
12617 // going to compare the operands to RHS, we should be certain that we don't
12618 // need any size extensions for this. So let's decline all cases when the
12619 // sizes of types of LHS and RHS do not match.
12620 // TODO: Maybe try to get RHS from sext to catch more cases?
12622 return false;
12623
12624 // Should not overflow.
12625 if (!LHSAddExpr->hasNoSignedWrap())
12626 return false;
12627
12628 auto *LL = LHSAddExpr->getOperand(0);
12629 auto *LR = LHSAddExpr->getOperand(1);
12630 auto *MinusOne = getMinusOne(RHS->getType());
12631
12632 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12633 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12634 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12635 };
12636 // Try to prove the following rule:
12637 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12638 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12639 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12640 return true;
12641 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12642 Value *LL, *LR;
12643 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12644
12645 using namespace llvm::PatternMatch;
12646
12647 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12648 // Rules for division.
12649 // We are going to perform some comparisons with Denominator and its
12650 // derivative expressions. In general case, creating a SCEV for it may
12651 // lead to a complex analysis of the entire graph, and in particular it
12652 // can request trip count recalculation for the same loop. This would
12653 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12654 // this, we only want to create SCEVs that are constants in this section.
12655 // So we bail if Denominator is not a constant.
12656 if (!isa<ConstantInt>(LR))
12657 return false;
12658
12659 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12660
12661 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12662 // then a SCEV for the numerator already exists and matches with FoundLHS.
12663 auto *Numerator = getExistingSCEV(LL);
12664 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12665 return false;
12666
12667 // Make sure that the numerator matches with FoundLHS and the denominator
12668 // is positive.
12669 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12670 return false;
12671
12672 auto *DTy = Denominator->getType();
12673 auto *FRHSTy = FoundRHS->getType();
12674 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12675 // One of types is a pointer and another one is not. We cannot extend
12676 // them properly to a wider type, so let us just reject this case.
12677 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12678 // to avoid this check.
12679 return false;
12680
12681 // Given that:
12682 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12683 auto *WTy = getWiderType(DTy, FRHSTy);
12684 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12685 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12686
12687 // Try to prove the following rule:
12688 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12689 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12690 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12691 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12692 if (isKnownNonPositive(RHS) &&
12693 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12694 return true;
12695
12696 // Try to prove the following rule:
12697 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12698 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12699 // If we divide it by Denominator > 2, then:
12700 // 1. If FoundLHS is negative, then the result is 0.
12701 // 2. If FoundLHS is non-negative, then the result is non-negative.
12702 // Anyways, the result is non-negative.
12703 auto *MinusOne = getMinusOne(WTy);
12704 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12705 if (isKnownNegative(RHS) &&
12706 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12707 return true;
12708 }
12709 }
12710
12711 // If our expression contained SCEVUnknown Phis, and we split it down and now
12712 // need to prove something for them, try to prove the predicate for every
12713 // possible incoming values of those Phis.
12714 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12715 return true;
12716
12717 return false;
12718}
12719
12720static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
12721 const SCEV *RHS) {
12722 // zext x u<= sext x, sext x s<= zext x
12723 const SCEV *Op;
12724 switch (Pred) {
12725 case ICmpInst::ICMP_SGE:
12726 std::swap(LHS, RHS);
12727 [[fallthrough]];
12728 case ICmpInst::ICMP_SLE: {
12729 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12730 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12732 }
12733 case ICmpInst::ICMP_UGE:
12734 std::swap(LHS, RHS);
12735 [[fallthrough]];
12736 case ICmpInst::ICMP_ULE: {
12737 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12738 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12740 }
12741 default:
12742 return false;
12743 };
12744 llvm_unreachable("unhandled case");
12745}
12746
12747bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12748 const SCEV *LHS,
12749 const SCEV *RHS) {
12750 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12751 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12752 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12753 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12754 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12755}
12756
12757bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12758 const SCEV *LHS,
12759 const SCEV *RHS,
12760 const SCEV *FoundLHS,
12761 const SCEV *FoundRHS) {
12762 switch (Pred) {
12763 default:
12764 llvm_unreachable("Unexpected CmpPredicate value!");
12765 case ICmpInst::ICMP_EQ:
12766 case ICmpInst::ICMP_NE:
12767 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12768 return true;
12769 break;
12770 case ICmpInst::ICMP_SLT:
12771 case ICmpInst::ICMP_SLE:
12772 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12773 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12774 return true;
12775 break;
12776 case ICmpInst::ICMP_SGT:
12777 case ICmpInst::ICMP_SGE:
12778 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12779 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12780 return true;
12781 break;
12782 case ICmpInst::ICMP_ULT:
12783 case ICmpInst::ICMP_ULE:
12784 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12785 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12786 return true;
12787 break;
12788 case ICmpInst::ICMP_UGT:
12789 case ICmpInst::ICMP_UGE:
12790 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12791 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12792 return true;
12793 break;
12794 }
12795
12796 // Maybe it can be proved via operations?
12797 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12798 return true;
12799
12800 return false;
12801}
12802
12803bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12804 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12805 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12806 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12807 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12808 // reduce the compile time impact of this optimization.
12809 return false;
12810
12811 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12812 if (!Addend)
12813 return false;
12814
12815 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12816
12817 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12818 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12819 ConstantRange FoundLHSRange =
12820 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12821
12822 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12823 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12824
12825 // We can also compute the range of values for `LHS` that satisfy the
12826 // consequent, "`LHS` `Pred` `RHS`":
12827 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12828 // The antecedent implies the consequent if every value of `LHS` that
12829 // satisfies the antecedent also satisfies the consequent.
12830 return LHSRange.icmp(Pred, ConstRHS);
12831}
12832
12833bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12834 bool IsSigned) {
12835 assert(isKnownPositive(Stride) && "Positive stride expected!");
12836
12837 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12838 const SCEV *One = getOne(Stride->getType());
12839
12840 if (IsSigned) {
12841 APInt MaxRHS = getSignedRangeMax(RHS);
12843 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12844
12845 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12846 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12847 }
12848
12849 APInt MaxRHS = getUnsignedRangeMax(RHS);
12850 APInt MaxValue = APInt::getMaxValue(BitWidth);
12851 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12852
12853 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12854 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12855}
12856
12857bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12858 bool IsSigned) {
12859
12860 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12861 const SCEV *One = getOne(Stride->getType());
12862
12863 if (IsSigned) {
12864 APInt MinRHS = getSignedRangeMin(RHS);
12866 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12867
12868 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12869 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12870 }
12871
12872 APInt MinRHS = getUnsignedRangeMin(RHS);
12873 APInt MinValue = APInt::getMinValue(BitWidth);
12874 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12875
12876 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12877 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12878}
12879
12881 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12882 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12883 // expression fixes the case of N=0.
12884 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12885 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12886 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12887}
12888
12889const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12890 const SCEV *Stride,
12891 const SCEV *End,
12892 unsigned BitWidth,
12893 bool IsSigned) {
12894 // The logic in this function assumes we can represent a positive stride.
12895 // If we can't, the backedge-taken count must be zero.
12896 if (IsSigned && BitWidth == 1)
12897 return getZero(Stride->getType());
12898
12899 // This code below only been closely audited for negative strides in the
12900 // unsigned comparison case, it may be correct for signed comparison, but
12901 // that needs to be established.
12902 if (IsSigned && isKnownNegative(Stride))
12903 return getCouldNotCompute();
12904
12905 // Calculate the maximum backedge count based on the range of values
12906 // permitted by Start, End, and Stride.
12907 APInt MinStart =
12908 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12909
12910 APInt MinStride =
12911 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12912
12913 // We assume either the stride is positive, or the backedge-taken count
12914 // is zero. So force StrideForMaxBECount to be at least one.
12915 APInt One(BitWidth, 1);
12916 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12917 : APIntOps::umax(One, MinStride);
12918
12919 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12920 : APInt::getMaxValue(BitWidth);
12921 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12922
12923 // Although End can be a MAX expression we estimate MaxEnd considering only
12924 // the case End = RHS of the loop termination condition. This is safe because
12925 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12926 // taken count.
12927 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12928 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12929
12930 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12931 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12932 : APIntOps::umax(MaxEnd, MinStart);
12933
12934 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12935 getConstant(StrideForMaxBECount) /* Step */);
12936}
12937
12939ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12940 const Loop *L, bool IsSigned,
12941 bool ControlsOnlyExit, bool AllowPredicates) {
12943
12944 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12945 bool PredicatedIV = false;
12946 if (!IV) {
12947 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12948 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12949 if (AR && AR->getLoop() == L && AR->isAffine()) {
12950 auto canProveNUW = [&]() {
12951 // We can use the comparison to infer no-wrap flags only if it fully
12952 // controls the loop exit.
12953 if (!ControlsOnlyExit)
12954 return false;
12955
12956 if (!isLoopInvariant(RHS, L))
12957 return false;
12958
12959 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12960 // We need the sequence defined by AR to strictly increase in the
12961 // unsigned integer domain for the logic below to hold.
12962 return false;
12963
12964 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12965 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12966 // If RHS <=u Limit, then there must exist a value V in the sequence
12967 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12968 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12969 // overflow occurs. This limit also implies that a signed comparison
12970 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12971 // the high bits on both sides must be zero.
12972 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12973 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12974 Limit = Limit.zext(OuterBitWidth);
12975 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12976 };
12977 auto Flags = AR->getNoWrapFlags();
12978 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12979 Flags = setFlags(Flags, SCEV::FlagNUW);
12980
12981 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12982 if (AR->hasNoUnsignedWrap()) {
12983 // Emulate what getZeroExtendExpr would have done during construction
12984 // if we'd been able to infer the fact just above at that time.
12985 const SCEV *Step = AR->getStepRecurrence(*this);
12986 Type *Ty = ZExt->getType();
12987 auto *S = getAddRecExpr(
12988 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12989 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12990 IV = dyn_cast<SCEVAddRecExpr>(S);
12991 }
12992 }
12993 }
12994 }
12995
12996
12997 if (!IV && AllowPredicates) {
12998 // Try to make this an AddRec using runtime tests, in the first X
12999 // iterations of this loop, where X is the SCEV expression found by the
13000 // algorithm below.
13001 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13002 PredicatedIV = true;
13003 }
13004
13005 // Avoid weird loops
13006 if (!IV || IV->getLoop() != L || !IV->isAffine())
13007 return getCouldNotCompute();
13008
13009 // A precondition of this method is that the condition being analyzed
13010 // reaches an exiting branch which dominates the latch. Given that, we can
13011 // assume that an increment which violates the nowrap specification and
13012 // produces poison must cause undefined behavior when the resulting poison
13013 // value is branched upon and thus we can conclude that the backedge is
13014 // taken no more often than would be required to produce that poison value.
13015 // Note that a well defined loop can exit on the iteration which violates
13016 // the nowrap specification if there is another exit (either explicit or
13017 // implicit/exceptional) which causes the loop to execute before the
13018 // exiting instruction we're analyzing would trigger UB.
13019 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13020 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13022
13023 const SCEV *Stride = IV->getStepRecurrence(*this);
13024
13025 bool PositiveStride = isKnownPositive(Stride);
13026
13027 // Avoid negative or zero stride values.
13028 if (!PositiveStride) {
13029 // We can compute the correct backedge taken count for loops with unknown
13030 // strides if we can prove that the loop is not an infinite loop with side
13031 // effects. Here's the loop structure we are trying to handle -
13032 //
13033 // i = start
13034 // do {
13035 // A[i] = i;
13036 // i += s;
13037 // } while (i < end);
13038 //
13039 // The backedge taken count for such loops is evaluated as -
13040 // (max(end, start + stride) - start - 1) /u stride
13041 //
13042 // The additional preconditions that we need to check to prove correctness
13043 // of the above formula is as follows -
13044 //
13045 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13046 // NoWrap flag).
13047 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13048 // no side effects within the loop)
13049 // c) loop has a single static exit (with no abnormal exits)
13050 //
13051 // Precondition a) implies that if the stride is negative, this is a single
13052 // trip loop. The backedge taken count formula reduces to zero in this case.
13053 //
13054 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13055 // then a zero stride means the backedge can't be taken without executing
13056 // undefined behavior.
13057 //
13058 // The positive stride case is the same as isKnownPositive(Stride) returning
13059 // true (original behavior of the function).
13060 //
13061 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13063 return getCouldNotCompute();
13064
13065 if (!isKnownNonZero(Stride)) {
13066 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13067 // if it might eventually be greater than start and if so, on which
13068 // iteration. We can't even produce a useful upper bound.
13069 if (!isLoopInvariant(RHS, L))
13070 return getCouldNotCompute();
13071
13072 // We allow a potentially zero stride, but we need to divide by stride
13073 // below. Since the loop can't be infinite and this check must control
13074 // the sole exit, we can infer the exit must be taken on the first
13075 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13076 // we know the numerator in the divides below must be zero, so we can
13077 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13078 // and produce the right result.
13079 // FIXME: Handle the case where Stride is poison?
13080 auto wouldZeroStrideBeUB = [&]() {
13081 // Proof by contradiction. Suppose the stride were zero. If we can
13082 // prove that the backedge *is* taken on the first iteration, then since
13083 // we know this condition controls the sole exit, we must have an
13084 // infinite loop. We can't have a (well defined) infinite loop per
13085 // check just above.
13086 // Note: The (Start - Stride) term is used to get the start' term from
13087 // (start' + stride,+,stride). Remember that we only care about the
13088 // result of this expression when stride == 0 at runtime.
13089 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13090 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13091 };
13092 if (!wouldZeroStrideBeUB()) {
13093 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13094 }
13095 }
13096 } else if (!NoWrap) {
13097 // Avoid proven overflow cases: this will ensure that the backedge taken
13098 // count will not generate any unsigned overflow.
13099 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13100 return getCouldNotCompute();
13101 }
13102
13103 // On all paths just preceeding, we established the following invariant:
13104 // IV can be assumed not to overflow up to and including the exiting
13105 // iteration. We proved this in one of two ways:
13106 // 1) We can show overflow doesn't occur before the exiting iteration
13107 // 1a) canIVOverflowOnLT, and b) step of one
13108 // 2) We can show that if overflow occurs, the loop must execute UB
13109 // before any possible exit.
13110 // Note that we have not yet proved RHS invariant (in general).
13111
13112 const SCEV *Start = IV->getStart();
13113
13114 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13115 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13116 // Use integer-typed versions for actual computation; we can't subtract
13117 // pointers in general.
13118 const SCEV *OrigStart = Start;
13119 const SCEV *OrigRHS = RHS;
13120 if (Start->getType()->isPointerTy()) {
13121 Start = getLosslessPtrToIntExpr(Start);
13122 if (isa<SCEVCouldNotCompute>(Start))
13123 return Start;
13124 }
13125 if (RHS->getType()->isPointerTy()) {
13127 if (isa<SCEVCouldNotCompute>(RHS))
13128 return RHS;
13129 }
13130
13131 const SCEV *End = nullptr, *BECount = nullptr,
13132 *BECountIfBackedgeTaken = nullptr;
13133 if (!isLoopInvariant(RHS, L)) {
13134 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13135 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13136 RHSAddRec->getNoWrapFlags()) {
13137 // The structure of loop we are trying to calculate backedge count of:
13138 //
13139 // left = left_start
13140 // right = right_start
13141 //
13142 // while(left < right){
13143 // ... do something here ...
13144 // left += s1; // stride of left is s1 (s1 > 0)
13145 // right += s2; // stride of right is s2 (s2 < 0)
13146 // }
13147 //
13148
13149 const SCEV *RHSStart = RHSAddRec->getStart();
13150 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13151
13152 // If Stride - RHSStride is positive and does not overflow, we can write
13153 // backedge count as ->
13154 // ceil((End - Start) /u (Stride - RHSStride))
13155 // Where, End = max(RHSStart, Start)
13156
13157 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13158 if (isKnownNegative(RHSStride) &&
13159 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13160 RHSStride)) {
13161
13162 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13163 if (isKnownPositive(Denominator)) {
13164 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13165 : getUMaxExpr(RHSStart, Start);
13166
13167 // We can do this because End >= Start, as End = max(RHSStart, Start)
13168 const SCEV *Delta = getMinusSCEV(End, Start);
13169
13170 BECount = getUDivCeilSCEV(Delta, Denominator);
13171 BECountIfBackedgeTaken =
13172 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13173 }
13174 }
13175 }
13176 if (BECount == nullptr) {
13177 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13178 // given the start, stride and max value for the end bound of the
13179 // loop (RHS), and the fact that IV does not overflow (which is
13180 // checked above).
13181 const SCEV *MaxBECount = computeMaxBECountForLT(
13182 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13183 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13184 MaxBECount, false /*MaxOrZero*/, Predicates);
13185 }
13186 } else {
13187 // We use the expression (max(End,Start)-Start)/Stride to describe the
13188 // backedge count, as if the backedge is taken at least once
13189 // max(End,Start) is End and so the result is as above, and if not
13190 // max(End,Start) is Start so we get a backedge count of zero.
13191 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13192 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13193 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13194 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13195 // Can we prove (max(RHS,Start) > Start - Stride?
13196 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13197 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13198 // In this case, we can use a refined formula for computing backedge
13199 // taken count. The general formula remains:
13200 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13201 // We want to use the alternate formula:
13202 // "((End - 1) - (Start - Stride)) /u Stride"
13203 // Let's do a quick case analysis to show these are equivalent under
13204 // our precondition that max(RHS,Start) > Start - Stride.
13205 // * For RHS <= Start, the backedge-taken count must be zero.
13206 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13207 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13208 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13209 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13210 // reducing this to the stride of 1 case.
13211 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13212 // Stride".
13213 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13214 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13215 // "((RHS - (Start - Stride) - 1) /u Stride".
13216 // Our preconditions trivially imply no overflow in that form.
13217 const SCEV *MinusOne = getMinusOne(Stride->getType());
13218 const SCEV *Numerator =
13219 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13220 BECount = getUDivExpr(Numerator, Stride);
13221 }
13222
13223 if (!BECount) {
13224 auto canProveRHSGreaterThanEqualStart = [&]() {
13225 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13226 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13227 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13228
13229 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13230 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13231 return true;
13232
13233 // (RHS > Start - 1) implies RHS >= Start.
13234 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13235 // "Start - 1" doesn't overflow.
13236 // * For signed comparison, if Start - 1 does overflow, it's equal
13237 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13238 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13239 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13240 //
13241 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13242 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13243 auto *StartMinusOne =
13244 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13245 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13246 };
13247
13248 // If we know that RHS >= Start in the context of loop, then we know
13249 // that max(RHS, Start) = RHS at this point.
13250 if (canProveRHSGreaterThanEqualStart()) {
13251 End = RHS;
13252 } else {
13253 // If RHS < Start, the backedge will be taken zero times. So in
13254 // general, we can write the backedge-taken count as:
13255 //
13256 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13257 //
13258 // We convert it to the following to make it more convenient for SCEV:
13259 //
13260 // ceil(max(RHS, Start) - Start) / Stride
13261 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13262
13263 // See what would happen if we assume the backedge is taken. This is
13264 // used to compute MaxBECount.
13265 BECountIfBackedgeTaken =
13266 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13267 }
13268
13269 // At this point, we know:
13270 //
13271 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13272 // 2. The index variable doesn't overflow.
13273 //
13274 // Therefore, we know N exists such that
13275 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13276 // doesn't overflow.
13277 //
13278 // Using this information, try to prove whether the addition in
13279 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13280 const SCEV *One = getOne(Stride->getType());
13281 bool MayAddOverflow = [&] {
13282 if (isKnownToBeAPowerOfTwo(Stride)) {
13283 // Suppose Stride is a power of two, and Start/End are unsigned
13284 // integers. Let UMAX be the largest representable unsigned
13285 // integer.
13286 //
13287 // By the preconditions of this function, we know
13288 // "(Start + Stride * N) >= End", and this doesn't overflow.
13289 // As a formula:
13290 //
13291 // End <= (Start + Stride * N) <= UMAX
13292 //
13293 // Subtracting Start from all the terms:
13294 //
13295 // End - Start <= Stride * N <= UMAX - Start
13296 //
13297 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13298 //
13299 // End - Start <= Stride * N <= UMAX
13300 //
13301 // Stride * N is a multiple of Stride. Therefore,
13302 //
13303 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13304 //
13305 // Since Stride is a power of two, UMAX + 1 is divisible by
13306 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13307 // write:
13308 //
13309 // End - Start <= Stride * N <= UMAX - Stride - 1
13310 //
13311 // Dropping the middle term:
13312 //
13313 // End - Start <= UMAX - Stride - 1
13314 //
13315 // Adding Stride - 1 to both sides:
13316 //
13317 // (End - Start) + (Stride - 1) <= UMAX
13318 //
13319 // In other words, the addition doesn't have unsigned overflow.
13320 //
13321 // A similar proof works if we treat Start/End as signed values.
13322 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13323 // to use signed max instead of unsigned max. Note that we're
13324 // trying to prove a lack of unsigned overflow in either case.
13325 return false;
13326 }
13327 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13328 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13329 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13330 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13331 // 1 <s End.
13332 //
13333 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13334 // End.
13335 return false;
13336 }
13337 return true;
13338 }();
13339
13340 const SCEV *Delta = getMinusSCEV(End, Start);
13341 if (!MayAddOverflow) {
13342 // floor((D + (S - 1)) / S)
13343 // We prefer this formulation if it's legal because it's fewer
13344 // operations.
13345 BECount =
13346 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13347 } else {
13348 BECount = getUDivCeilSCEV(Delta, Stride);
13349 }
13350 }
13351 }
13352
13353 const SCEV *ConstantMaxBECount;
13354 bool MaxOrZero = false;
13355 if (isa<SCEVConstant>(BECount)) {
13356 ConstantMaxBECount = BECount;
13357 } else if (BECountIfBackedgeTaken &&
13358 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13359 // If we know exactly how many times the backedge will be taken if it's
13360 // taken at least once, then the backedge count will either be that or
13361 // zero.
13362 ConstantMaxBECount = BECountIfBackedgeTaken;
13363 MaxOrZero = true;
13364 } else {
13365 ConstantMaxBECount = computeMaxBECountForLT(
13366 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13367 }
13368
13369 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13370 !isa<SCEVCouldNotCompute>(BECount))
13371 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13372
13373 const SCEV *SymbolicMaxBECount =
13374 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13375 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13376 Predicates);
13377}
13378
13379ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13380 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13381 bool ControlsOnlyExit, bool AllowPredicates) {
13383 // We handle only IV > Invariant
13384 if (!isLoopInvariant(RHS, L))
13385 return getCouldNotCompute();
13386
13387 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13388 if (!IV && AllowPredicates)
13389 // Try to make this an AddRec using runtime tests, in the first X
13390 // iterations of this loop, where X is the SCEV expression found by the
13391 // algorithm below.
13392 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13393
13394 // Avoid weird loops
13395 if (!IV || IV->getLoop() != L || !IV->isAffine())
13396 return getCouldNotCompute();
13397
13398 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13399 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13401
13402 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13403
13404 // Avoid negative or zero stride values
13405 if (!isKnownPositive(Stride))
13406 return getCouldNotCompute();
13407
13408 // Avoid proven overflow cases: this will ensure that the backedge taken count
13409 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13410 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13411 // behaviors like the case of C language.
13412 if (!Stride->isOne() && !NoWrap)
13413 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13414 return getCouldNotCompute();
13415
13416 const SCEV *Start = IV->getStart();
13417 const SCEV *End = RHS;
13418 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13419 // If we know that Start >= RHS in the context of loop, then we know that
13420 // min(RHS, Start) = RHS at this point.
13422 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13423 End = RHS;
13424 else
13425 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13426 }
13427
13428 if (Start->getType()->isPointerTy()) {
13429 Start = getLosslessPtrToIntExpr(Start);
13430 if (isa<SCEVCouldNotCompute>(Start))
13431 return Start;
13432 }
13433 if (End->getType()->isPointerTy()) {
13435 if (isa<SCEVCouldNotCompute>(End))
13436 return End;
13437 }
13438
13439 // Compute ((Start - End) + (Stride - 1)) / Stride.
13440 // FIXME: This can overflow. Holding off on fixing this for now;
13441 // howManyGreaterThans will hopefully be gone soon.
13442 const SCEV *One = getOne(Stride->getType());
13443 const SCEV *BECount = getUDivExpr(
13444 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13445
13446 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13447 : getUnsignedRangeMax(Start);
13448
13449 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13450 : getUnsignedRangeMin(Stride);
13451
13452 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13453 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13454 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13455
13456 // Although End can be a MIN expression we estimate MinEnd considering only
13457 // the case End = RHS. This is safe because in the other case (Start - End)
13458 // is zero, leading to a zero maximum backedge taken count.
13459 APInt MinEnd =
13460 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13461 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13462
13463 const SCEV *ConstantMaxBECount =
13464 isa<SCEVConstant>(BECount)
13465 ? BECount
13466 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13467 getConstant(MinStride));
13468
13469 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13470 ConstantMaxBECount = BECount;
13471 const SCEV *SymbolicMaxBECount =
13472 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13473
13474 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13475 Predicates);
13476}
13477
13479 ScalarEvolution &SE) const {
13480 if (Range.isFullSet()) // Infinite loop.
13481 return SE.getCouldNotCompute();
13482
13483 // If the start is a non-zero constant, shift the range to simplify things.
13484 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13485 if (!SC->getValue()->isZero()) {
13487 Operands[0] = SE.getZero(SC->getType());
13488 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13489 getNoWrapFlags(FlagNW));
13490 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13491 return ShiftedAddRec->getNumIterationsInRange(
13492 Range.subtract(SC->getAPInt()), SE);
13493 // This is strange and shouldn't happen.
13494 return SE.getCouldNotCompute();
13495 }
13496
13497 // The only time we can solve this is when we have all constant indices.
13498 // Otherwise, we cannot determine the overflow conditions.
13499 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13500 return SE.getCouldNotCompute();
13501
13502 // Okay at this point we know that all elements of the chrec are constants and
13503 // that the start element is zero.
13504
13505 // First check to see if the range contains zero. If not, the first
13506 // iteration exits.
13507 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13508 if (!Range.contains(APInt(BitWidth, 0)))
13509 return SE.getZero(getType());
13510
13511 if (isAffine()) {
13512 // If this is an affine expression then we have this situation:
13513 // Solve {0,+,A} in Range === Ax in Range
13514
13515 // We know that zero is in the range. If A is positive then we know that
13516 // the upper value of the range must be the first possible exit value.
13517 // If A is negative then the lower of the range is the last possible loop
13518 // value. Also note that we already checked for a full range.
13519 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13520 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13521
13522 // The exit value should be (End+A)/A.
13523 APInt ExitVal = (End + A).udiv(A);
13524 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13525
13526 // Evaluate at the exit value. If we really did fall out of the valid
13527 // range, then we computed our trip count, otherwise wrap around or other
13528 // things must have happened.
13529 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13530 if (Range.contains(Val->getValue()))
13531 return SE.getCouldNotCompute(); // Something strange happened
13532
13533 // Ensure that the previous value is in the range.
13536 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13537 "Linear scev computation is off in a bad way!");
13538 return SE.getConstant(ExitValue);
13539 }
13540
13541 if (isQuadratic()) {
13542 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13543 return SE.getConstant(*S);
13544 }
13545
13546 return SE.getCouldNotCompute();
13547}
13548
13549const SCEVAddRecExpr *
13551 assert(getNumOperands() > 1 && "AddRec with zero step?");
13552 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13553 // but in this case we cannot guarantee that the value returned will be an
13554 // AddRec because SCEV does not have a fixed point where it stops
13555 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13556 // may happen if we reach arithmetic depth limit while simplifying. So we
13557 // construct the returned value explicitly.
13559 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13560 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13561 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13562 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13563 // We know that the last operand is not a constant zero (otherwise it would
13564 // have been popped out earlier). This guarantees us that if the result has
13565 // the same last operand, then it will also not be popped out, meaning that
13566 // the returned value will be an AddRec.
13567 const SCEV *Last = getOperand(getNumOperands() - 1);
13568 assert(!Last->isZero() && "Recurrency with zero step?");
13569 Ops.push_back(Last);
13570 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13572}
13573
13574// Return true when S contains at least an undef value.
13576 return SCEVExprContains(S, [](const SCEV *S) {
13577 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13578 return isa<UndefValue>(SU->getValue());
13579 return false;
13580 });
13581}
13582
13583// Return true when S contains a value that is a nullptr.
13585 return SCEVExprContains(S, [](const SCEV *S) {
13586 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13587 return SU->getValue() == nullptr;
13588 return false;
13589 });
13590}
13591
13592/// Return the size of an element read or written by Inst.
13594 Type *Ty;
13595 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13596 Ty = Store->getValueOperand()->getType();
13597 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13598 Ty = Load->getType();
13599 else
13600 return nullptr;
13601
13603 return getSizeOfExpr(ETy, Ty);
13604}
13605
13606//===----------------------------------------------------------------------===//
13607// SCEVCallbackVH Class Implementation
13608//===----------------------------------------------------------------------===//
13609
13610void ScalarEvolution::SCEVCallbackVH::deleted() {
13611 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13612 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13613 SE->ConstantEvolutionLoopExitValue.erase(PN);
13614 SE->eraseValueFromMap(getValPtr());
13615 // this now dangles!
13616}
13617
13618void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13619 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13620
13621 // Forget all the expressions associated with users of the old value,
13622 // so that future queries will recompute the expressions using the new
13623 // value.
13624 SE->forgetValue(getValPtr());
13625 // this now dangles!
13626}
13627
13628ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13629 : CallbackVH(V), SE(se) {}
13630
13631//===----------------------------------------------------------------------===//
13632// ScalarEvolution Class Implementation
13633//===----------------------------------------------------------------------===//
13634
13637 LoopInfo &LI)
13638 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13639 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13640 LoopDispositions(64), BlockDispositions(64) {
13641 // To use guards for proving predicates, we need to scan every instruction in
13642 // relevant basic blocks, and not just terminators. Doing this is a waste of
13643 // time if the IR does not actually contain any calls to
13644 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13645 //
13646 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13647 // to _add_ guards to the module when there weren't any before, and wants
13648 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13649 // efficient in lieu of being smart in that rather obscure case.
13650
13651 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13652 F.getParent(), Intrinsic::experimental_guard);
13653 HasGuards = GuardDecl && !GuardDecl->use_empty();
13654}
13655
13657 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13658 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13659 ValueExprMap(std::move(Arg.ValueExprMap)),
13660 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13661 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13662 PendingMerges(std::move(Arg.PendingMerges)),
13663 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13664 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13665 PredicatedBackedgeTakenCounts(
13666 std::move(Arg.PredicatedBackedgeTakenCounts)),
13667 BECountUsers(std::move(Arg.BECountUsers)),
13668 ConstantEvolutionLoopExitValue(
13669 std::move(Arg.ConstantEvolutionLoopExitValue)),
13670 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13671 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13672 LoopDispositions(std::move(Arg.LoopDispositions)),
13673 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13674 BlockDispositions(std::move(Arg.BlockDispositions)),
13675 SCEVUsers(std::move(Arg.SCEVUsers)),
13676 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13677 SignedRanges(std::move(Arg.SignedRanges)),
13678 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13679 UniquePreds(std::move(Arg.UniquePreds)),
13680 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13681 LoopUsers(std::move(Arg.LoopUsers)),
13682 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13683 FirstUnknown(Arg.FirstUnknown) {
13684 Arg.FirstUnknown = nullptr;
13685}
13686
13688 // Iterate through all the SCEVUnknown instances and call their
13689 // destructors, so that they release their references to their values.
13690 for (SCEVUnknown *U = FirstUnknown; U;) {
13691 SCEVUnknown *Tmp = U;
13692 U = U->Next;
13693 Tmp->~SCEVUnknown();
13694 }
13695 FirstUnknown = nullptr;
13696
13697 ExprValueMap.clear();
13698 ValueExprMap.clear();
13699 HasRecMap.clear();
13700 BackedgeTakenCounts.clear();
13701 PredicatedBackedgeTakenCounts.clear();
13702
13703 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13704 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13705 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13706 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13707 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13708}
13709
13711 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13712}
13713
13714/// When printing a top-level SCEV for trip counts, it's helpful to include
13715/// a type for constants which are otherwise hard to disambiguate.
13716static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13717 if (isa<SCEVConstant>(S))
13718 OS << *S->getType() << " ";
13719 OS << *S;
13720}
13721
13723 const Loop *L) {
13724 // Print all inner loops first
13725 for (Loop *I : *L)
13726 PrintLoopInfo(OS, SE, I);
13727
13728 OS << "Loop ";
13729 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13730 OS << ": ";
13731
13732 SmallVector<BasicBlock *, 8> ExitingBlocks;
13733 L->getExitingBlocks(ExitingBlocks);
13734 if (ExitingBlocks.size() != 1)
13735 OS << "<multiple exits> ";
13736
13737 auto *BTC = SE->getBackedgeTakenCount(L);
13738 if (!isa<SCEVCouldNotCompute>(BTC)) {
13739 OS << "backedge-taken count is ";
13741 } else
13742 OS << "Unpredictable backedge-taken count.";
13743 OS << "\n";
13744
13745 if (ExitingBlocks.size() > 1)
13746 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13747 OS << " exit count for " << ExitingBlock->getName() << ": ";
13748 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13750 if (isa<SCEVCouldNotCompute>(EC)) {
13751 // Retry with predicates.
13753 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13754 if (!isa<SCEVCouldNotCompute>(EC)) {
13755 OS << "\n predicated exit count for " << ExitingBlock->getName()
13756 << ": ";
13758 OS << "\n Predicates:\n";
13759 for (const auto *P : Predicates)
13760 P->print(OS, 4);
13761 }
13762 }
13763 OS << "\n";
13764 }
13765
13766 OS << "Loop ";
13767 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13768 OS << ": ";
13769
13770 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13771 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13772 OS << "constant max backedge-taken count is ";
13773 PrintSCEVWithTypeHint(OS, ConstantBTC);
13775 OS << ", actual taken count either this or zero.";
13776 } else {
13777 OS << "Unpredictable constant max backedge-taken count. ";
13778 }
13779
13780 OS << "\n"
13781 "Loop ";
13782 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13783 OS << ": ";
13784
13785 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13786 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13787 OS << "symbolic max backedge-taken count is ";
13788 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13790 OS << ", actual taken count either this or zero.";
13791 } else {
13792 OS << "Unpredictable symbolic max backedge-taken count. ";
13793 }
13794 OS << "\n";
13795
13796 if (ExitingBlocks.size() > 1)
13797 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13798 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13799 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13801 PrintSCEVWithTypeHint(OS, ExitBTC);
13802 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13803 // Retry with predicates.
13805 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13807 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13808 OS << "\n predicated symbolic max exit count for "
13809 << ExitingBlock->getName() << ": ";
13810 PrintSCEVWithTypeHint(OS, ExitBTC);
13811 OS << "\n Predicates:\n";
13812 for (const auto *P : Predicates)
13813 P->print(OS, 4);
13814 }
13815 }
13816 OS << "\n";
13817 }
13818
13820 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13821 if (PBT != BTC) {
13822 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13823 OS << "Loop ";
13824 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13825 OS << ": ";
13826 if (!isa<SCEVCouldNotCompute>(PBT)) {
13827 OS << "Predicated backedge-taken count is ";
13829 } else
13830 OS << "Unpredictable predicated backedge-taken count.";
13831 OS << "\n";
13832 OS << " Predicates:\n";
13833 for (const auto *P : Preds)
13834 P->print(OS, 4);
13835 }
13836 Preds.clear();
13837
13838 auto *PredConstantMax =
13840 if (PredConstantMax != ConstantBTC) {
13841 assert(!Preds.empty() &&
13842 "different predicated constant max BTC but no predicates");
13843 OS << "Loop ";
13844 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13845 OS << ": ";
13846 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13847 OS << "Predicated constant max backedge-taken count is ";
13848 PrintSCEVWithTypeHint(OS, PredConstantMax);
13849 } else
13850 OS << "Unpredictable predicated constant max backedge-taken count.";
13851 OS << "\n";
13852 OS << " Predicates:\n";
13853 for (const auto *P : Preds)
13854 P->print(OS, 4);
13855 }
13856 Preds.clear();
13857
13858 auto *PredSymbolicMax =
13860 if (SymbolicBTC != PredSymbolicMax) {
13861 assert(!Preds.empty() &&
13862 "Different predicated symbolic max BTC, but no predicates");
13863 OS << "Loop ";
13864 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13865 OS << ": ";
13866 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13867 OS << "Predicated symbolic max backedge-taken count is ";
13868 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13869 } else
13870 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13871 OS << "\n";
13872 OS << " Predicates:\n";
13873 for (const auto *P : Preds)
13874 P->print(OS, 4);
13875 }
13876
13878 OS << "Loop ";
13879 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13880 OS << ": ";
13881 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13882 }
13883}
13884
13885namespace llvm {
13887 switch (LD) {
13889 OS << "Variant";
13890 break;
13892 OS << "Invariant";
13893 break;
13895 OS << "Computable";
13896 break;
13897 }
13898 return OS;
13899}
13900
13902 switch (BD) {
13904 OS << "DoesNotDominate";
13905 break;
13907 OS << "Dominates";
13908 break;
13910 OS << "ProperlyDominates";
13911 break;
13912 }
13913 return OS;
13914}
13915} // namespace llvm
13916
13918 // ScalarEvolution's implementation of the print method is to print
13919 // out SCEV values of all instructions that are interesting. Doing
13920 // this potentially causes it to create new SCEV objects though,
13921 // which technically conflicts with the const qualifier. This isn't
13922 // observable from outside the class though, so casting away the
13923 // const isn't dangerous.
13924 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13925
13926 if (ClassifyExpressions) {
13927 OS << "Classifying expressions for: ";
13928 F.printAsOperand(OS, /*PrintType=*/false);
13929 OS << "\n";
13930 for (Instruction &I : instructions(F))
13931 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13932 OS << I << '\n';
13933 OS << " --> ";
13934 const SCEV *SV = SE.getSCEV(&I);
13935 SV->print(OS);
13936 if (!isa<SCEVCouldNotCompute>(SV)) {
13937 OS << " U: ";
13938 SE.getUnsignedRange(SV).print(OS);
13939 OS << " S: ";
13940 SE.getSignedRange(SV).print(OS);
13941 }
13942
13943 const Loop *L = LI.getLoopFor(I.getParent());
13944
13945 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13946 if (AtUse != SV) {
13947 OS << " --> ";
13948 AtUse->print(OS);
13949 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13950 OS << " U: ";
13951 SE.getUnsignedRange(AtUse).print(OS);
13952 OS << " S: ";
13953 SE.getSignedRange(AtUse).print(OS);
13954 }
13955 }
13956
13957 if (L) {
13958 OS << "\t\t" "Exits: ";
13959 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13960 if (!SE.isLoopInvariant(ExitValue, L)) {
13961 OS << "<<Unknown>>";
13962 } else {
13963 OS << *ExitValue;
13964 }
13965
13966 bool First = true;
13967 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13968 if (First) {
13969 OS << "\t\t" "LoopDispositions: { ";
13970 First = false;
13971 } else {
13972 OS << ", ";
13973 }
13974
13975 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13976 OS << ": " << SE.getLoopDisposition(SV, Iter);
13977 }
13978
13979 for (const auto *InnerL : depth_first(L)) {
13980 if (InnerL == L)
13981 continue;
13982 if (First) {
13983 OS << "\t\t" "LoopDispositions: { ";
13984 First = false;
13985 } else {
13986 OS << ", ";
13987 }
13988
13989 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13990 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13991 }
13992
13993 OS << " }";
13994 }
13995
13996 OS << "\n";
13997 }
13998 }
13999
14000 OS << "Determining loop execution counts for: ";
14001 F.printAsOperand(OS, /*PrintType=*/false);
14002 OS << "\n";
14003 for (Loop *I : LI)
14004 PrintLoopInfo(OS, &SE, I);
14005}
14006
14009 auto &Values = LoopDispositions[S];
14010 for (auto &V : Values) {
14011 if (V.getPointer() == L)
14012 return V.getInt();
14013 }
14014 Values.emplace_back(L, LoopVariant);
14015 LoopDisposition D = computeLoopDisposition(S, L);
14016 auto &Values2 = LoopDispositions[S];
14017 for (auto &V : llvm::reverse(Values2)) {
14018 if (V.getPointer() == L) {
14019 V.setInt(D);
14020 break;
14021 }
14022 }
14023 return D;
14024}
14025
14027ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14028 switch (S->getSCEVType()) {
14029 case scConstant:
14030 case scVScale:
14031 return LoopInvariant;
14032 case scAddRecExpr: {
14033 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14034
14035 // If L is the addrec's loop, it's computable.
14036 if (AR->getLoop() == L)
14037 return LoopComputable;
14038
14039 // Add recurrences are never invariant in the function-body (null loop).
14040 if (!L)
14041 return LoopVariant;
14042
14043 // Everything that is not defined at loop entry is variant.
14044 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14045 return LoopVariant;
14046 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14047 " dominate the contained loop's header?");
14048
14049 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14050 if (AR->getLoop()->contains(L))
14051 return LoopInvariant;
14052
14053 // This recurrence is variant w.r.t. L if any of its operands
14054 // are variant.
14055 for (const auto *Op : AR->operands())
14056 if (!isLoopInvariant(Op, L))
14057 return LoopVariant;
14058
14059 // Otherwise it's loop-invariant.
14060 return LoopInvariant;
14061 }
14062 case scTruncate:
14063 case scZeroExtend:
14064 case scSignExtend:
14065 case scPtrToInt:
14066 case scAddExpr:
14067 case scMulExpr:
14068 case scUDivExpr:
14069 case scUMaxExpr:
14070 case scSMaxExpr:
14071 case scUMinExpr:
14072 case scSMinExpr:
14073 case scSequentialUMinExpr: {
14074 bool HasVarying = false;
14075 for (const auto *Op : S->operands()) {
14077 if (D == LoopVariant)
14078 return LoopVariant;
14079 if (D == LoopComputable)
14080 HasVarying = true;
14081 }
14082 return HasVarying ? LoopComputable : LoopInvariant;
14083 }
14084 case scUnknown:
14085 // All non-instruction values are loop invariant. All instructions are loop
14086 // invariant if they are not contained in the specified loop.
14087 // Instructions are never considered invariant in the function body
14088 // (null loop) because they are defined within the "loop".
14089 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14090 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14091 return LoopInvariant;
14092 case scCouldNotCompute:
14093 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14094 }
14095 llvm_unreachable("Unknown SCEV kind!");
14096}
14097
14099 return getLoopDisposition(S, L) == LoopInvariant;
14100}
14101
14103 return getLoopDisposition(S, L) == LoopComputable;
14104}
14105
14108 auto &Values = BlockDispositions[S];
14109 for (auto &V : Values) {
14110 if (V.getPointer() == BB)
14111 return V.getInt();
14112 }
14113 Values.emplace_back(BB, DoesNotDominateBlock);
14114 BlockDisposition D = computeBlockDisposition(S, BB);
14115 auto &Values2 = BlockDispositions[S];
14116 for (auto &V : llvm::reverse(Values2)) {
14117 if (V.getPointer() == BB) {
14118 V.setInt(D);
14119 break;
14120 }
14121 }
14122 return D;
14123}
14124
14126ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14127 switch (S->getSCEVType()) {
14128 case scConstant:
14129 case scVScale:
14131 case scAddRecExpr: {
14132 // This uses a "dominates" query instead of "properly dominates" query
14133 // to test for proper dominance too, because the instruction which
14134 // produces the addrec's value is a PHI, and a PHI effectively properly
14135 // dominates its entire containing block.
14136 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14137 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14138 return DoesNotDominateBlock;
14139
14140 // Fall through into SCEVNAryExpr handling.
14141 [[fallthrough]];
14142 }
14143 case scTruncate:
14144 case scZeroExtend:
14145 case scSignExtend:
14146 case scPtrToInt:
14147 case scAddExpr:
14148 case scMulExpr:
14149 case scUDivExpr:
14150 case scUMaxExpr:
14151 case scSMaxExpr:
14152 case scUMinExpr:
14153 case scSMinExpr:
14154 case scSequentialUMinExpr: {
14155 bool Proper = true;
14156 for (const SCEV *NAryOp : S->operands()) {
14158 if (D == DoesNotDominateBlock)
14159 return DoesNotDominateBlock;
14160 if (D == DominatesBlock)
14161 Proper = false;
14162 }
14163 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14164 }
14165 case scUnknown:
14166 if (Instruction *I =
14167 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14168 if (I->getParent() == BB)
14169 return DominatesBlock;
14170 if (DT.properlyDominates(I->getParent(), BB))
14172 return DoesNotDominateBlock;
14173 }
14175 case scCouldNotCompute:
14176 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14177 }
14178 llvm_unreachable("Unknown SCEV kind!");
14179}
14180
14181bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14182 return getBlockDisposition(S, BB) >= DominatesBlock;
14183}
14184
14187}
14188
14189bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14190 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14191}
14192
14193void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14194 bool Predicated) {
14195 auto &BECounts =
14196 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14197 auto It = BECounts.find(L);
14198 if (It != BECounts.end()) {
14199 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14200 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14201 if (!isa<SCEVConstant>(S)) {
14202 auto UserIt = BECountUsers.find(S);
14203 assert(UserIt != BECountUsers.end());
14204 UserIt->second.erase({L, Predicated});
14205 }
14206 }
14207 }
14208 BECounts.erase(It);
14209 }
14210}
14211
14212void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14213 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14214 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14215
14216 while (!Worklist.empty()) {
14217 const SCEV *Curr = Worklist.pop_back_val();
14218 auto Users = SCEVUsers.find(Curr);
14219 if (Users != SCEVUsers.end())
14220 for (const auto *User : Users->second)
14221 if (ToForget.insert(User).second)
14222 Worklist.push_back(User);
14223 }
14224
14225 for (const auto *S : ToForget)
14226 forgetMemoizedResultsImpl(S);
14227
14228 for (auto I = PredicatedSCEVRewrites.begin();
14229 I != PredicatedSCEVRewrites.end();) {
14230 std::pair<const SCEV *, const Loop *> Entry = I->first;
14231 if (ToForget.count(Entry.first))
14232 PredicatedSCEVRewrites.erase(I++);
14233 else
14234 ++I;
14235 }
14236}
14237
14238void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14239 LoopDispositions.erase(S);
14240 BlockDispositions.erase(S);
14241 UnsignedRanges.erase(S);
14242 SignedRanges.erase(S);
14243 HasRecMap.erase(S);
14244 ConstantMultipleCache.erase(S);
14245
14246 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14247 UnsignedWrapViaInductionTried.erase(AR);
14248 SignedWrapViaInductionTried.erase(AR);
14249 }
14250
14251 auto ExprIt = ExprValueMap.find(S);
14252 if (ExprIt != ExprValueMap.end()) {
14253 for (Value *V : ExprIt->second) {
14254 auto ValueIt = ValueExprMap.find_as(V);
14255 if (ValueIt != ValueExprMap.end())
14256 ValueExprMap.erase(ValueIt);
14257 }
14258 ExprValueMap.erase(ExprIt);
14259 }
14260
14261 auto ScopeIt = ValuesAtScopes.find(S);
14262 if (ScopeIt != ValuesAtScopes.end()) {
14263 for (const auto &Pair : ScopeIt->second)
14264 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14265 llvm::erase(ValuesAtScopesUsers[Pair.second],
14266 std::make_pair(Pair.first, S));
14267 ValuesAtScopes.erase(ScopeIt);
14268 }
14269
14270 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14271 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14272 for (const auto &Pair : ScopeUserIt->second)
14273 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14274 ValuesAtScopesUsers.erase(ScopeUserIt);
14275 }
14276
14277 auto BEUsersIt = BECountUsers.find(S);
14278 if (BEUsersIt != BECountUsers.end()) {
14279 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14280 auto Copy = BEUsersIt->second;
14281 for (const auto &Pair : Copy)
14282 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14283 BECountUsers.erase(BEUsersIt);
14284 }
14285
14286 auto FoldUser = FoldCacheUser.find(S);
14287 if (FoldUser != FoldCacheUser.end())
14288 for (auto &KV : FoldUser->second)
14289 FoldCache.erase(KV);
14290 FoldCacheUser.erase(S);
14291}
14292
14293void
14294ScalarEvolution::getUsedLoops(const SCEV *S,
14295 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14296 struct FindUsedLoops {
14297 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14298 : LoopsUsed(LoopsUsed) {}
14300 bool follow(const SCEV *S) {
14301 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14302 LoopsUsed.insert(AR->getLoop());
14303 return true;
14304 }
14305
14306 bool isDone() const { return false; }
14307 };
14308
14309 FindUsedLoops F(LoopsUsed);
14311}
14312
14313void ScalarEvolution::getReachableBlocks(
14316 Worklist.push_back(&F.getEntryBlock());
14317 while (!Worklist.empty()) {
14318 BasicBlock *BB = Worklist.pop_back_val();
14319 if (!Reachable.insert(BB).second)
14320 continue;
14321
14322 Value *Cond;
14323 BasicBlock *TrueBB, *FalseBB;
14324 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14325 m_BasicBlock(FalseBB)))) {
14326 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14327 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14328 continue;
14329 }
14330
14331 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14332 const SCEV *L = getSCEV(Cmp->getOperand(0));
14333 const SCEV *R = getSCEV(Cmp->getOperand(1));
14334 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14335 Worklist.push_back(TrueBB);
14336 continue;
14337 }
14338 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14339 R)) {
14340 Worklist.push_back(FalseBB);
14341 continue;
14342 }
14343 }
14344 }
14345
14346 append_range(Worklist, successors(BB));
14347 }
14348}
14349
14351 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14352 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14353
14354 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14355
14356 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14357 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14358 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14359
14360 const SCEV *visitConstant(const SCEVConstant *Constant) {
14361 return SE.getConstant(Constant->getAPInt());
14362 }
14363
14364 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14365 return SE.getUnknown(Expr->getValue());
14366 }
14367
14368 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14369 return SE.getCouldNotCompute();
14370 }
14371 };
14372
14373 SCEVMapper SCM(SE2);
14374 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14375 SE2.getReachableBlocks(ReachableBlocks, F);
14376
14377 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14378 if (containsUndefs(Old) || containsUndefs(New)) {
14379 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14380 // not propagate undef aggressively). This means we can (and do) fail
14381 // verification in cases where a transform makes a value go from "undef"
14382 // to "undef+1" (say). The transform is fine, since in both cases the
14383 // result is "undef", but SCEV thinks the value increased by 1.
14384 return nullptr;
14385 }
14386
14387 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14388 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14389 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14390 return nullptr;
14391
14392 return Delta;
14393 };
14394
14395 while (!LoopStack.empty()) {
14396 auto *L = LoopStack.pop_back_val();
14397 llvm::append_range(LoopStack, *L);
14398
14399 // Only verify BECounts in reachable loops. For an unreachable loop,
14400 // any BECount is legal.
14401 if (!ReachableBlocks.contains(L->getHeader()))
14402 continue;
14403
14404 // Only verify cached BECounts. Computing new BECounts may change the
14405 // results of subsequent SCEV uses.
14406 auto It = BackedgeTakenCounts.find(L);
14407 if (It == BackedgeTakenCounts.end())
14408 continue;
14409
14410 auto *CurBECount =
14411 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14412 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14413
14414 if (CurBECount == SE2.getCouldNotCompute() ||
14415 NewBECount == SE2.getCouldNotCompute()) {
14416 // NB! This situation is legal, but is very suspicious -- whatever pass
14417 // change the loop to make a trip count go from could not compute to
14418 // computable or vice-versa *should have* invalidated SCEV. However, we
14419 // choose not to assert here (for now) since we don't want false
14420 // positives.
14421 continue;
14422 }
14423
14424 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14425 SE.getTypeSizeInBits(NewBECount->getType()))
14426 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14427 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14428 SE.getTypeSizeInBits(NewBECount->getType()))
14429 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14430
14431 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14432 if (Delta && !Delta->isZero()) {
14433 dbgs() << "Trip Count for " << *L << " Changed!\n";
14434 dbgs() << "Old: " << *CurBECount << "\n";
14435 dbgs() << "New: " << *NewBECount << "\n";
14436 dbgs() << "Delta: " << *Delta << "\n";
14437 std::abort();
14438 }
14439 }
14440
14441 // Collect all valid loops currently in LoopInfo.
14442 SmallPtrSet<Loop *, 32> ValidLoops;
14443 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14444 while (!Worklist.empty()) {
14445 Loop *L = Worklist.pop_back_val();
14446 if (ValidLoops.insert(L).second)
14447 Worklist.append(L->begin(), L->end());
14448 }
14449 for (const auto &KV : ValueExprMap) {
14450#ifndef NDEBUG
14451 // Check for SCEV expressions referencing invalid/deleted loops.
14452 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14453 assert(ValidLoops.contains(AR->getLoop()) &&
14454 "AddRec references invalid loop");
14455 }
14456#endif
14457
14458 // Check that the value is also part of the reverse map.
14459 auto It = ExprValueMap.find(KV.second);
14460 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14461 dbgs() << "Value " << *KV.first
14462 << " is in ValueExprMap but not in ExprValueMap\n";
14463 std::abort();
14464 }
14465
14466 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14467 if (!ReachableBlocks.contains(I->getParent()))
14468 continue;
14469 const SCEV *OldSCEV = SCM.visit(KV.second);
14470 const SCEV *NewSCEV = SE2.getSCEV(I);
14471 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14472 if (Delta && !Delta->isZero()) {
14473 dbgs() << "SCEV for value " << *I << " changed!\n"
14474 << "Old: " << *OldSCEV << "\n"
14475 << "New: " << *NewSCEV << "\n"
14476 << "Delta: " << *Delta << "\n";
14477 std::abort();
14478 }
14479 }
14480 }
14481
14482 for (const auto &KV : ExprValueMap) {
14483 for (Value *V : KV.second) {
14484 auto It = ValueExprMap.find_as(V);
14485 if (It == ValueExprMap.end()) {
14486 dbgs() << "Value " << *V
14487 << " is in ExprValueMap but not in ValueExprMap\n";
14488 std::abort();
14489 }
14490 if (It->second != KV.first) {
14491 dbgs() << "Value " << *V << " mapped to " << *It->second
14492 << " rather than " << *KV.first << "\n";
14493 std::abort();
14494 }
14495 }
14496 }
14497
14498 // Verify integrity of SCEV users.
14499 for (const auto &S : UniqueSCEVs) {
14500 for (const auto *Op : S.operands()) {
14501 // We do not store dependencies of constants.
14502 if (isa<SCEVConstant>(Op))
14503 continue;
14504 auto It = SCEVUsers.find(Op);
14505 if (It != SCEVUsers.end() && It->second.count(&S))
14506 continue;
14507 dbgs() << "Use of operand " << *Op << " by user " << S
14508 << " is not being tracked!\n";
14509 std::abort();
14510 }
14511 }
14512
14513 // Verify integrity of ValuesAtScopes users.
14514 for (const auto &ValueAndVec : ValuesAtScopes) {
14515 const SCEV *Value = ValueAndVec.first;
14516 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14517 const Loop *L = LoopAndValueAtScope.first;
14518 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14519 if (!isa<SCEVConstant>(ValueAtScope)) {
14520 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14521 if (It != ValuesAtScopesUsers.end() &&
14522 is_contained(It->second, std::make_pair(L, Value)))
14523 continue;
14524 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14525 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14526 std::abort();
14527 }
14528 }
14529 }
14530
14531 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14532 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14533 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14534 const Loop *L = LoopAndValue.first;
14535 const SCEV *Value = LoopAndValue.second;
14536 assert(!isa<SCEVConstant>(Value));
14537 auto It = ValuesAtScopes.find(Value);
14538 if (It != ValuesAtScopes.end() &&
14539 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14540 continue;
14541 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14542 << *ValueAtScope << " missing in ValuesAtScopes\n";
14543 std::abort();
14544 }
14545 }
14546
14547 // Verify integrity of BECountUsers.
14548 auto VerifyBECountUsers = [&](bool Predicated) {
14549 auto &BECounts =
14550 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14551 for (const auto &LoopAndBEInfo : BECounts) {
14552 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14553 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14554 if (!isa<SCEVConstant>(S)) {
14555 auto UserIt = BECountUsers.find(S);
14556 if (UserIt != BECountUsers.end() &&
14557 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14558 continue;
14559 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14560 << " missing from BECountUsers\n";
14561 std::abort();
14562 }
14563 }
14564 }
14565 }
14566 };
14567 VerifyBECountUsers(/* Predicated */ false);
14568 VerifyBECountUsers(/* Predicated */ true);
14569
14570 // Verify intergity of loop disposition cache.
14571 for (auto &[S, Values] : LoopDispositions) {
14572 for (auto [Loop, CachedDisposition] : Values) {
14573 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14574 if (CachedDisposition != RecomputedDisposition) {
14575 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14576 << " is incorrect: cached " << CachedDisposition << ", actual "
14577 << RecomputedDisposition << "\n";
14578 std::abort();
14579 }
14580 }
14581 }
14582
14583 // Verify integrity of the block disposition cache.
14584 for (auto &[S, Values] : BlockDispositions) {
14585 for (auto [BB, CachedDisposition] : Values) {
14586 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14587 if (CachedDisposition != RecomputedDisposition) {
14588 dbgs() << "Cached disposition of " << *S << " for block %"
14589 << BB->getName() << " is incorrect: cached " << CachedDisposition
14590 << ", actual " << RecomputedDisposition << "\n";
14591 std::abort();
14592 }
14593 }
14594 }
14595
14596 // Verify FoldCache/FoldCacheUser caches.
14597 for (auto [FoldID, Expr] : FoldCache) {
14598 auto I = FoldCacheUser.find(Expr);
14599 if (I == FoldCacheUser.end()) {
14600 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14601 << "!\n";
14602 std::abort();
14603 }
14604 if (!is_contained(I->second, FoldID)) {
14605 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14606 std::abort();
14607 }
14608 }
14609 for (auto [Expr, IDs] : FoldCacheUser) {
14610 for (auto &FoldID : IDs) {
14611 auto I = FoldCache.find(FoldID);
14612 if (I == FoldCache.end()) {
14613 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14614 << "!\n";
14615 std::abort();
14616 }
14617 if (I->second != Expr) {
14618 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14619 << *I->second << " != " << *Expr << "!\n";
14620 std::abort();
14621 }
14622 }
14623 }
14624
14625 // Verify that ConstantMultipleCache computations are correct. We check that
14626 // cached multiples and recomputed multiples are multiples of each other to
14627 // verify correctness. It is possible that a recomputed multiple is different
14628 // from the cached multiple due to strengthened no wrap flags or changes in
14629 // KnownBits computations.
14630 for (auto [S, Multiple] : ConstantMultipleCache) {
14631 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14632 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14633 Multiple.urem(RecomputedMultiple) != 0 &&
14634 RecomputedMultiple.urem(Multiple) != 0)) {
14635 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14636 << *S << " : Computed " << RecomputedMultiple
14637 << " but cache contains " << Multiple << "!\n";
14638 std::abort();
14639 }
14640 }
14641}
14642
14644 Function &F, const PreservedAnalyses &PA,
14646 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14647 // of its dependencies is invalidated.
14648 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14649 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14650 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14652 Inv.invalidate<LoopAnalysis>(F, PA);
14653}
14654
14655AnalysisKey ScalarEvolutionAnalysis::Key;
14656
14659 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14660 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14661 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14662 auto &LI = AM.getResult<LoopAnalysis>(F);
14663 return ScalarEvolution(F, TLI, AC, DT, LI);
14664}
14665
14669 return PreservedAnalyses::all();
14670}
14671
14674 // For compatibility with opt's -analyze feature under legacy pass manager
14675 // which was not ported to NPM. This keeps tests using
14676 // update_analyze_test_checks.py working.
14677 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14678 << F.getName() << "':\n";
14680 return PreservedAnalyses::all();
14681}
14682
14684 "Scalar Evolution Analysis", false, true)
14690 "Scalar Evolution Analysis", false, true)
14691
14693
14696}
14697
14699 SE.reset(new ScalarEvolution(
14700 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14701 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14702 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14703 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14704 return false;
14705}
14706
14708
14710 SE->print(OS);
14711}
14712
14714 if (!VerifySCEV)
14715 return;
14716
14717 SE->verify();
14718}
14719
14721 AU.setPreservesAll();
14726}
14727
14729 const SCEV *RHS) {
14731}
14732
14733const SCEVPredicate *
14735 const SCEV *LHS, const SCEV *RHS) {
14737 assert(LHS->getType() == RHS->getType() &&
14738 "Type mismatch between LHS and RHS");
14739 // Unique this node based on the arguments
14740 ID.AddInteger(SCEVPredicate::P_Compare);
14741 ID.AddInteger(Pred);
14742 ID.AddPointer(LHS);
14743 ID.AddPointer(RHS);
14744 void *IP = nullptr;
14745 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14746 return S;
14747 SCEVComparePredicate *Eq = new (SCEVAllocator)
14748 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14749 UniquePreds.InsertNode(Eq, IP);
14750 return Eq;
14751}
14752
14754 const SCEVAddRecExpr *AR,
14757 // Unique this node based on the arguments
14758 ID.AddInteger(SCEVPredicate::P_Wrap);
14759 ID.AddPointer(AR);
14760 ID.AddInteger(AddedFlags);
14761 void *IP = nullptr;
14762 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14763 return S;
14764 auto *OF = new (SCEVAllocator)
14765 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14766 UniquePreds.InsertNode(OF, IP);
14767 return OF;
14768}
14769
14770namespace {
14771
14772class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14773public:
14774
14775 /// Rewrites \p S in the context of a loop L and the SCEV predication
14776 /// infrastructure.
14777 ///
14778 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14779 /// equivalences present in \p Pred.
14780 ///
14781 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14782 /// \p NewPreds such that the result will be an AddRecExpr.
14783 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14785 const SCEVPredicate *Pred) {
14786 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14787 return Rewriter.visit(S);
14788 }
14789
14790 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14791 if (Pred) {
14792 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14793 for (const auto *Pred : U->getPredicates())
14794 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14795 if (IPred->getLHS() == Expr &&
14796 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14797 return IPred->getRHS();
14798 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14799 if (IPred->getLHS() == Expr &&
14800 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14801 return IPred->getRHS();
14802 }
14803 }
14804 return convertToAddRecWithPreds(Expr);
14805 }
14806
14807 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14808 const SCEV *Operand = visit(Expr->getOperand());
14809 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14810 if (AR && AR->getLoop() == L && AR->isAffine()) {
14811 // This couldn't be folded because the operand didn't have the nuw
14812 // flag. Add the nusw flag as an assumption that we could make.
14813 const SCEV *Step = AR->getStepRecurrence(SE);
14814 Type *Ty = Expr->getType();
14815 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14816 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14817 SE.getSignExtendExpr(Step, Ty), L,
14818 AR->getNoWrapFlags());
14819 }
14820 return SE.getZeroExtendExpr(Operand, Expr->getType());
14821 }
14822
14823 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14824 const SCEV *Operand = visit(Expr->getOperand());
14825 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14826 if (AR && AR->getLoop() == L && AR->isAffine()) {
14827 // This couldn't be folded because the operand didn't have the nsw
14828 // flag. Add the nssw flag as an assumption that we could make.
14829 const SCEV *Step = AR->getStepRecurrence(SE);
14830 Type *Ty = Expr->getType();
14831 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14832 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14833 SE.getSignExtendExpr(Step, Ty), L,
14834 AR->getNoWrapFlags());
14835 }
14836 return SE.getSignExtendExpr(Operand, Expr->getType());
14837 }
14838
14839private:
14840 explicit SCEVPredicateRewriter(
14841 const Loop *L, ScalarEvolution &SE,
14843 const SCEVPredicate *Pred)
14844 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14845
14846 bool addOverflowAssumption(const SCEVPredicate *P) {
14847 if (!NewPreds) {
14848 // Check if we've already made this assumption.
14849 return Pred && Pred->implies(P, SE);
14850 }
14851 NewPreds->push_back(P);
14852 return true;
14853 }
14854
14855 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14857 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14858 return addOverflowAssumption(A);
14859 }
14860
14861 // If \p Expr represents a PHINode, we try to see if it can be represented
14862 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14863 // to add this predicate as a runtime overflow check, we return the AddRec.
14864 // If \p Expr does not meet these conditions (is not a PHI node, or we
14865 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14866 // return \p Expr.
14867 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14868 if (!isa<PHINode>(Expr->getValue()))
14869 return Expr;
14870 std::optional<
14871 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14872 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14873 if (!PredicatedRewrite)
14874 return Expr;
14875 for (const auto *P : PredicatedRewrite->second){
14876 // Wrap predicates from outer loops are not supported.
14877 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14878 if (L != WP->getExpr()->getLoop())
14879 return Expr;
14880 }
14881 if (!addOverflowAssumption(P))
14882 return Expr;
14883 }
14884 return PredicatedRewrite->first;
14885 }
14886
14888 const SCEVPredicate *Pred;
14889 const Loop *L;
14890};
14891
14892} // end anonymous namespace
14893
14894const SCEV *
14896 const SCEVPredicate &Preds) {
14897 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14898}
14899
14901 const SCEV *S, const Loop *L,
14904 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14905 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14906
14907 if (!AddRec)
14908 return nullptr;
14909
14910 // Since the transformation was successful, we can now transfer the SCEV
14911 // predicates.
14912 Preds.append(TransformPreds.begin(), TransformPreds.end());
14913
14914 return AddRec;
14915}
14916
14917/// SCEV predicates
14919 SCEVPredicateKind Kind)
14920 : FastID(ID), Kind(Kind) {}
14921
14923 const ICmpInst::Predicate Pred,
14924 const SCEV *LHS, const SCEV *RHS)
14925 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14926 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14927 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14928}
14929
14931 ScalarEvolution &SE) const {
14932 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14933
14934 if (!Op)
14935 return false;
14936
14937 if (Pred != ICmpInst::ICMP_EQ)
14938 return false;
14939
14940 return Op->LHS == LHS && Op->RHS == RHS;
14941}
14942
14943bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14944
14946 if (Pred == ICmpInst::ICMP_EQ)
14947 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14948 else
14949 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14950 << *RHS << "\n";
14951
14952}
14953
14955 const SCEVAddRecExpr *AR,
14956 IncrementWrapFlags Flags)
14957 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14958
14959const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14960
14962 ScalarEvolution &SE) const {
14963 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14964 if (!Op || setFlags(Flags, Op->Flags) != Flags)
14965 return false;
14966
14967 if (Op->AR == AR)
14968 return true;
14969
14970 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14972 return false;
14973
14974 const SCEV *Start = AR->getStart();
14975 const SCEV *OpStart = Op->AR->getStart();
14976 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
14977 return false;
14978
14979 const SCEV *Step = AR->getStepRecurrence(SE);
14980 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14981 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14982 return false;
14983
14984 // If both steps are positive, this implies N, if N's start and step are
14985 // ULE/SLE (for NSUW/NSSW) than this'.
14986 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14987 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14988 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14989
14990 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14991 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
14992 : SE.getNoopOrSignExtend(OpStart, WiderTy);
14993 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
14994 : SE.getNoopOrSignExtend(Start, WiderTy);
14996 return SE.isKnownPredicate(Pred, OpStep, Step) &&
14997 SE.isKnownPredicate(Pred, OpStart, Start);
14998}
14999
15001 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15002 IncrementWrapFlags IFlags = Flags;
15003
15004 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15005 IFlags = clearFlags(IFlags, IncrementNSSW);
15006
15007 return IFlags == IncrementAnyWrap;
15008}
15009
15011 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15013 OS << "<nusw>";
15015 OS << "<nssw>";
15016 OS << "\n";
15017}
15018
15021 ScalarEvolution &SE) {
15022 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15023 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15024
15025 // We can safely transfer the NSW flag as NSSW.
15026 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15027 ImpliedFlags = IncrementNSSW;
15028
15029 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15030 // If the increment is positive, the SCEV NUW flag will also imply the
15031 // WrapPredicate NUSW flag.
15032 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15033 if (Step->getValue()->getValue().isNonNegative())
15034 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15035 }
15036
15037 return ImpliedFlags;
15038}
15039
15040/// Union predicates don't get cached so create a dummy set ID for it.
15042 ScalarEvolution &SE)
15043 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15044 for (const auto *P : Preds)
15045 add(P, SE);
15046}
15047
15049 return all_of(Preds,
15050 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15051}
15052
15054 ScalarEvolution &SE) const {
15055 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15056 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15057 return this->implies(I, SE);
15058 });
15059
15060 return any_of(Preds,
15061 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15062}
15063
15065 for (const auto *Pred : Preds)
15066 Pred->print(OS, Depth);
15067}
15068
15069void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15070 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15071 for (const auto *Pred : Set->Preds)
15072 add(Pred, SE);
15073 return;
15074 }
15075
15076 // Only add predicate if it is not already implied by this union predicate.
15077 if (implies(N, SE))
15078 return;
15079
15080 // Build a new vector containing the current predicates, except the ones that
15081 // are implied by the new predicate N.
15083 for (auto *P : Preds) {
15084 if (N->implies(P, SE))
15085 continue;
15086 PrunedPreds.push_back(P);
15087 }
15088 Preds = std::move(PrunedPreds);
15089 Preds.push_back(N);
15090}
15091
15093 Loop &L)
15094 : SE(SE), L(L) {
15096 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15097}
15098
15101 for (const auto *Op : Ops)
15102 // We do not expect that forgetting cached data for SCEVConstants will ever
15103 // open any prospects for sharpening or introduce any correctness issues,
15104 // so we don't bother storing their dependencies.
15105 if (!isa<SCEVConstant>(Op))
15106 SCEVUsers[Op].insert(User);
15107}
15108
15110 const SCEV *Expr = SE.getSCEV(V);
15111 RewriteEntry &Entry = RewriteMap[Expr];
15112
15113 // If we already have an entry and the version matches, return it.
15114 if (Entry.second && Generation == Entry.first)
15115 return Entry.second;
15116
15117 // We found an entry but it's stale. Rewrite the stale entry
15118 // according to the current predicate.
15119 if (Entry.second)
15120 Expr = Entry.second;
15121
15122 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15123 Entry = {Generation, NewSCEV};
15124
15125 return NewSCEV;
15126}
15127
15129 if (!BackedgeCount) {
15131 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15132 for (const auto *P : Preds)
15133 addPredicate(*P);
15134 }
15135 return BackedgeCount;
15136}
15137
15139 if (!SymbolicMaxBackedgeCount) {
15141 SymbolicMaxBackedgeCount =
15143 for (const auto *P : Preds)
15144 addPredicate(*P);
15145 }
15146 return SymbolicMaxBackedgeCount;
15147}
15148
15150 if (!SmallConstantMaxTripCount) {
15152 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15153 for (const auto *P : Preds)
15154 addPredicate(*P);
15155 }
15156 return *SmallConstantMaxTripCount;
15157}
15158
15160 if (Preds->implies(&Pred, SE))
15161 return;
15162
15163 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15164 NewPreds.push_back(&Pred);
15165 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15166 updateGeneration();
15167}
15168
15170 return *Preds;
15171}
15172
15173void PredicatedScalarEvolution::updateGeneration() {
15174 // If the generation number wrapped recompute everything.
15175 if (++Generation == 0) {
15176 for (auto &II : RewriteMap) {
15177 const SCEV *Rewritten = II.second.second;
15178 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15179 }
15180 }
15181}
15182
15185 const SCEV *Expr = getSCEV(V);
15186 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15187
15188 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15189
15190 // Clear the statically implied flags.
15191 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15192 addPredicate(*SE.getWrapPredicate(AR, Flags));
15193
15194 auto II = FlagsMap.insert({V, Flags});
15195 if (!II.second)
15196 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15197}
15198
15201 const SCEV *Expr = getSCEV(V);
15202 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15203
15205 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15206
15207 auto II = FlagsMap.find(V);
15208
15209 if (II != FlagsMap.end())
15210 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15211
15213}
15214
15216 const SCEV *Expr = this->getSCEV(V);
15218 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15219
15220 if (!New)
15221 return nullptr;
15222
15223 for (const auto *P : NewPreds)
15224 addPredicate(*P);
15225
15226 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15227 return New;
15228}
15229
15232 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15233 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15234 SE)),
15235 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15236 for (auto I : Init.FlagsMap)
15237 FlagsMap.insert(I);
15238}
15239
15241 // For each block.
15242 for (auto *BB : L.getBlocks())
15243 for (auto &I : *BB) {
15244 if (!SE.isSCEVable(I.getType()))
15245 continue;
15246
15247 auto *Expr = SE.getSCEV(&I);
15248 auto II = RewriteMap.find(Expr);
15249
15250 if (II == RewriteMap.end())
15251 continue;
15252
15253 // Don't print things that are not interesting.
15254 if (II->second.second == Expr)
15255 continue;
15256
15257 OS.indent(Depth) << "[PSE]" << I << ":\n";
15258 OS.indent(Depth + 2) << *Expr << "\n";
15259 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15260 }
15261}
15262
15263// Match the mathematical pattern A - (A / B) * B, where A and B can be
15264// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15265// for URem with constant power-of-2 second operands.
15266// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15267// 4, A / B becomes X / 8).
15268bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15269 const SCEV *&RHS) {
15270 if (Expr->getType()->isPointerTy())
15271 return false;
15272
15273 // Try to match 'zext (trunc A to iB) to iY', which is used
15274 // for URem with constant power-of-2 second operands. Make sure the size of
15275 // the operand A matches the size of the whole expressions.
15276 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15277 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15278 LHS = Trunc->getOperand();
15279 // Bail out if the type of the LHS is larger than the type of the
15280 // expression for now.
15281 if (getTypeSizeInBits(LHS->getType()) >
15282 getTypeSizeInBits(Expr->getType()))
15283 return false;
15284 if (LHS->getType() != Expr->getType())
15285 LHS = getZeroExtendExpr(LHS, Expr->getType());
15287 << getTypeSizeInBits(Trunc->getType()));
15288 return true;
15289 }
15290 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15291 if (Add == nullptr || Add->getNumOperands() != 2)
15292 return false;
15293
15294 const SCEV *A = Add->getOperand(1);
15295 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15296
15297 if (Mul == nullptr)
15298 return false;
15299
15300 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15301 // (SomeExpr + (-(SomeExpr / B) * B)).
15302 if (Expr == getURemExpr(A, B)) {
15303 LHS = A;
15304 RHS = B;
15305 return true;
15306 }
15307 return false;
15308 };
15309
15310 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15311 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15312 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15313 MatchURemWithDivisor(Mul->getOperand(2));
15314
15315 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15316 if (Mul->getNumOperands() == 2)
15317 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15318 MatchURemWithDivisor(Mul->getOperand(0)) ||
15319 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15320 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15321 return false;
15322}
15323
15326 BasicBlock *Header = L->getHeader();
15327 BasicBlock *Pred = L->getLoopPredecessor();
15328 LoopGuards Guards(SE);
15329 if (!Pred)
15330 return Guards;
15332 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15333 return Guards;
15334}
15335
15336void ScalarEvolution::LoopGuards::collectFromPHI(
15338 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15340 unsigned Depth) {
15341 if (!SE.isSCEVable(Phi.getType()))
15342 return;
15343
15344 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15345 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15346 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15347 if (!VisitedBlocks.insert(InBlock).second)
15348 return {nullptr, scCouldNotCompute};
15349 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15350 if (Inserted)
15351 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15352 Depth + 1);
15353 auto &RewriteMap = G->second.RewriteMap;
15354 if (RewriteMap.empty())
15355 return {nullptr, scCouldNotCompute};
15356 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15357 if (S == RewriteMap.end())
15358 return {nullptr, scCouldNotCompute};
15359 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15360 if (!SM)
15361 return {nullptr, scCouldNotCompute};
15362 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15363 return {C0, SM->getSCEVType()};
15364 return {nullptr, scCouldNotCompute};
15365 };
15366 auto MergeMinMaxConst = [](MinMaxPattern P1,
15367 MinMaxPattern P2) -> MinMaxPattern {
15368 auto [C1, T1] = P1;
15369 auto [C2, T2] = P2;
15370 if (!C1 || !C2 || T1 != T2)
15371 return {nullptr, scCouldNotCompute};
15372 switch (T1) {
15373 case scUMaxExpr:
15374 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15375 case scSMaxExpr:
15376 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15377 case scUMinExpr:
15378 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15379 case scSMinExpr:
15380 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15381 default:
15382 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15383 }
15384 };
15385 auto P = GetMinMaxConst(0);
15386 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15387 if (!P.first)
15388 break;
15389 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15390 }
15391 if (P.first) {
15392 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15393 SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15394 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15395 Guards.RewriteMap.insert({LHS, RHS});
15396 }
15397}
15398
15399void ScalarEvolution::LoopGuards::collectFromBlock(
15401 const BasicBlock *Block, const BasicBlock *Pred,
15402 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15403 SmallVector<const SCEV *> ExprsToRewrite;
15404 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15405 const SCEV *RHS,
15407 &RewriteMap) {
15408 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15409 // replacement SCEV which isn't directly implied by the structure of that
15410 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15411 // legal. See the scoping rules for flags in the header to understand why.
15412
15413 // If LHS is a constant, apply information to the other expression.
15414 if (isa<SCEVConstant>(LHS)) {
15415 std::swap(LHS, RHS);
15417 }
15418
15419 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15420 // create this form when combining two checks of the form (X u< C2 + C1) and
15421 // (X >=u C1).
15422 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15423 &ExprsToRewrite]() {
15424 const SCEVConstant *C1;
15425 const SCEVUnknown *LHSUnknown;
15426 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15427 if (!match(LHS,
15428 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15429 !C2)
15430 return false;
15431
15432 auto ExactRegion =
15434 .sub(C1->getAPInt());
15435
15436 // Bail out, unless we have a non-wrapping, monotonic range.
15437 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15438 return false;
15439 auto I = RewriteMap.find(LHSUnknown);
15440 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15441 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15442 SE.getConstant(ExactRegion.getUnsignedMin()),
15443 SE.getUMinExpr(RewrittenLHS,
15444 SE.getConstant(ExactRegion.getUnsignedMax())));
15445 ExprsToRewrite.push_back(LHSUnknown);
15446 return true;
15447 };
15448 if (MatchRangeCheckIdiom())
15449 return;
15450
15451 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15452 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15453 // the non-constant operand and in \p LHS the constant operand.
15454 auto IsMinMaxSCEVWithNonNegativeConstant =
15455 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15456 const SCEV *&RHS) {
15457 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15458 if (MinMax->getNumOperands() != 2)
15459 return false;
15460 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15461 if (C->getAPInt().isNegative())
15462 return false;
15463 SCTy = MinMax->getSCEVType();
15464 LHS = MinMax->getOperand(0);
15465 RHS = MinMax->getOperand(1);
15466 return true;
15467 }
15468 }
15469 return false;
15470 };
15471
15472 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15473 // constant, and returns their APInt in ExprVal and in DivisorVal.
15474 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15475 APInt &ExprVal, APInt &DivisorVal) {
15476 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15477 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15478 if (!ConstExpr || !ConstDivisor)
15479 return false;
15480 ExprVal = ConstExpr->getAPInt();
15481 DivisorVal = ConstDivisor->getAPInt();
15482 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15483 };
15484
15485 // Return a new SCEV that modifies \p Expr to the closest number divides by
15486 // \p Divisor and greater or equal than Expr.
15487 // For now, only handle constant Expr and Divisor.
15488 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15489 const SCEV *Divisor) {
15490 APInt ExprVal;
15491 APInt DivisorVal;
15492 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15493 return Expr;
15494 APInt Rem = ExprVal.urem(DivisorVal);
15495 if (!Rem.isZero())
15496 // return the SCEV: Expr + Divisor - Expr % Divisor
15497 return SE.getConstant(ExprVal + DivisorVal - Rem);
15498 return Expr;
15499 };
15500
15501 // Return a new SCEV that modifies \p Expr to the closest number divides by
15502 // \p Divisor and less or equal than Expr.
15503 // For now, only handle constant Expr and Divisor.
15504 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15505 const SCEV *Divisor) {
15506 APInt ExprVal;
15507 APInt DivisorVal;
15508 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15509 return Expr;
15510 APInt Rem = ExprVal.urem(DivisorVal);
15511 // return the SCEV: Expr - Expr % Divisor
15512 return SE.getConstant(ExprVal - Rem);
15513 };
15514
15515 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15516 // recursively. This is done by aligning up/down the constant value to the
15517 // Divisor.
15518 std::function<const SCEV *(const SCEV *, const SCEV *)>
15519 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15520 const SCEV *Divisor) {
15521 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15522 SCEVTypes SCTy;
15523 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15524 MinMaxRHS))
15525 return MinMaxExpr;
15526 auto IsMin =
15527 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15528 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15529 "Expected non-negative operand!");
15530 auto *DivisibleExpr =
15531 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15532 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15534 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15535 return SE.getMinMaxExpr(SCTy, Ops);
15536 };
15537
15538 // If we have LHS == 0, check if LHS is computing a property of some unknown
15539 // SCEV %v which we can rewrite %v to express explicitly.
15540 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15541 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15542 // explicitly express that.
15543 const SCEV *URemLHS = nullptr;
15544 const SCEV *URemRHS = nullptr;
15545 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15546 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15547 auto I = RewriteMap.find(LHSUnknown);
15548 const SCEV *RewrittenLHS =
15549 I != RewriteMap.end() ? I->second : LHSUnknown;
15550 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15551 const auto *Multiple =
15552 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15553 RewriteMap[LHSUnknown] = Multiple;
15554 ExprsToRewrite.push_back(LHSUnknown);
15555 return;
15556 }
15557 }
15558 }
15559
15560 // Do not apply information for constants or if RHS contains an AddRec.
15561 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15562 return;
15563
15564 // If RHS is SCEVUnknown, make sure the information is applied to it.
15565 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15566 std::swap(LHS, RHS);
15568 }
15569
15570 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15571 // and \p FromRewritten are the same (i.e. there has been no rewrite
15572 // registered for \p From), then puts this value in the list of rewritten
15573 // expressions.
15574 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15575 const SCEV *To) {
15576 if (From == FromRewritten)
15577 ExprsToRewrite.push_back(From);
15578 RewriteMap[From] = To;
15579 };
15580
15581 // Checks whether \p S has already been rewritten. In that case returns the
15582 // existing rewrite because we want to chain further rewrites onto the
15583 // already rewritten value. Otherwise returns \p S.
15584 auto GetMaybeRewritten = [&](const SCEV *S) {
15585 auto I = RewriteMap.find(S);
15586 return I != RewriteMap.end() ? I->second : S;
15587 };
15588
15589 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15590 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15591 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15592 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15593 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15594 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15595 // DividesBy.
15596 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15597 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15598 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15599 if (Mul->getNumOperands() != 2)
15600 return false;
15601 auto *MulLHS = Mul->getOperand(0);
15602 auto *MulRHS = Mul->getOperand(1);
15603 if (isa<SCEVConstant>(MulLHS))
15604 std::swap(MulLHS, MulRHS);
15605 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15606 if (Div->getOperand(1) == MulRHS) {
15607 DividesBy = MulRHS;
15608 return true;
15609 }
15610 }
15611 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15612 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15613 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15614 return false;
15615 };
15616
15617 // Return true if Expr known to divide by \p DividesBy.
15618 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15619 [&](const SCEV *Expr, const SCEV *DividesBy) {
15620 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15621 return true;
15622 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15623 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15624 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15625 return false;
15626 };
15627
15628 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15629 const SCEV *DividesBy = nullptr;
15630 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15631 // Check that the whole expression is divided by DividesBy
15632 DividesBy =
15633 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15634
15635 // Collect rewrites for LHS and its transitive operands based on the
15636 // condition.
15637 // For min/max expressions, also apply the guard to its operands:
15638 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15639 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15640 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15641 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15642
15643 // We cannot express strict predicates in SCEV, so instead we replace them
15644 // with non-strict ones against plus or minus one of RHS depending on the
15645 // predicate.
15646 const SCEV *One = SE.getOne(RHS->getType());
15647 switch (Predicate) {
15648 case CmpInst::ICMP_ULT:
15649 if (RHS->getType()->isPointerTy())
15650 return;
15651 RHS = SE.getUMaxExpr(RHS, One);
15652 [[fallthrough]];
15653 case CmpInst::ICMP_SLT: {
15654 RHS = SE.getMinusSCEV(RHS, One);
15655 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15656 break;
15657 }
15658 case CmpInst::ICMP_UGT:
15659 case CmpInst::ICMP_SGT:
15660 RHS = SE.getAddExpr(RHS, One);
15661 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15662 break;
15663 case CmpInst::ICMP_ULE:
15664 case CmpInst::ICMP_SLE:
15665 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15666 break;
15667 case CmpInst::ICMP_UGE:
15668 case CmpInst::ICMP_SGE:
15669 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15670 break;
15671 default:
15672 break;
15673 }
15674
15675 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15677
15678 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15679 append_range(Worklist, S->operands());
15680 };
15681
15682 while (!Worklist.empty()) {
15683 const SCEV *From = Worklist.pop_back_val();
15684 if (isa<SCEVConstant>(From))
15685 continue;
15686 if (!Visited.insert(From).second)
15687 continue;
15688 const SCEV *FromRewritten = GetMaybeRewritten(From);
15689 const SCEV *To = nullptr;
15690
15691 switch (Predicate) {
15692 case CmpInst::ICMP_ULT:
15693 case CmpInst::ICMP_ULE:
15694 To = SE.getUMinExpr(FromRewritten, RHS);
15695 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15696 EnqueueOperands(UMax);
15697 break;
15698 case CmpInst::ICMP_SLT:
15699 case CmpInst::ICMP_SLE:
15700 To = SE.getSMinExpr(FromRewritten, RHS);
15701 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15702 EnqueueOperands(SMax);
15703 break;
15704 case CmpInst::ICMP_UGT:
15705 case CmpInst::ICMP_UGE:
15706 To = SE.getUMaxExpr(FromRewritten, RHS);
15707 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15708 EnqueueOperands(UMin);
15709 break;
15710 case CmpInst::ICMP_SGT:
15711 case CmpInst::ICMP_SGE:
15712 To = SE.getSMaxExpr(FromRewritten, RHS);
15713 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15714 EnqueueOperands(SMin);
15715 break;
15716 case CmpInst::ICMP_EQ:
15717 if (isa<SCEVConstant>(RHS))
15718 To = RHS;
15719 break;
15720 case CmpInst::ICMP_NE:
15721 if (match(RHS, m_scev_Zero())) {
15722 const SCEV *OneAlignedUp =
15723 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15724 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15725 }
15726 break;
15727 default:
15728 break;
15729 }
15730
15731 if (To)
15732 AddRewrite(From, FromRewritten, To);
15733 }
15734 };
15735
15737 // First, collect information from assumptions dominating the loop.
15738 for (auto &AssumeVH : SE.AC.assumptions()) {
15739 if (!AssumeVH)
15740 continue;
15741 auto *AssumeI = cast<CallInst>(AssumeVH);
15742 if (!SE.DT.dominates(AssumeI, Block))
15743 continue;
15744 Terms.emplace_back(AssumeI->getOperand(0), true);
15745 }
15746
15747 // Second, collect information from llvm.experimental.guards dominating the loop.
15748 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15749 SE.F.getParent(), Intrinsic::experimental_guard);
15750 if (GuardDecl)
15751 for (const auto *GU : GuardDecl->users())
15752 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15753 if (Guard->getFunction() == Block->getParent() &&
15754 SE.DT.dominates(Guard, Block))
15755 Terms.emplace_back(Guard->getArgOperand(0), true);
15756
15757 // Third, collect conditions from dominating branches. Starting at the loop
15758 // predecessor, climb up the predecessor chain, as long as there are
15759 // predecessors that can be found that have unique successors leading to the
15760 // original header.
15761 // TODO: share this logic with isLoopEntryGuardedByCond.
15762 unsigned NumCollectedConditions = 0;
15763 VisitedBlocks.insert(Block);
15764 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15765 for (; Pair.first;
15766 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15767 VisitedBlocks.insert(Pair.second);
15768 const BranchInst *LoopEntryPredicate =
15769 dyn_cast<BranchInst>(Pair.first->getTerminator());
15770 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15771 continue;
15772
15773 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15774 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15775 NumCollectedConditions++;
15776
15777 // If we are recursively collecting guards stop after 2
15778 // conditions to limit compile-time impact for now.
15779 if (Depth > 0 && NumCollectedConditions == 2)
15780 break;
15781 }
15782 // Finally, if we stopped climbing the predecessor chain because
15783 // there wasn't a unique one to continue, try to collect conditions
15784 // for PHINodes by recursively following all of their incoming
15785 // blocks and try to merge the found conditions to build a new one
15786 // for the Phi.
15787 if (Pair.second->hasNPredecessorsOrMore(2) &&
15790 for (auto &Phi : Pair.second->phis())
15791 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15792 }
15793
15794 // Now apply the information from the collected conditions to
15795 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15796 // earliest conditions is processed first. This ensures the SCEVs with the
15797 // shortest dependency chains are constructed first.
15798 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15799 SmallVector<Value *, 8> Worklist;
15801 Worklist.push_back(Term);
15802 while (!Worklist.empty()) {
15803 Value *Cond = Worklist.pop_back_val();
15804 if (!Visited.insert(Cond).second)
15805 continue;
15806
15807 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15808 auto Predicate =
15809 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15810 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15811 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15812 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15813 continue;
15814 }
15815
15816 Value *L, *R;
15817 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15818 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15819 Worklist.push_back(L);
15820 Worklist.push_back(R);
15821 }
15822 }
15823 }
15824
15825 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15826 // the replacement expressions are contained in the ranges of the replaced
15827 // expressions.
15828 Guards.PreserveNUW = true;
15829 Guards.PreserveNSW = true;
15830 for (const SCEV *Expr : ExprsToRewrite) {
15831 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15832 Guards.PreserveNUW &=
15833 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15834 Guards.PreserveNSW &=
15835 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15836 }
15837
15838 // Now that all rewrite information is collect, rewrite the collected
15839 // expressions with the information in the map. This applies information to
15840 // sub-expressions.
15841 if (ExprsToRewrite.size() > 1) {
15842 for (const SCEV *Expr : ExprsToRewrite) {
15843 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15844 Guards.RewriteMap.erase(Expr);
15845 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15846 }
15847 }
15848}
15849
15851 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15852 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15853 /// replacement is loop invariant in the loop of the AddRec.
15854 class SCEVLoopGuardRewriter
15855 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15857
15859
15860 public:
15861 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15862 const ScalarEvolution::LoopGuards &Guards)
15863 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15864 if (Guards.PreserveNUW)
15865 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15866 if (Guards.PreserveNSW)
15867 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15868 }
15869
15870 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15871
15872 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15873 auto I = Map.find(Expr);
15874 if (I == Map.end())
15875 return Expr;
15876 return I->second;
15877 }
15878
15879 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15880 auto I = Map.find(Expr);
15881 if (I == Map.end()) {
15882 // If we didn't find the extact ZExt expr in the map, check if there's
15883 // an entry for a smaller ZExt we can use instead.
15884 Type *Ty = Expr->getType();
15885 const SCEV *Op = Expr->getOperand(0);
15886 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15887 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15888 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15889 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15890 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15891 auto I = Map.find(NarrowExt);
15892 if (I != Map.end())
15893 return SE.getZeroExtendExpr(I->second, Ty);
15894 Bitwidth = Bitwidth / 2;
15895 }
15896
15898 Expr);
15899 }
15900 return I->second;
15901 }
15902
15903 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15904 auto I = Map.find(Expr);
15905 if (I == Map.end())
15907 Expr);
15908 return I->second;
15909 }
15910
15911 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15912 auto I = Map.find(Expr);
15913 if (I == Map.end())
15915 return I->second;
15916 }
15917
15918 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15919 auto I = Map.find(Expr);
15920 if (I == Map.end())
15922 return I->second;
15923 }
15924
15925 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15927 bool Changed = false;
15928 for (const auto *Op : Expr->operands()) {
15929 Operands.push_back(
15931 Changed |= Op != Operands.back();
15932 }
15933 // We are only replacing operands with equivalent values, so transfer the
15934 // flags from the original expression.
15935 return !Changed ? Expr
15936 : SE.getAddExpr(Operands,
15938 Expr->getNoWrapFlags(), FlagMask));
15939 }
15940
15941 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15943 bool Changed = false;
15944 for (const auto *Op : Expr->operands()) {
15945 Operands.push_back(
15947 Changed |= Op != Operands.back();
15948 }
15949 // We are only replacing operands with equivalent values, so transfer the
15950 // flags from the original expression.
15951 return !Changed ? Expr
15952 : SE.getMulExpr(Operands,
15954 Expr->getNoWrapFlags(), FlagMask));
15955 }
15956 };
15957
15958 if (RewriteMap.empty())
15959 return Expr;
15960
15961 SCEVLoopGuardRewriter Rewriter(SE, *this);
15962 return Rewriter.visit(Expr);
15963}
15964
15965const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15966 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15967}
15968
15970 const LoopGuards &Guards) {
15971 return Guards.rewrite(Expr);
15972}
@ Poison
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:622
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:557
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:261
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1945
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1547
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:986
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1520
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1392
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:612
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1007
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1492
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:910
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:206
APInt abs() const
Get the absolute value.
Definition: APInt.h:1773
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition: APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:466
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1640
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1468
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:219
unsigned countTrailingZeros() const
Definition: APInt.h:1626
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:356
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:827
APInt multiplicativeInverse() const
Definition: APInt.cpp:1248
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1150
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:959
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1221
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:231
iterator end() const
Definition: ArrayRef.h:157
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
iterator begin() const
Definition: ArrayRef.h:156
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:461
const Instruction & front() const
Definition: BasicBlock.h:484
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:481
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:220
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:240
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:370
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:946
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:673
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:703
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:697
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:696
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:700
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:698
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:701
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:699
bool isSigned() const
Definition: InstrTypes.h:928
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:825
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:940
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:869
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:787
bool isUnsigned() const
Definition: InstrTypes.h:934
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:924
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
Definition: CmpPredicate.h:22
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2632
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2293
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1267
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2638
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2626
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2279
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:873
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:148
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:880
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:709
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:851
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:754
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:878
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:617
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition: DenseMap.h:226
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:176
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const BasicBlock & getEntryBlock() const
Definition: Function.h:821
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:657
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:407
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:404
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
Definition: DerivedTypes.h:42
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
An instruction for reading from memory.
Definition: Instructions.h:176
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1073
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:686
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:121
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:458
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
size_type size() const
Definition: SmallSet.h:170
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void reserve(size_type N)
Definition: SmallVector.h:663
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:805
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:567
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:596
TypeSize getSizeInBits() const
Definition: DataLayout.h:576
Class to represent struct types.
Definition: DerivedTypes.h:218
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:264
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:252
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:288
Use & Op()
Definition: User.h:192
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5144
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2217
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2222
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2227
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2785
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2232
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:771
@ Entry
Definition: COFF.h:844
@ Exit
Definition: COFF.h:845
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:747
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:832
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bind_ty< const SCEVConstant > m_SCEVConstant(const SCEVConstant *&V)
bind_ty< const SCEV > m_SCEV(const SCEV *&V)
Match a SCEV, capturing it if we match.
SCEVBinaryExpr_match< SCEVAddExpr, Op0_t, Op1_t > m_scev_Add(const Op0_t &Op0, const Op1_t &Op1)
bool match(const SCEV *S, const Pattern &P)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:390
void stable_sort(R &&Range)
Definition: STLExtras.h:2037
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:256
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7301
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2115
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2107
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1162
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1152
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:256
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:261
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2014
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:303
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1873
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1945
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1903
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:293
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:100
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:43
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:188
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:137
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:121
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:97
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.