LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (matchURem(Op, LHS, RHS))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 if (SM->getNumOperands() == 2)
1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845 if (MulLHS->getAPInt().isPowerOf2())
1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848 MulLHS->getAPInt().logBase2();
1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850 return getMulExpr(
1851 getZeroExtendExpr(MulLHS, Ty),
1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2427static SCEV::NoWrapFlags
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 if (Ops.size() == 2) {
2704 if (Mul && Mul->getNumOperands() == 2 &&
2705 Mul->getOperand(0)->isAllOnesValue()) {
2706 const SCEV *X;
2707 const SCEV *Y;
2708 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709 return getMulExpr(Y, getUDivExpr(X, Y));
2710 }
2711 }
2712 }
2713
2714 // Skip past any other cast SCEVs.
2715 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2716 ++Idx;
2717
2718 // If there are add operands they would be next.
2719 if (Idx < Ops.size()) {
2720 bool DeletedAdd = false;
2721 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2722 // common NUW flag for expression after inlining. Other flags cannot be
2723 // preserved, because they may depend on the original order of operations.
2724 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2725 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2726 if (Ops.size() > AddOpsInlineThreshold ||
2727 Add->getNumOperands() > AddOpsInlineThreshold)
2728 break;
2729 // If we have an add, expand the add operands onto the end of the operands
2730 // list.
2731 Ops.erase(Ops.begin()+Idx);
2732 append_range(Ops, Add->operands());
2733 DeletedAdd = true;
2734 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2735 }
2736
2737 // If we deleted at least one add, we added operands to the end of the list,
2738 // and they are not necessarily sorted. Recurse to resort and resimplify
2739 // any operands we just acquired.
2740 if (DeletedAdd)
2741 return getAddExpr(Ops, CommonFlags, Depth + 1);
2742 }
2743
2744 // Skip over the add expression until we get to a multiply.
2745 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2746 ++Idx;
2747
2748 // Check to see if there are any folding opportunities present with
2749 // operands multiplied by constant values.
2750 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2754 APInt AccumulatedConstant(BitWidth, 0);
2755 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2756 Ops, APInt(BitWidth, 1), *this)) {
2757 struct APIntCompare {
2758 bool operator()(const APInt &LHS, const APInt &RHS) const {
2759 return LHS.ult(RHS);
2760 }
2761 };
2762
2763 // Some interesting folding opportunity is present, so its worthwhile to
2764 // re-generate the operands list. Group the operands by constant scale,
2765 // to avoid multiplying by the same constant scale multiple times.
2766 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2767 for (const SCEV *NewOp : NewOps)
2768 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2769 // Re-generate the operands list.
2770 Ops.clear();
2771 if (AccumulatedConstant != 0)
2772 Ops.push_back(getConstant(AccumulatedConstant));
2773 for (auto &MulOp : MulOpLists) {
2774 if (MulOp.first == 1) {
2775 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2776 } else if (MulOp.first != 0) {
2777 Ops.push_back(getMulExpr(
2778 getConstant(MulOp.first),
2779 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2780 SCEV::FlagAnyWrap, Depth + 1));
2781 }
2782 }
2783 if (Ops.empty())
2784 return getZero(Ty);
2785 if (Ops.size() == 1)
2786 return Ops[0];
2787 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2788 }
2789 }
2790
2791 // If we are adding something to a multiply expression, make sure the
2792 // something is not already an operand of the multiply. If so, merge it into
2793 // the multiply.
2794 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2795 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2796 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2797 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2798 if (isa<SCEVConstant>(MulOpSCEV))
2799 continue;
2800 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2801 if (MulOpSCEV == Ops[AddOp]) {
2802 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2803 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2804 if (Mul->getNumOperands() != 2) {
2805 // If the multiply has more than two operands, we must get the
2806 // Y*Z term.
2808 Mul->operands().take_front(MulOp));
2809 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2810 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2811 }
2812 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2813 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2814 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2816 if (Ops.size() == 2) return OuterMul;
2817 if (AddOp < Idx) {
2818 Ops.erase(Ops.begin()+AddOp);
2819 Ops.erase(Ops.begin()+Idx-1);
2820 } else {
2821 Ops.erase(Ops.begin()+Idx);
2822 Ops.erase(Ops.begin()+AddOp-1);
2823 }
2824 Ops.push_back(OuterMul);
2825 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2826 }
2827
2828 // Check this multiply against other multiplies being added together.
2829 for (unsigned OtherMulIdx = Idx+1;
2830 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 ++OtherMulIdx) {
2832 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2833 // If MulOp occurs in OtherMul, we can fold the two multiplies
2834 // together.
2835 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2836 OMulOp != e; ++OMulOp)
2837 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2838 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2839 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2840 if (Mul->getNumOperands() != 2) {
2842 Mul->operands().take_front(MulOp));
2843 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2844 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2845 }
2846 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2847 if (OtherMul->getNumOperands() != 2) {
2849 OtherMul->operands().take_front(OMulOp));
2850 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2851 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2852 }
2853 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2854 const SCEV *InnerMulSum =
2855 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2856 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2858 if (Ops.size() == 2) return OuterMul;
2859 Ops.erase(Ops.begin()+Idx);
2860 Ops.erase(Ops.begin()+OtherMulIdx-1);
2861 Ops.push_back(OuterMul);
2862 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2863 }
2864 }
2865 }
2866 }
2867
2868 // If there are any add recurrences in the operands list, see if any other
2869 // added values are loop invariant. If so, we can fold them into the
2870 // recurrence.
2871 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2872 ++Idx;
2873
2874 // Scan over all recurrences, trying to fold loop invariants into them.
2875 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2876 // Scan all of the other operands to this add and add them to the vector if
2877 // they are loop invariant w.r.t. the recurrence.
2879 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2880 const Loop *AddRecLoop = AddRec->getLoop();
2881 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2882 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2883 LIOps.push_back(Ops[i]);
2884 Ops.erase(Ops.begin()+i);
2885 --i; --e;
2886 }
2887
2888 // If we found some loop invariants, fold them into the recurrence.
2889 if (!LIOps.empty()) {
2890 // Compute nowrap flags for the addition of the loop-invariant ops and
2891 // the addrec. Temporarily push it as an operand for that purpose. These
2892 // flags are valid in the scope of the addrec only.
2893 LIOps.push_back(AddRec);
2894 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2895 LIOps.pop_back();
2896
2897 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2898 LIOps.push_back(AddRec->getStart());
2899
2900 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2901
2902 // It is not in general safe to propagate flags valid on an add within
2903 // the addrec scope to one outside it. We must prove that the inner
2904 // scope is guaranteed to execute if the outer one does to be able to
2905 // safely propagate. We know the program is undefined if poison is
2906 // produced on the inner scoped addrec. We also know that *for this use*
2907 // the outer scoped add can't overflow (because of the flags we just
2908 // computed for the inner scoped add) without the program being undefined.
2909 // Proving that entry to the outer scope neccesitates entry to the inner
2910 // scope, thus proves the program undefined if the flags would be violated
2911 // in the outer scope.
2912 SCEV::NoWrapFlags AddFlags = Flags;
2913 if (AddFlags != SCEV::FlagAnyWrap) {
2914 auto *DefI = getDefiningScopeBound(LIOps);
2915 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2916 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2917 AddFlags = SCEV::FlagAnyWrap;
2918 }
2919 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2920
2921 // Build the new addrec. Propagate the NUW and NSW flags if both the
2922 // outer add and the inner addrec are guaranteed to have no overflow.
2923 // Always propagate NW.
2924 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2925 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2926
2927 // If all of the other operands were loop invariant, we are done.
2928 if (Ops.size() == 1) return NewRec;
2929
2930 // Otherwise, add the folded AddRec by the non-invariant parts.
2931 for (unsigned i = 0;; ++i)
2932 if (Ops[i] == AddRec) {
2933 Ops[i] = NewRec;
2934 break;
2935 }
2936 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2937 }
2938
2939 // Okay, if there weren't any loop invariants to be folded, check to see if
2940 // there are multiple AddRec's with the same loop induction variable being
2941 // added together. If so, we can fold them.
2942 for (unsigned OtherIdx = Idx+1;
2943 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2944 ++OtherIdx) {
2945 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2946 // so that the 1st found AddRecExpr is dominated by all others.
2947 assert(DT.dominates(
2948 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2949 AddRec->getLoop()->getHeader()) &&
2950 "AddRecExprs are not sorted in reverse dominance order?");
2951 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2952 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2953 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2954 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 ++OtherIdx) {
2956 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2957 if (OtherAddRec->getLoop() == AddRecLoop) {
2958 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2959 i != e; ++i) {
2960 if (i >= AddRecOps.size()) {
2961 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2962 break;
2963 }
2965 AddRecOps[i], OtherAddRec->getOperand(i)};
2966 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2969 }
2970 }
2971 // Step size has changed, so we cannot guarantee no self-wraparound.
2972 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2973 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2974 }
2975 }
2976
2977 // Otherwise couldn't fold anything into this recurrence. Move onto the
2978 // next one.
2979 }
2980
2981 // Okay, it looks like we really DO need an add expr. Check to see if we
2982 // already have one, otherwise create a new one.
2983 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2984}
2985
2986const SCEV *
2987ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2988 SCEV::NoWrapFlags Flags) {
2990 ID.AddInteger(scAddExpr);
2991 for (const SCEV *Op : Ops)
2992 ID.AddPointer(Op);
2993 void *IP = nullptr;
2994 SCEVAddExpr *S =
2995 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2996 if (!S) {
2997 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2999 S = new (SCEVAllocator)
3000 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3001 UniqueSCEVs.InsertNode(S, IP);
3002 registerUser(S, Ops);
3003 }
3004 S->setNoWrapFlags(Flags);
3005 return S;
3006}
3007
3008const SCEV *
3009ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3010 const Loop *L, SCEV::NoWrapFlags Flags) {
3011 FoldingSetNodeID ID;
3012 ID.AddInteger(scAddRecExpr);
3013 for (const SCEV *Op : Ops)
3014 ID.AddPointer(Op);
3015 ID.AddPointer(L);
3016 void *IP = nullptr;
3017 SCEVAddRecExpr *S =
3018 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3019 if (!S) {
3020 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3022 S = new (SCEVAllocator)
3023 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3024 UniqueSCEVs.InsertNode(S, IP);
3025 LoopUsers[L].push_back(S);
3026 registerUser(S, Ops);
3027 }
3028 setNoWrapFlags(S, Flags);
3029 return S;
3030}
3031
3032const SCEV *
3033ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3034 SCEV::NoWrapFlags Flags) {
3035 FoldingSetNodeID ID;
3036 ID.AddInteger(scMulExpr);
3037 for (const SCEV *Op : Ops)
3038 ID.AddPointer(Op);
3039 void *IP = nullptr;
3040 SCEVMulExpr *S =
3041 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3042 if (!S) {
3043 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3045 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3046 O, Ops.size());
3047 UniqueSCEVs.InsertNode(S, IP);
3048 registerUser(S, Ops);
3049 }
3050 S->setNoWrapFlags(Flags);
3051 return S;
3052}
3053
3054static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3055 uint64_t k = i*j;
3056 if (j > 1 && k / j != i) Overflow = true;
3057 return k;
3058}
3059
3060/// Compute the result of "n choose k", the binomial coefficient. If an
3061/// intermediate computation overflows, Overflow will be set and the return will
3062/// be garbage. Overflow is not cleared on absence of overflow.
3063static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3064 // We use the multiplicative formula:
3065 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3066 // At each iteration, we take the n-th term of the numeral and divide by the
3067 // (k-n)th term of the denominator. This division will always produce an
3068 // integral result, and helps reduce the chance of overflow in the
3069 // intermediate computations. However, we can still overflow even when the
3070 // final result would fit.
3071
3072 if (n == 0 || n == k) return 1;
3073 if (k > n) return 0;
3074
3075 if (k > n/2)
3076 k = n-k;
3077
3078 uint64_t r = 1;
3079 for (uint64_t i = 1; i <= k; ++i) {
3080 r = umul_ov(r, n-(i-1), Overflow);
3081 r /= i;
3082 }
3083 return r;
3084}
3085
3086/// Determine if any of the operands in this SCEV are a constant or if
3087/// any of the add or multiply expressions in this SCEV contain a constant.
3088static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3089 struct FindConstantInAddMulChain {
3090 bool FoundConstant = false;
3091
3092 bool follow(const SCEV *S) {
3093 FoundConstant |= isa<SCEVConstant>(S);
3094 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3095 }
3096
3097 bool isDone() const {
3098 return FoundConstant;
3099 }
3100 };
3101
3102 FindConstantInAddMulChain F;
3104 ST.visitAll(StartExpr);
3105 return F.FoundConstant;
3106}
3107
3108/// Get a canonical multiply expression, or something simpler if possible.
3110 SCEV::NoWrapFlags OrigFlags,
3111 unsigned Depth) {
3112 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3113 "only nuw or nsw allowed");
3114 assert(!Ops.empty() && "Cannot get empty mul!");
3115 if (Ops.size() == 1) return Ops[0];
3116#ifndef NDEBUG
3117 Type *ETy = Ops[0]->getType();
3118 assert(!ETy->isPointerTy());
3119 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3120 assert(Ops[i]->getType() == ETy &&
3121 "SCEVMulExpr operand types don't match!");
3122#endif
3123
3124 const SCEV *Folded = constantFoldAndGroupOps(
3125 *this, LI, DT, Ops,
3126 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3127 [](const APInt &C) { return C.isOne(); }, // identity
3128 [](const APInt &C) { return C.isZero(); }); // absorber
3129 if (Folded)
3130 return Folded;
3131
3132 // Delay expensive flag strengthening until necessary.
3133 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3134 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3135 };
3136
3137 // Limit recursion calls depth.
3139 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3140
3141 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3142 // Don't strengthen flags if we have no new information.
3143 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3144 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3145 Mul->setNoWrapFlags(ComputeFlags(Ops));
3146 return S;
3147 }
3148
3149 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3150 if (Ops.size() == 2) {
3151 // C1*(C2+V) -> C1*C2 + C1*V
3152 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3153 // If any of Add's ops are Adds or Muls with a constant, apply this
3154 // transformation as well.
3155 //
3156 // TODO: There are some cases where this transformation is not
3157 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3158 // this transformation should be narrowed down.
3159 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3160 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3162 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3164 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3165 }
3166
3167 if (Ops[0]->isAllOnesValue()) {
3168 // If we have a mul by -1 of an add, try distributing the -1 among the
3169 // add operands.
3170 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3172 bool AnyFolded = false;
3173 for (const SCEV *AddOp : Add->operands()) {
3174 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3175 Depth + 1);
3176 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3177 NewOps.push_back(Mul);
3178 }
3179 if (AnyFolded)
3180 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3181 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3182 // Negation preserves a recurrence's no self-wrap property.
3184 for (const SCEV *AddRecOp : AddRec->operands())
3185 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3186 Depth + 1));
3187 // Let M be the minimum representable signed value. AddRec with nsw
3188 // multiplied by -1 can have signed overflow if and only if it takes a
3189 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3190 // maximum signed value. In all other cases signed overflow is
3191 // impossible.
3192 auto FlagsMask = SCEV::FlagNW;
3193 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3194 auto MinInt =
3195 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3196 if (getSignedRangeMin(AddRec) != MinInt)
3197 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3198 }
3199 return getAddRecExpr(Operands, AddRec->getLoop(),
3200 AddRec->getNoWrapFlags(FlagsMask));
3201 }
3202 }
3203
3204 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3205 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3206 const SCEVAddExpr *InnerAdd;
3207 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3208 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3209 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3210 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3211 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3213 SCEV::FlagNUW)) {
3214 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3215 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3216 };
3217 }
3218
3219 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3220 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3221 // of C1, fold to (D /u (C2 /u C1)).
3222 const SCEV *D;
3223 APInt C1V = LHSC->getAPInt();
3224 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3225 // as -1 * 1, as it won't enable additional folds.
3226 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3227 C1V = C1V.abs();
3228 const SCEVConstant *C2;
3229 if (C1V.isPowerOf2() &&
3231 C2->getAPInt().isPowerOf2() &&
3232 C1V.logBase2() <= getMinTrailingZeros(D)) {
3233 const SCEV *NewMul = nullptr;
3234 if (C1V.uge(C2->getAPInt())) {
3235 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3236 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3237 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3238 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3239 }
3240 if (NewMul)
3241 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3242 }
3243 }
3244 }
3245
3246 // Skip over the add expression until we get to a multiply.
3247 unsigned Idx = 0;
3248 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3249 ++Idx;
3250
3251 // If there are mul operands inline them all into this expression.
3252 if (Idx < Ops.size()) {
3253 bool DeletedMul = false;
3254 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3255 if (Ops.size() > MulOpsInlineThreshold)
3256 break;
3257 // If we have an mul, expand the mul operands onto the end of the
3258 // operands list.
3259 Ops.erase(Ops.begin()+Idx);
3260 append_range(Ops, Mul->operands());
3261 DeletedMul = true;
3262 }
3263
3264 // If we deleted at least one mul, we added operands to the end of the
3265 // list, and they are not necessarily sorted. Recurse to resort and
3266 // resimplify any operands we just acquired.
3267 if (DeletedMul)
3268 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3269 }
3270
3271 // If there are any add recurrences in the operands list, see if any other
3272 // added values are loop invariant. If so, we can fold them into the
3273 // recurrence.
3274 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3275 ++Idx;
3276
3277 // Scan over all recurrences, trying to fold loop invariants into them.
3278 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3279 // Scan all of the other operands to this mul and add them to the vector
3280 // if they are loop invariant w.r.t. the recurrence.
3282 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3283 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3284 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3285 LIOps.push_back(Ops[i]);
3286 Ops.erase(Ops.begin()+i);
3287 --i; --e;
3288 }
3289
3290 // If we found some loop invariants, fold them into the recurrence.
3291 if (!LIOps.empty()) {
3292 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3294 NewOps.reserve(AddRec->getNumOperands());
3295 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3296
3297 // If both the mul and addrec are nuw, we can preserve nuw.
3298 // If both the mul and addrec are nsw, we can only preserve nsw if either
3299 // a) they are also nuw, or
3300 // b) all multiplications of addrec operands with scale are nsw.
3301 SCEV::NoWrapFlags Flags =
3302 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3303
3304 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3305 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3306 SCEV::FlagAnyWrap, Depth + 1));
3307
3308 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3310 Instruction::Mul, getSignedRange(Scale),
3312 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3313 Flags = clearFlags(Flags, SCEV::FlagNSW);
3314 }
3315 }
3316
3317 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3318
3319 // If all of the other operands were loop invariant, we are done.
3320 if (Ops.size() == 1) return NewRec;
3321
3322 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3323 for (unsigned i = 0;; ++i)
3324 if (Ops[i] == AddRec) {
3325 Ops[i] = NewRec;
3326 break;
3327 }
3328 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3329 }
3330
3331 // Okay, if there weren't any loop invariants to be folded, check to see
3332 // if there are multiple AddRec's with the same loop induction variable
3333 // being multiplied together. If so, we can fold them.
3334
3335 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3336 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3337 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3338 // ]]],+,...up to x=2n}.
3339 // Note that the arguments to choose() are always integers with values
3340 // known at compile time, never SCEV objects.
3341 //
3342 // The implementation avoids pointless extra computations when the two
3343 // addrec's are of different length (mathematically, it's equivalent to
3344 // an infinite stream of zeros on the right).
3345 bool OpsModified = false;
3346 for (unsigned OtherIdx = Idx+1;
3347 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3348 ++OtherIdx) {
3349 const SCEVAddRecExpr *OtherAddRec =
3350 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3351 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3352 continue;
3353
3354 // Limit max number of arguments to avoid creation of unreasonably big
3355 // SCEVAddRecs with very complex operands.
3356 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3357 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3358 continue;
3359
3360 bool Overflow = false;
3361 Type *Ty = AddRec->getType();
3362 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3364 for (int x = 0, xe = AddRec->getNumOperands() +
3365 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3367 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3368 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3369 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3370 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3371 z < ze && !Overflow; ++z) {
3372 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3373 uint64_t Coeff;
3374 if (LargerThan64Bits)
3375 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3376 else
3377 Coeff = Coeff1*Coeff2;
3378 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3379 const SCEV *Term1 = AddRec->getOperand(y-z);
3380 const SCEV *Term2 = OtherAddRec->getOperand(z);
3381 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3382 SCEV::FlagAnyWrap, Depth + 1));
3383 }
3384 }
3385 if (SumOps.empty())
3386 SumOps.push_back(getZero(Ty));
3387 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3388 }
3389 if (!Overflow) {
3390 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3392 if (Ops.size() == 2) return NewAddRec;
3393 Ops[Idx] = NewAddRec;
3394 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3395 OpsModified = true;
3396 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3397 if (!AddRec)
3398 break;
3399 }
3400 }
3401 if (OpsModified)
3402 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3403
3404 // Otherwise couldn't fold anything into this recurrence. Move onto the
3405 // next one.
3406 }
3407
3408 // Okay, it looks like we really DO need an mul expr. Check to see if we
3409 // already have one, otherwise create a new one.
3410 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3411}
3412
3413/// Represents an unsigned remainder expression based on unsigned division.
3415 const SCEV *RHS) {
3416 assert(getEffectiveSCEVType(LHS->getType()) ==
3417 getEffectiveSCEVType(RHS->getType()) &&
3418 "SCEVURemExpr operand types don't match!");
3419
3420 // Short-circuit easy cases
3421 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3422 // If constant is one, the result is trivial
3423 if (RHSC->getValue()->isOne())
3424 return getZero(LHS->getType()); // X urem 1 --> 0
3425
3426 // If constant is a power of two, fold into a zext(trunc(LHS)).
3427 if (RHSC->getAPInt().isPowerOf2()) {
3428 Type *FullTy = LHS->getType();
3429 Type *TruncTy =
3430 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3431 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3432 }
3433 }
3434
3435 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3436 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3437 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3438 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3439}
3440
3441/// Get a canonical unsigned division expression, or something simpler if
3442/// possible.
3444 const SCEV *RHS) {
3445 assert(!LHS->getType()->isPointerTy() &&
3446 "SCEVUDivExpr operand can't be pointer!");
3447 assert(LHS->getType() == RHS->getType() &&
3448 "SCEVUDivExpr operand types don't match!");
3449
3451 ID.AddInteger(scUDivExpr);
3452 ID.AddPointer(LHS);
3453 ID.AddPointer(RHS);
3454 void *IP = nullptr;
3455 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3456 return S;
3457
3458 // 0 udiv Y == 0
3459 if (match(LHS, m_scev_Zero()))
3460 return LHS;
3461
3462 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3463 if (RHSC->getValue()->isOne())
3464 return LHS; // X udiv 1 --> x
3465 // If the denominator is zero, the result of the udiv is undefined. Don't
3466 // try to analyze it, because the resolution chosen here may differ from
3467 // the resolution chosen in other parts of the compiler.
3468 if (!RHSC->getValue()->isZero()) {
3469 // Determine if the division can be folded into the operands of
3470 // its operands.
3471 // TODO: Generalize this to non-constants by using known-bits information.
3472 Type *Ty = LHS->getType();
3473 unsigned LZ = RHSC->getAPInt().countl_zero();
3474 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3475 // For non-power-of-two values, effectively round the value up to the
3476 // nearest power of two.
3477 if (!RHSC->getAPInt().isPowerOf2())
3478 ++MaxShiftAmt;
3479 IntegerType *ExtTy =
3480 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3481 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3482 if (const SCEVConstant *Step =
3483 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3484 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3485 const APInt &StepInt = Step->getAPInt();
3486 const APInt &DivInt = RHSC->getAPInt();
3487 if (!StepInt.urem(DivInt) &&
3488 getZeroExtendExpr(AR, ExtTy) ==
3489 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3490 getZeroExtendExpr(Step, ExtTy),
3491 AR->getLoop(), SCEV::FlagAnyWrap)) {
3493 for (const SCEV *Op : AR->operands())
3494 Operands.push_back(getUDivExpr(Op, RHS));
3495 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3496 }
3497 /// Get a canonical UDivExpr for a recurrence.
3498 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3499 // We can currently only fold X%N if X is constant.
3500 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3501 if (StartC && !DivInt.urem(StepInt) &&
3502 getZeroExtendExpr(AR, ExtTy) ==
3503 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3504 getZeroExtendExpr(Step, ExtTy),
3505 AR->getLoop(), SCEV::FlagAnyWrap)) {
3506 const APInt &StartInt = StartC->getAPInt();
3507 const APInt &StartRem = StartInt.urem(StepInt);
3508 if (StartRem != 0) {
3509 const SCEV *NewLHS =
3510 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3511 AR->getLoop(), SCEV::FlagNW);
3512 if (LHS != NewLHS) {
3513 LHS = NewLHS;
3514
3515 // Reset the ID to include the new LHS, and check if it is
3516 // already cached.
3517 ID.clear();
3518 ID.AddInteger(scUDivExpr);
3519 ID.AddPointer(LHS);
3520 ID.AddPointer(RHS);
3521 IP = nullptr;
3522 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3523 return S;
3524 }
3525 }
3526 }
3527 }
3528 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3529 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3531 for (const SCEV *Op : M->operands())
3532 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3533 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3534 // Find an operand that's safely divisible.
3535 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3536 const SCEV *Op = M->getOperand(i);
3537 const SCEV *Div = getUDivExpr(Op, RHSC);
3538 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3539 Operands = SmallVector<const SCEV *, 4>(M->operands());
3540 Operands[i] = Div;
3541 return getMulExpr(Operands);
3542 }
3543 }
3544 }
3545
3546 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3547 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3548 if (auto *DivisorConstant =
3549 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3550 bool Overflow = false;
3551 APInt NewRHS =
3552 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3553 if (Overflow) {
3554 return getConstant(RHSC->getType(), 0, false);
3555 }
3556 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3557 }
3558 }
3559
3560 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3561 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3563 for (const SCEV *Op : A->operands())
3564 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3565 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3566 Operands.clear();
3567 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3568 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3569 if (isa<SCEVUDivExpr>(Op) ||
3570 getMulExpr(Op, RHS) != A->getOperand(i))
3571 break;
3572 Operands.push_back(Op);
3573 }
3574 if (Operands.size() == A->getNumOperands())
3575 return getAddExpr(Operands);
3576 }
3577 }
3578
3579 // Fold if both operands are constant.
3580 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3581 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3582 }
3583 }
3584
3585 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3586 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3587 AE && AE->getNumOperands() == 2) {
3588 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3589 const APInt &NegC = VC->getAPInt();
3590 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3591 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3592 if (MME && MME->getNumOperands() == 2 &&
3593 isa<SCEVConstant>(MME->getOperand(0)) &&
3594 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3595 MME->getOperand(1) == RHS)
3596 return getZero(LHS->getType());
3597 }
3598 }
3599 }
3600
3601 // TODO: Generalize to handle any common factors.
3602 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3603 const SCEV *NewLHS, *NewRHS;
3604 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3605 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3606 return getUDivExpr(NewLHS, NewRHS);
3607
3608 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3609 // changes). Make sure we get a new one.
3610 IP = nullptr;
3611 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3612 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3613 LHS, RHS);
3614 UniqueSCEVs.InsertNode(S, IP);
3615 registerUser(S, {LHS, RHS});
3616 return S;
3617}
3618
3619APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3620 APInt A = C1->getAPInt().abs();
3621 APInt B = C2->getAPInt().abs();
3622 uint32_t ABW = A.getBitWidth();
3623 uint32_t BBW = B.getBitWidth();
3624
3625 if (ABW > BBW)
3626 B = B.zext(ABW);
3627 else if (ABW < BBW)
3628 A = A.zext(BBW);
3629
3630 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3631}
3632
3633/// Get a canonical unsigned division expression, or something simpler if
3634/// possible. There is no representation for an exact udiv in SCEV IR, but we
3635/// can attempt to remove factors from the LHS and RHS. We can't do this when
3636/// it's not exact because the udiv may be clearing bits.
3638 const SCEV *RHS) {
3639 // TODO: we could try to find factors in all sorts of things, but for now we
3640 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3641 // end of this file for inspiration.
3642
3644 if (!Mul || !Mul->hasNoUnsignedWrap())
3645 return getUDivExpr(LHS, RHS);
3646
3647 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3648 // If the mulexpr multiplies by a constant, then that constant must be the
3649 // first element of the mulexpr.
3650 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3651 if (LHSCst == RHSCst) {
3653 return getMulExpr(Operands);
3654 }
3655
3656 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3657 // that there's a factor provided by one of the other terms. We need to
3658 // check.
3659 APInt Factor = gcd(LHSCst, RHSCst);
3660 if (!Factor.isIntN(1)) {
3661 LHSCst =
3662 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3663 RHSCst =
3664 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3666 Operands.push_back(LHSCst);
3667 append_range(Operands, Mul->operands().drop_front());
3668 LHS = getMulExpr(Operands);
3669 RHS = RHSCst;
3671 if (!Mul)
3672 return getUDivExactExpr(LHS, RHS);
3673 }
3674 }
3675 }
3676
3677 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3678 if (Mul->getOperand(i) == RHS) {
3680 append_range(Operands, Mul->operands().take_front(i));
3681 append_range(Operands, Mul->operands().drop_front(i + 1));
3682 return getMulExpr(Operands);
3683 }
3684 }
3685
3686 return getUDivExpr(LHS, RHS);
3687}
3688
3689/// Get an add recurrence expression for the specified loop. Simplify the
3690/// expression as much as possible.
3691const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3692 const Loop *L,
3693 SCEV::NoWrapFlags Flags) {
3695 Operands.push_back(Start);
3696 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3697 if (StepChrec->getLoop() == L) {
3698 append_range(Operands, StepChrec->operands());
3699 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3700 }
3701
3702 Operands.push_back(Step);
3703 return getAddRecExpr(Operands, L, Flags);
3704}
3705
3706/// Get an add recurrence expression for the specified loop. Simplify the
3707/// expression as much as possible.
3708const SCEV *
3710 const Loop *L, SCEV::NoWrapFlags Flags) {
3711 if (Operands.size() == 1) return Operands[0];
3712#ifndef NDEBUG
3714 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3715 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3716 "SCEVAddRecExpr operand types don't match!");
3717 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3718 }
3719 for (const SCEV *Op : Operands)
3721 "SCEVAddRecExpr operand is not available at loop entry!");
3722#endif
3723
3724 if (Operands.back()->isZero()) {
3725 Operands.pop_back();
3726 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3727 }
3728
3729 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3730 // use that information to infer NUW and NSW flags. However, computing a
3731 // BE count requires calling getAddRecExpr, so we may not yet have a
3732 // meaningful BE count at this point (and if we don't, we'd be stuck
3733 // with a SCEVCouldNotCompute as the cached BE count).
3734
3735 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3736
3737 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3738 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3739 const Loop *NestedLoop = NestedAR->getLoop();
3740 if (L->contains(NestedLoop)
3741 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3742 : (!NestedLoop->contains(L) &&
3743 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3744 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3745 Operands[0] = NestedAR->getStart();
3746 // AddRecs require their operands be loop-invariant with respect to their
3747 // loops. Don't perform this transformation if it would break this
3748 // requirement.
3749 bool AllInvariant = all_of(
3750 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3751
3752 if (AllInvariant) {
3753 // Create a recurrence for the outer loop with the same step size.
3754 //
3755 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3756 // inner recurrence has the same property.
3757 SCEV::NoWrapFlags OuterFlags =
3758 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3759
3760 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3761 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3762 return isLoopInvariant(Op, NestedLoop);
3763 });
3764
3765 if (AllInvariant) {
3766 // Ok, both add recurrences are valid after the transformation.
3767 //
3768 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3769 // the outer recurrence has the same property.
3770 SCEV::NoWrapFlags InnerFlags =
3771 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3772 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3773 }
3774 }
3775 // Reset Operands to its original state.
3776 Operands[0] = NestedAR;
3777 }
3778 }
3779
3780 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3781 // already have one, otherwise create a new one.
3782 return getOrCreateAddRecExpr(Operands, L, Flags);
3783}
3784
3785const SCEV *
3787 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3788 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3789 // getSCEV(Base)->getType() has the same address space as Base->getType()
3790 // because SCEV::getType() preserves the address space.
3791 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3792 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3793 if (NW != GEPNoWrapFlags::none()) {
3794 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3795 // but to do that, we have to ensure that said flag is valid in the entire
3796 // defined scope of the SCEV.
3797 // TODO: non-instructions have global scope. We might be able to prove
3798 // some global scope cases
3799 auto *GEPI = dyn_cast<Instruction>(GEP);
3800 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3801 NW = GEPNoWrapFlags::none();
3802 }
3803
3805 if (NW.hasNoUnsignedSignedWrap())
3806 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3807 if (NW.hasNoUnsignedWrap())
3808 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3809
3810 Type *CurTy = GEP->getType();
3811 bool FirstIter = true;
3813 for (const SCEV *IndexExpr : IndexExprs) {
3814 // Compute the (potentially symbolic) offset in bytes for this index.
3815 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3816 // For a struct, add the member offset.
3817 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3818 unsigned FieldNo = Index->getZExtValue();
3819 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3820 Offsets.push_back(FieldOffset);
3821
3822 // Update CurTy to the type of the field at Index.
3823 CurTy = STy->getTypeAtIndex(Index);
3824 } else {
3825 // Update CurTy to its element type.
3826 if (FirstIter) {
3827 assert(isa<PointerType>(CurTy) &&
3828 "The first index of a GEP indexes a pointer");
3829 CurTy = GEP->getSourceElementType();
3830 FirstIter = false;
3831 } else {
3833 }
3834 // For an array, add the element offset, explicitly scaled.
3835 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3836 // Getelementptr indices are signed.
3837 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3838
3839 // Multiply the index by the element size to compute the element offset.
3840 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3841 Offsets.push_back(LocalOffset);
3842 }
3843 }
3844
3845 // Handle degenerate case of GEP without offsets.
3846 if (Offsets.empty())
3847 return BaseExpr;
3848
3849 // Add the offsets together, assuming nsw if inbounds.
3850 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3851 // Add the base address and the offset. We cannot use the nsw flag, as the
3852 // base address is unsigned. However, if we know that the offset is
3853 // non-negative, we can use nuw.
3854 bool NUW = NW.hasNoUnsignedWrap() ||
3857 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3858 assert(BaseExpr->getType() == GEPExpr->getType() &&
3859 "GEP should not change type mid-flight.");
3860 return GEPExpr;
3861}
3862
3863SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3866 ID.AddInteger(SCEVType);
3867 for (const SCEV *Op : Ops)
3868 ID.AddPointer(Op);
3869 void *IP = nullptr;
3870 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3871}
3872
3873const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3875 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3876}
3877
3880 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3881 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3882 if (Ops.size() == 1) return Ops[0];
3883#ifndef NDEBUG
3884 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3885 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3886 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3887 "Operand types don't match!");
3888 assert(Ops[0]->getType()->isPointerTy() ==
3889 Ops[i]->getType()->isPointerTy() &&
3890 "min/max should be consistently pointerish");
3891 }
3892#endif
3893
3894 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3895 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3896
3897 const SCEV *Folded = constantFoldAndGroupOps(
3898 *this, LI, DT, Ops,
3899 [&](const APInt &C1, const APInt &C2) {
3900 switch (Kind) {
3901 case scSMaxExpr:
3902 return APIntOps::smax(C1, C2);
3903 case scSMinExpr:
3904 return APIntOps::smin(C1, C2);
3905 case scUMaxExpr:
3906 return APIntOps::umax(C1, C2);
3907 case scUMinExpr:
3908 return APIntOps::umin(C1, C2);
3909 default:
3910 llvm_unreachable("Unknown SCEV min/max opcode");
3911 }
3912 },
3913 [&](const APInt &C) {
3914 // identity
3915 if (IsMax)
3916 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3917 else
3918 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3919 },
3920 [&](const APInt &C) {
3921 // absorber
3922 if (IsMax)
3923 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3924 else
3925 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3926 });
3927 if (Folded)
3928 return Folded;
3929
3930 // Check if we have created the same expression before.
3931 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3932 return S;
3933 }
3934
3935 // Find the first operation of the same kind
3936 unsigned Idx = 0;
3937 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3938 ++Idx;
3939
3940 // Check to see if one of the operands is of the same kind. If so, expand its
3941 // operands onto our operand list, and recurse to simplify.
3942 if (Idx < Ops.size()) {
3943 bool DeletedAny = false;
3944 while (Ops[Idx]->getSCEVType() == Kind) {
3945 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3946 Ops.erase(Ops.begin()+Idx);
3947 append_range(Ops, SMME->operands());
3948 DeletedAny = true;
3949 }
3950
3951 if (DeletedAny)
3952 return getMinMaxExpr(Kind, Ops);
3953 }
3954
3955 // Okay, check to see if the same value occurs in the operand list twice. If
3956 // so, delete one. Since we sorted the list, these values are required to
3957 // be adjacent.
3962 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3963 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3964 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3965 if (Ops[i] == Ops[i + 1] ||
3966 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3967 // X op Y op Y --> X op Y
3968 // X op Y --> X, if we know X, Y are ordered appropriately
3969 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3970 --i;
3971 --e;
3972 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3973 Ops[i + 1])) {
3974 // X op Y --> Y, if we know X, Y are ordered appropriately
3975 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3976 --i;
3977 --e;
3978 }
3979 }
3980
3981 if (Ops.size() == 1) return Ops[0];
3982
3983 assert(!Ops.empty() && "Reduced smax down to nothing!");
3984
3985 // Okay, it looks like we really DO need an expr. Check to see if we
3986 // already have one, otherwise create a new one.
3988 ID.AddInteger(Kind);
3989 for (const SCEV *Op : Ops)
3990 ID.AddPointer(Op);
3991 void *IP = nullptr;
3992 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3993 if (ExistingSCEV)
3994 return ExistingSCEV;
3995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3997 SCEV *S = new (SCEVAllocator)
3998 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3999
4000 UniqueSCEVs.InsertNode(S, IP);
4001 registerUser(S, Ops);
4002 return S;
4003}
4004
4005namespace {
4006
4007class SCEVSequentialMinMaxDeduplicatingVisitor final
4008 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4009 std::optional<const SCEV *>> {
4010 using RetVal = std::optional<const SCEV *>;
4012
4013 ScalarEvolution &SE;
4014 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4015 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4017
4018 bool canRecurseInto(SCEVTypes Kind) const {
4019 // We can only recurse into the SCEV expression of the same effective type
4020 // as the type of our root SCEV expression.
4021 return RootKind == Kind || NonSequentialRootKind == Kind;
4022 };
4023
4024 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4026 "Only for min/max expressions.");
4027 SCEVTypes Kind = S->getSCEVType();
4028
4029 if (!canRecurseInto(Kind))
4030 return S;
4031
4032 auto *NAry = cast<SCEVNAryExpr>(S);
4034 bool Changed = visit(Kind, NAry->operands(), NewOps);
4035
4036 if (!Changed)
4037 return S;
4038 if (NewOps.empty())
4039 return std::nullopt;
4040
4042 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4043 : SE.getMinMaxExpr(Kind, NewOps);
4044 }
4045
4046 RetVal visit(const SCEV *S) {
4047 // Has the whole operand been seen already?
4048 if (!SeenOps.insert(S).second)
4049 return std::nullopt;
4050 return Base::visit(S);
4051 }
4052
4053public:
4054 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4055 SCEVTypes RootKind)
4056 : SE(SE), RootKind(RootKind),
4057 NonSequentialRootKind(
4058 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4059 RootKind)) {}
4060
4061 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4062 SmallVectorImpl<const SCEV *> &NewOps) {
4063 bool Changed = false;
4065 Ops.reserve(OrigOps.size());
4066
4067 for (const SCEV *Op : OrigOps) {
4068 RetVal NewOp = visit(Op);
4069 if (NewOp != Op)
4070 Changed = true;
4071 if (NewOp)
4072 Ops.emplace_back(*NewOp);
4073 }
4074
4075 if (Changed)
4076 NewOps = std::move(Ops);
4077 return Changed;
4078 }
4079
4080 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4081
4082 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4083
4084 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4085
4086 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4087
4088 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4089
4090 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4091
4092 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4093
4094 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4095
4096 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4097
4098 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4099
4100 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4101 return visitAnyMinMaxExpr(Expr);
4102 }
4103
4104 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4105 return visitAnyMinMaxExpr(Expr);
4106 }
4107
4108 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4109 return visitAnyMinMaxExpr(Expr);
4110 }
4111
4112 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4113 return visitAnyMinMaxExpr(Expr);
4114 }
4115
4116 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4117 return visitAnyMinMaxExpr(Expr);
4118 }
4119
4120 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4121
4122 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4123};
4124
4125} // namespace
4126
4128 switch (Kind) {
4129 case scConstant:
4130 case scVScale:
4131 case scTruncate:
4132 case scZeroExtend:
4133 case scSignExtend:
4134 case scPtrToInt:
4135 case scAddExpr:
4136 case scMulExpr:
4137 case scUDivExpr:
4138 case scAddRecExpr:
4139 case scUMaxExpr:
4140 case scSMaxExpr:
4141 case scUMinExpr:
4142 case scSMinExpr:
4143 case scUnknown:
4144 // If any operand is poison, the whole expression is poison.
4145 return true;
4147 // FIXME: if the *first* operand is poison, the whole expression is poison.
4148 return false; // Pessimistically, say that it does not propagate poison.
4149 case scCouldNotCompute:
4150 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4151 }
4152 llvm_unreachable("Unknown SCEV kind!");
4153}
4154
4155namespace {
4156// The only way poison may be introduced in a SCEV expression is from a
4157// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4158// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4159// introduce poison -- they encode guaranteed, non-speculated knowledge.
4160//
4161// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4162// with the notable exception of umin_seq, where only poison from the first
4163// operand is (unconditionally) propagated.
4164struct SCEVPoisonCollector {
4165 bool LookThroughMaybePoisonBlocking;
4166 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4167 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4168 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4169
4170 bool follow(const SCEV *S) {
4171 if (!LookThroughMaybePoisonBlocking &&
4173 return false;
4174
4175 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4176 if (!isGuaranteedNotToBePoison(SU->getValue()))
4177 MaybePoison.insert(SU);
4178 }
4179 return true;
4180 }
4181 bool isDone() const { return false; }
4182};
4183} // namespace
4184
4185/// Return true if V is poison given that AssumedPoison is already poison.
4186static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4187 // First collect all SCEVs that might result in AssumedPoison to be poison.
4188 // We need to look through potentially poison-blocking operations here,
4189 // because we want to find all SCEVs that *might* result in poison, not only
4190 // those that are *required* to.
4191 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4192 visitAll(AssumedPoison, PC1);
4193
4194 // AssumedPoison is never poison. As the assumption is false, the implication
4195 // is true. Don't bother walking the other SCEV in this case.
4196 if (PC1.MaybePoison.empty())
4197 return true;
4198
4199 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4200 // as well. We cannot look through potentially poison-blocking operations
4201 // here, as their arguments only *may* make the result poison.
4202 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4203 visitAll(S, PC2);
4204
4205 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4206 // it will also make S poison by being part of PC2.MaybePoison.
4207 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4208}
4209
4211 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4212 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4213 visitAll(S, PC);
4214 for (const SCEVUnknown *SU : PC.MaybePoison)
4215 Result.insert(SU->getValue());
4216}
4217
4219 const SCEV *S, Instruction *I,
4220 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4221 // If the instruction cannot be poison, it's always safe to reuse.
4223 return true;
4224
4225 // Otherwise, it is possible that I is more poisonous that S. Collect the
4226 // poison-contributors of S, and then check whether I has any additional
4227 // poison-contributors. Poison that is contributed through poison-generating
4228 // flags is handled by dropping those flags instead.
4230 getPoisonGeneratingValues(PoisonVals, S);
4231
4232 SmallVector<Value *> Worklist;
4234 Worklist.push_back(I);
4235 while (!Worklist.empty()) {
4236 Value *V = Worklist.pop_back_val();
4237 if (!Visited.insert(V).second)
4238 continue;
4239
4240 // Avoid walking large instruction graphs.
4241 if (Visited.size() > 16)
4242 return false;
4243
4244 // Either the value can't be poison, or the S would also be poison if it
4245 // is.
4246 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4247 continue;
4248
4249 auto *I = dyn_cast<Instruction>(V);
4250 if (!I)
4251 return false;
4252
4253 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4254 // can't replace an arbitrary add with disjoint or, even if we drop the
4255 // flag. We would need to convert the or into an add.
4256 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4257 if (PDI->isDisjoint())
4258 return false;
4259
4260 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4261 // because SCEV currently assumes it can't be poison. Remove this special
4262 // case once we proper model when vscale can be poison.
4263 if (auto *II = dyn_cast<IntrinsicInst>(I);
4264 II && II->getIntrinsicID() == Intrinsic::vscale)
4265 continue;
4266
4267 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4268 return false;
4269
4270 // If the instruction can't create poison, we can recurse to its operands.
4271 if (I->hasPoisonGeneratingAnnotations())
4272 DropPoisonGeneratingInsts.push_back(I);
4273
4274 llvm::append_range(Worklist, I->operands());
4275 }
4276 return true;
4277}
4278
4279const SCEV *
4282 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4283 "Not a SCEVSequentialMinMaxExpr!");
4284 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4285 if (Ops.size() == 1)
4286 return Ops[0];
4287#ifndef NDEBUG
4288 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4289 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4290 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4291 "Operand types don't match!");
4292 assert(Ops[0]->getType()->isPointerTy() ==
4293 Ops[i]->getType()->isPointerTy() &&
4294 "min/max should be consistently pointerish");
4295 }
4296#endif
4297
4298 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4299 // so we can *NOT* do any kind of sorting of the expressions!
4300
4301 // Check if we have created the same expression before.
4302 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4303 return S;
4304
4305 // FIXME: there are *some* simplifications that we can do here.
4306
4307 // Keep only the first instance of an operand.
4308 {
4309 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4310 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4311 if (Changed)
4312 return getSequentialMinMaxExpr(Kind, Ops);
4313 }
4314
4315 // Check to see if one of the operands is of the same kind. If so, expand its
4316 // operands onto our operand list, and recurse to simplify.
4317 {
4318 unsigned Idx = 0;
4319 bool DeletedAny = false;
4320 while (Idx < Ops.size()) {
4321 if (Ops[Idx]->getSCEVType() != Kind) {
4322 ++Idx;
4323 continue;
4324 }
4325 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4326 Ops.erase(Ops.begin() + Idx);
4327 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4328 SMME->operands().end());
4329 DeletedAny = true;
4330 }
4331
4332 if (DeletedAny)
4333 return getSequentialMinMaxExpr(Kind, Ops);
4334 }
4335
4336 const SCEV *SaturationPoint;
4338 switch (Kind) {
4340 SaturationPoint = getZero(Ops[0]->getType());
4341 Pred = ICmpInst::ICMP_ULE;
4342 break;
4343 default:
4344 llvm_unreachable("Not a sequential min/max type.");
4345 }
4346
4347 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4348 if (!isGuaranteedNotToCauseUB(Ops[i]))
4349 continue;
4350 // We can replace %x umin_seq %y with %x umin %y if either:
4351 // * %y being poison implies %x is also poison.
4352 // * %x cannot be the saturating value (e.g. zero for umin).
4353 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4354 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4355 SaturationPoint)) {
4356 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4357 Ops[i - 1] = getMinMaxExpr(
4359 SeqOps);
4360 Ops.erase(Ops.begin() + i);
4361 return getSequentialMinMaxExpr(Kind, Ops);
4362 }
4363 // Fold %x umin_seq %y to %x if %x ule %y.
4364 // TODO: We might be able to prove the predicate for a later operand.
4365 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4366 Ops.erase(Ops.begin() + i);
4367 return getSequentialMinMaxExpr(Kind, Ops);
4368 }
4369 }
4370
4371 // Okay, it looks like we really DO need an expr. Check to see if we
4372 // already have one, otherwise create a new one.
4374 ID.AddInteger(Kind);
4375 for (const SCEV *Op : Ops)
4376 ID.AddPointer(Op);
4377 void *IP = nullptr;
4378 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4379 if (ExistingSCEV)
4380 return ExistingSCEV;
4381
4382 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4384 SCEV *S = new (SCEVAllocator)
4385 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4386
4387 UniqueSCEVs.InsertNode(S, IP);
4388 registerUser(S, Ops);
4389 return S;
4390}
4391
4392const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4393 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4394 return getSMaxExpr(Ops);
4395}
4396
4400
4401const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4402 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4403 return getUMaxExpr(Ops);
4404}
4405
4409
4411 const SCEV *RHS) {
4412 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4413 return getSMinExpr(Ops);
4414}
4415
4419
4420const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4421 bool Sequential) {
4422 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4423 return getUMinExpr(Ops, Sequential);
4424}
4425
4431
4432const SCEV *
4434 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4435 if (Size.isScalable())
4436 Res = getMulExpr(Res, getVScale(IntTy));
4437 return Res;
4438}
4439
4441 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4442}
4443
4445 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4446}
4447
4449 StructType *STy,
4450 unsigned FieldNo) {
4451 // We can bypass creating a target-independent constant expression and then
4452 // folding it back into a ConstantInt. This is just a compile-time
4453 // optimization.
4454 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4455 assert(!SL->getSizeInBits().isScalable() &&
4456 "Cannot get offset for structure containing scalable vector types");
4457 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4458}
4459
4461 // Don't attempt to do anything other than create a SCEVUnknown object
4462 // here. createSCEV only calls getUnknown after checking for all other
4463 // interesting possibilities, and any other code that calls getUnknown
4464 // is doing so in order to hide a value from SCEV canonicalization.
4465
4467 ID.AddInteger(scUnknown);
4468 ID.AddPointer(V);
4469 void *IP = nullptr;
4470 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4471 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4472 "Stale SCEVUnknown in uniquing map!");
4473 return S;
4474 }
4475 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4476 FirstUnknown);
4477 FirstUnknown = cast<SCEVUnknown>(S);
4478 UniqueSCEVs.InsertNode(S, IP);
4479 return S;
4480}
4481
4482//===----------------------------------------------------------------------===//
4483// Basic SCEV Analysis and PHI Idiom Recognition Code
4484//
4485
4486/// Test if values of the given type are analyzable within the SCEV
4487/// framework. This primarily includes integer types, and it can optionally
4488/// include pointer types if the ScalarEvolution class has access to
4489/// target-specific information.
4491 // Integers and pointers are always SCEVable.
4492 return Ty->isIntOrPtrTy();
4493}
4494
4495/// Return the size in bits of the specified type, for which isSCEVable must
4496/// return true.
4498 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4499 if (Ty->isPointerTy())
4501 return getDataLayout().getTypeSizeInBits(Ty);
4502}
4503
4504/// Return a type with the same bitwidth as the given type and which represents
4505/// how SCEV will treat the given type, for which isSCEVable must return
4506/// true. For pointer types, this is the pointer index sized integer type.
4508 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4509
4510 if (Ty->isIntegerTy())
4511 return Ty;
4512
4513 // The only other support type is pointer.
4514 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4515 return getDataLayout().getIndexType(Ty);
4516}
4517
4519 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4520}
4521
4523 const SCEV *B) {
4524 /// For a valid use point to exist, the defining scope of one operand
4525 /// must dominate the other.
4526 bool PreciseA, PreciseB;
4527 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4528 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4529 if (!PreciseA || !PreciseB)
4530 // Can't tell.
4531 return false;
4532 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4533 DT.dominates(ScopeB, ScopeA);
4534}
4535
4537 return CouldNotCompute.get();
4538}
4539
4540bool ScalarEvolution::checkValidity(const SCEV *S) const {
4541 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4542 auto *SU = dyn_cast<SCEVUnknown>(S);
4543 return SU && SU->getValue() == nullptr;
4544 });
4545
4546 return !ContainsNulls;
4547}
4548
4550 HasRecMapType::iterator I = HasRecMap.find(S);
4551 if (I != HasRecMap.end())
4552 return I->second;
4553
4554 bool FoundAddRec =
4555 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4556 HasRecMap.insert({S, FoundAddRec});
4557 return FoundAddRec;
4558}
4559
4560/// Return the ValueOffsetPair set for \p S. \p S can be represented
4561/// by the value and offset from any ValueOffsetPair in the set.
4562ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4563 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4564 if (SI == ExprValueMap.end())
4565 return {};
4566 return SI->second.getArrayRef();
4567}
4568
4569/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4570/// cannot be used separately. eraseValueFromMap should be used to remove
4571/// V from ValueExprMap and ExprValueMap at the same time.
4572void ScalarEvolution::eraseValueFromMap(Value *V) {
4573 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4574 if (I != ValueExprMap.end()) {
4575 auto EVIt = ExprValueMap.find(I->second);
4576 bool Removed = EVIt->second.remove(V);
4577 (void) Removed;
4578 assert(Removed && "Value not in ExprValueMap?");
4579 ValueExprMap.erase(I);
4580 }
4581}
4582
4583void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4584 // A recursive query may have already computed the SCEV. It should be
4585 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4586 // inferred nowrap flags.
4587 auto It = ValueExprMap.find_as(V);
4588 if (It == ValueExprMap.end()) {
4589 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4590 ExprValueMap[S].insert(V);
4591 }
4592}
4593
4594/// Return an existing SCEV if it exists, otherwise analyze the expression and
4595/// create a new one.
4597 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4598
4599 if (const SCEV *S = getExistingSCEV(V))
4600 return S;
4601 return createSCEVIter(V);
4602}
4603
4605 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4606
4607 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4608 if (I != ValueExprMap.end()) {
4609 const SCEV *S = I->second;
4610 assert(checkValidity(S) &&
4611 "existing SCEV has not been properly invalidated");
4612 return S;
4613 }
4614 return nullptr;
4615}
4616
4617/// Return a SCEV corresponding to -V = -1*V
4619 SCEV::NoWrapFlags Flags) {
4620 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4621 return getConstant(
4622 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4623
4624 Type *Ty = V->getType();
4625 Ty = getEffectiveSCEVType(Ty);
4626 return getMulExpr(V, getMinusOne(Ty), Flags);
4627}
4628
4629/// If Expr computes ~A, return A else return nullptr
4630static const SCEV *MatchNotExpr(const SCEV *Expr) {
4631 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4632 if (!Add || Add->getNumOperands() != 2 ||
4633 !Add->getOperand(0)->isAllOnesValue())
4634 return nullptr;
4635
4636 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4637 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4638 !AddRHS->getOperand(0)->isAllOnesValue())
4639 return nullptr;
4640
4641 return AddRHS->getOperand(1);
4642}
4643
4644/// Return a SCEV corresponding to ~V = -1-V
4646 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4647
4648 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4649 return getConstant(
4650 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4651
4652 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4653 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4654 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4655 SmallVector<const SCEV *, 2> MatchedOperands;
4656 for (const SCEV *Operand : MME->operands()) {
4657 const SCEV *Matched = MatchNotExpr(Operand);
4658 if (!Matched)
4659 return (const SCEV *)nullptr;
4660 MatchedOperands.push_back(Matched);
4661 }
4662 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4663 MatchedOperands);
4664 };
4665 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4666 return Replaced;
4667 }
4668
4669 Type *Ty = V->getType();
4670 Ty = getEffectiveSCEVType(Ty);
4671 return getMinusSCEV(getMinusOne(Ty), V);
4672}
4673
4675 assert(P->getType()->isPointerTy());
4676
4677 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4678 // The base of an AddRec is the first operand.
4679 SmallVector<const SCEV *> Ops{AddRec->operands()};
4680 Ops[0] = removePointerBase(Ops[0]);
4681 // Don't try to transfer nowrap flags for now. We could in some cases
4682 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4683 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4684 }
4685 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4686 // The base of an Add is the pointer operand.
4687 SmallVector<const SCEV *> Ops{Add->operands()};
4688 const SCEV **PtrOp = nullptr;
4689 for (const SCEV *&AddOp : Ops) {
4690 if (AddOp->getType()->isPointerTy()) {
4691 assert(!PtrOp && "Cannot have multiple pointer ops");
4692 PtrOp = &AddOp;
4693 }
4694 }
4695 *PtrOp = removePointerBase(*PtrOp);
4696 // Don't try to transfer nowrap flags for now. We could in some cases
4697 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4698 return getAddExpr(Ops);
4699 }
4700 // Any other expression must be a pointer base.
4701 return getZero(P->getType());
4702}
4703
4704const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4705 SCEV::NoWrapFlags Flags,
4706 unsigned Depth) {
4707 // Fast path: X - X --> 0.
4708 if (LHS == RHS)
4709 return getZero(LHS->getType());
4710
4711 // If we subtract two pointers with different pointer bases, bail.
4712 // Eventually, we're going to add an assertion to getMulExpr that we
4713 // can't multiply by a pointer.
4714 if (RHS->getType()->isPointerTy()) {
4715 if (!LHS->getType()->isPointerTy() ||
4716 getPointerBase(LHS) != getPointerBase(RHS))
4717 return getCouldNotCompute();
4718 LHS = removePointerBase(LHS);
4719 RHS = removePointerBase(RHS);
4720 }
4721
4722 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4723 // makes it so that we cannot make much use of NUW.
4724 auto AddFlags = SCEV::FlagAnyWrap;
4725 const bool RHSIsNotMinSigned =
4727 if (hasFlags(Flags, SCEV::FlagNSW)) {
4728 // Let M be the minimum representable signed value. Then (-1)*RHS
4729 // signed-wraps if and only if RHS is M. That can happen even for
4730 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4731 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4732 // (-1)*RHS, we need to prove that RHS != M.
4733 //
4734 // If LHS is non-negative and we know that LHS - RHS does not
4735 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4736 // either by proving that RHS > M or that LHS >= 0.
4737 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4738 AddFlags = SCEV::FlagNSW;
4739 }
4740 }
4741
4742 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4743 // RHS is NSW and LHS >= 0.
4744 //
4745 // The difficulty here is that the NSW flag may have been proven
4746 // relative to a loop that is to be found in a recurrence in LHS and
4747 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4748 // larger scope than intended.
4749 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4750
4751 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4752}
4753
4755 unsigned Depth) {
4756 Type *SrcTy = V->getType();
4757 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4758 "Cannot truncate or zero extend with non-integer arguments!");
4759 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4760 return V; // No conversion
4761 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4762 return getTruncateExpr(V, Ty, Depth);
4763 return getZeroExtendExpr(V, Ty, Depth);
4764}
4765
4767 unsigned Depth) {
4768 Type *SrcTy = V->getType();
4769 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4770 "Cannot truncate or zero extend with non-integer arguments!");
4771 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4772 return V; // No conversion
4773 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4774 return getTruncateExpr(V, Ty, Depth);
4775 return getSignExtendExpr(V, Ty, Depth);
4776}
4777
4778const SCEV *
4780 Type *SrcTy = V->getType();
4781 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4782 "Cannot noop or zero extend with non-integer arguments!");
4784 "getNoopOrZeroExtend cannot truncate!");
4785 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4786 return V; // No conversion
4787 return getZeroExtendExpr(V, Ty);
4788}
4789
4790const SCEV *
4792 Type *SrcTy = V->getType();
4793 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4794 "Cannot noop or sign extend with non-integer arguments!");
4796 "getNoopOrSignExtend cannot truncate!");
4797 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4798 return V; // No conversion
4799 return getSignExtendExpr(V, Ty);
4800}
4801
4802const SCEV *
4804 Type *SrcTy = V->getType();
4805 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4806 "Cannot noop or any extend with non-integer arguments!");
4808 "getNoopOrAnyExtend cannot truncate!");
4809 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4810 return V; // No conversion
4811 return getAnyExtendExpr(V, Ty);
4812}
4813
4814const SCEV *
4816 Type *SrcTy = V->getType();
4817 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4818 "Cannot truncate or noop with non-integer arguments!");
4820 "getTruncateOrNoop cannot extend!");
4821 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4822 return V; // No conversion
4823 return getTruncateExpr(V, Ty);
4824}
4825
4827 const SCEV *RHS) {
4828 const SCEV *PromotedLHS = LHS;
4829 const SCEV *PromotedRHS = RHS;
4830
4831 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4832 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4833 else
4834 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4835
4836 return getUMaxExpr(PromotedLHS, PromotedRHS);
4837}
4838
4840 const SCEV *RHS,
4841 bool Sequential) {
4842 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4843 return getUMinFromMismatchedTypes(Ops, Sequential);
4844}
4845
4846const SCEV *
4848 bool Sequential) {
4849 assert(!Ops.empty() && "At least one operand must be!");
4850 // Trivial case.
4851 if (Ops.size() == 1)
4852 return Ops[0];
4853
4854 // Find the max type first.
4855 Type *MaxType = nullptr;
4856 for (const auto *S : Ops)
4857 if (MaxType)
4858 MaxType = getWiderType(MaxType, S->getType());
4859 else
4860 MaxType = S->getType();
4861 assert(MaxType && "Failed to find maximum type!");
4862
4863 // Extend all ops to max type.
4864 SmallVector<const SCEV *, 2> PromotedOps;
4865 for (const auto *S : Ops)
4866 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4867
4868 // Generate umin.
4869 return getUMinExpr(PromotedOps, Sequential);
4870}
4871
4873 // A pointer operand may evaluate to a nonpointer expression, such as null.
4874 if (!V->getType()->isPointerTy())
4875 return V;
4876
4877 while (true) {
4878 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4879 V = AddRec->getStart();
4880 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4881 const SCEV *PtrOp = nullptr;
4882 for (const SCEV *AddOp : Add->operands()) {
4883 if (AddOp->getType()->isPointerTy()) {
4884 assert(!PtrOp && "Cannot have multiple pointer ops");
4885 PtrOp = AddOp;
4886 }
4887 }
4888 assert(PtrOp && "Must have pointer op");
4889 V = PtrOp;
4890 } else // Not something we can look further into.
4891 return V;
4892 }
4893}
4894
4895/// Push users of the given Instruction onto the given Worklist.
4899 // Push the def-use children onto the Worklist stack.
4900 for (User *U : I->users()) {
4901 auto *UserInsn = cast<Instruction>(U);
4902 if (Visited.insert(UserInsn).second)
4903 Worklist.push_back(UserInsn);
4904 }
4905}
4906
4907namespace {
4908
4909/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4910/// expression in case its Loop is L. If it is not L then
4911/// if IgnoreOtherLoops is true then use AddRec itself
4912/// otherwise rewrite cannot be done.
4913/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4914class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4915public:
4916 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4917 bool IgnoreOtherLoops = true) {
4918 SCEVInitRewriter Rewriter(L, SE);
4919 const SCEV *Result = Rewriter.visit(S);
4920 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4921 return SE.getCouldNotCompute();
4922 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4923 ? SE.getCouldNotCompute()
4924 : Result;
4925 }
4926
4927 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4928 if (!SE.isLoopInvariant(Expr, L))
4929 SeenLoopVariantSCEVUnknown = true;
4930 return Expr;
4931 }
4932
4933 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4934 // Only re-write AddRecExprs for this loop.
4935 if (Expr->getLoop() == L)
4936 return Expr->getStart();
4937 SeenOtherLoops = true;
4938 return Expr;
4939 }
4940
4941 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4942
4943 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4944
4945private:
4946 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4947 : SCEVRewriteVisitor(SE), L(L) {}
4948
4949 const Loop *L;
4950 bool SeenLoopVariantSCEVUnknown = false;
4951 bool SeenOtherLoops = false;
4952};
4953
4954/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4955/// increment expression in case its Loop is L. If it is not L then
4956/// use AddRec itself.
4957/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4958class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4959public:
4960 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4961 SCEVPostIncRewriter Rewriter(L, SE);
4962 const SCEV *Result = Rewriter.visit(S);
4963 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4964 ? SE.getCouldNotCompute()
4965 : Result;
4966 }
4967
4968 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4969 if (!SE.isLoopInvariant(Expr, L))
4970 SeenLoopVariantSCEVUnknown = true;
4971 return Expr;
4972 }
4973
4974 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4975 // Only re-write AddRecExprs for this loop.
4976 if (Expr->getLoop() == L)
4977 return Expr->getPostIncExpr(SE);
4978 SeenOtherLoops = true;
4979 return Expr;
4980 }
4981
4982 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4983
4984 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4985
4986private:
4987 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4988 : SCEVRewriteVisitor(SE), L(L) {}
4989
4990 const Loop *L;
4991 bool SeenLoopVariantSCEVUnknown = false;
4992 bool SeenOtherLoops = false;
4993};
4994
4995/// This class evaluates the compare condition by matching it against the
4996/// condition of loop latch. If there is a match we assume a true value
4997/// for the condition while building SCEV nodes.
4998class SCEVBackedgeConditionFolder
4999 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5000public:
5001 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5002 ScalarEvolution &SE) {
5003 bool IsPosBECond = false;
5004 Value *BECond = nullptr;
5005 if (BasicBlock *Latch = L->getLoopLatch()) {
5006 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5007 if (BI && BI->isConditional()) {
5008 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5009 "Both outgoing branches should not target same header!");
5010 BECond = BI->getCondition();
5011 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5012 } else {
5013 return S;
5014 }
5015 }
5016 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5017 return Rewriter.visit(S);
5018 }
5019
5020 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5021 const SCEV *Result = Expr;
5022 bool InvariantF = SE.isLoopInvariant(Expr, L);
5023
5024 if (!InvariantF) {
5026 switch (I->getOpcode()) {
5027 case Instruction::Select: {
5028 SelectInst *SI = cast<SelectInst>(I);
5029 std::optional<const SCEV *> Res =
5030 compareWithBackedgeCondition(SI->getCondition());
5031 if (Res) {
5032 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5033 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5034 }
5035 break;
5036 }
5037 default: {
5038 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5039 if (Res)
5040 Result = *Res;
5041 break;
5042 }
5043 }
5044 }
5045 return Result;
5046 }
5047
5048private:
5049 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5050 bool IsPosBECond, ScalarEvolution &SE)
5051 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5052 IsPositiveBECond(IsPosBECond) {}
5053
5054 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5055
5056 const Loop *L;
5057 /// Loop back condition.
5058 Value *BackedgeCond = nullptr;
5059 /// Set to true if loop back is on positive branch condition.
5060 bool IsPositiveBECond;
5061};
5062
5063std::optional<const SCEV *>
5064SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5065
5066 // If value matches the backedge condition for loop latch,
5067 // then return a constant evolution node based on loopback
5068 // branch taken.
5069 if (BackedgeCond == IC)
5070 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5072 return std::nullopt;
5073}
5074
5075class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5076public:
5077 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5078 ScalarEvolution &SE) {
5079 SCEVShiftRewriter Rewriter(L, SE);
5080 const SCEV *Result = Rewriter.visit(S);
5081 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5082 }
5083
5084 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5085 // Only allow AddRecExprs for this loop.
5086 if (!SE.isLoopInvariant(Expr, L))
5087 Valid = false;
5088 return Expr;
5089 }
5090
5091 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5092 if (Expr->getLoop() == L && Expr->isAffine())
5093 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5094 Valid = false;
5095 return Expr;
5096 }
5097
5098 bool isValid() { return Valid; }
5099
5100private:
5101 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5102 : SCEVRewriteVisitor(SE), L(L) {}
5103
5104 const Loop *L;
5105 bool Valid = true;
5106};
5107
5108} // end anonymous namespace
5109
5111ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5112 if (!AR->isAffine())
5113 return SCEV::FlagAnyWrap;
5114
5115 using OBO = OverflowingBinaryOperator;
5116
5118
5119 if (!AR->hasNoSelfWrap()) {
5120 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5121 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5122 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5123 const APInt &BECountAP = BECountMax->getAPInt();
5124 unsigned NoOverflowBitWidth =
5125 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5126 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5128 }
5129 }
5130
5131 if (!AR->hasNoSignedWrap()) {
5132 ConstantRange AddRecRange = getSignedRange(AR);
5133 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5134
5136 Instruction::Add, IncRange, OBO::NoSignedWrap);
5137 if (NSWRegion.contains(AddRecRange))
5139 }
5140
5141 if (!AR->hasNoUnsignedWrap()) {
5142 ConstantRange AddRecRange = getUnsignedRange(AR);
5143 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5144
5146 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5147 if (NUWRegion.contains(AddRecRange))
5149 }
5150
5151 return Result;
5152}
5153
5155ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5157
5158 if (AR->hasNoSignedWrap())
5159 return Result;
5160
5161 if (!AR->isAffine())
5162 return Result;
5163
5164 // This function can be expensive, only try to prove NSW once per AddRec.
5165 if (!SignedWrapViaInductionTried.insert(AR).second)
5166 return Result;
5167
5168 const SCEV *Step = AR->getStepRecurrence(*this);
5169 const Loop *L = AR->getLoop();
5170
5171 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5172 // Note that this serves two purposes: It filters out loops that are
5173 // simply not analyzable, and it covers the case where this code is
5174 // being called from within backedge-taken count analysis, such that
5175 // attempting to ask for the backedge-taken count would likely result
5176 // in infinite recursion. In the later case, the analysis code will
5177 // cope with a conservative value, and it will take care to purge
5178 // that value once it has finished.
5179 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5180
5181 // Normally, in the cases we can prove no-overflow via a
5182 // backedge guarding condition, we can also compute a backedge
5183 // taken count for the loop. The exceptions are assumptions and
5184 // guards present in the loop -- SCEV is not great at exploiting
5185 // these to compute max backedge taken counts, but can still use
5186 // these to prove lack of overflow. Use this fact to avoid
5187 // doing extra work that may not pay off.
5188
5189 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5190 AC.assumptions().empty())
5191 return Result;
5192
5193 // If the backedge is guarded by a comparison with the pre-inc value the
5194 // addrec is safe. Also, if the entry is guarded by a comparison with the
5195 // start value and the backedge is guarded by a comparison with the post-inc
5196 // value, the addrec is safe.
5198 const SCEV *OverflowLimit =
5199 getSignedOverflowLimitForStep(Step, &Pred, this);
5200 if (OverflowLimit &&
5201 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5202 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5203 Result = setFlags(Result, SCEV::FlagNSW);
5204 }
5205 return Result;
5206}
5208ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5210
5211 if (AR->hasNoUnsignedWrap())
5212 return Result;
5213
5214 if (!AR->isAffine())
5215 return Result;
5216
5217 // This function can be expensive, only try to prove NUW once per AddRec.
5218 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5219 return Result;
5220
5221 const SCEV *Step = AR->getStepRecurrence(*this);
5222 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5223 const Loop *L = AR->getLoop();
5224
5225 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5226 // Note that this serves two purposes: It filters out loops that are
5227 // simply not analyzable, and it covers the case where this code is
5228 // being called from within backedge-taken count analysis, such that
5229 // attempting to ask for the backedge-taken count would likely result
5230 // in infinite recursion. In the later case, the analysis code will
5231 // cope with a conservative value, and it will take care to purge
5232 // that value once it has finished.
5233 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5234
5235 // Normally, in the cases we can prove no-overflow via a
5236 // backedge guarding condition, we can also compute a backedge
5237 // taken count for the loop. The exceptions are assumptions and
5238 // guards present in the loop -- SCEV is not great at exploiting
5239 // these to compute max backedge taken counts, but can still use
5240 // these to prove lack of overflow. Use this fact to avoid
5241 // doing extra work that may not pay off.
5242
5243 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5244 AC.assumptions().empty())
5245 return Result;
5246
5247 // If the backedge is guarded by a comparison with the pre-inc value the
5248 // addrec is safe. Also, if the entry is guarded by a comparison with the
5249 // start value and the backedge is guarded by a comparison with the post-inc
5250 // value, the addrec is safe.
5251 if (isKnownPositive(Step)) {
5252 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5253 getUnsignedRangeMax(Step));
5256 Result = setFlags(Result, SCEV::FlagNUW);
5257 }
5258 }
5259
5260 return Result;
5261}
5262
5263namespace {
5264
5265/// Represents an abstract binary operation. This may exist as a
5266/// normal instruction or constant expression, or may have been
5267/// derived from an expression tree.
5268struct BinaryOp {
5269 unsigned Opcode;
5270 Value *LHS;
5271 Value *RHS;
5272 bool IsNSW = false;
5273 bool IsNUW = false;
5274
5275 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5276 /// constant expression.
5277 Operator *Op = nullptr;
5278
5279 explicit BinaryOp(Operator *Op)
5280 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5281 Op(Op) {
5282 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5283 IsNSW = OBO->hasNoSignedWrap();
5284 IsNUW = OBO->hasNoUnsignedWrap();
5285 }
5286 }
5287
5288 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5289 bool IsNUW = false)
5290 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5291};
5292
5293} // end anonymous namespace
5294
5295/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5296static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5297 AssumptionCache &AC,
5298 const DominatorTree &DT,
5299 const Instruction *CxtI) {
5300 auto *Op = dyn_cast<Operator>(V);
5301 if (!Op)
5302 return std::nullopt;
5303
5304 // Implementation detail: all the cleverness here should happen without
5305 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5306 // SCEV expressions when possible, and we should not break that.
5307
5308 switch (Op->getOpcode()) {
5309 case Instruction::Add:
5310 case Instruction::Sub:
5311 case Instruction::Mul:
5312 case Instruction::UDiv:
5313 case Instruction::URem:
5314 case Instruction::And:
5315 case Instruction::AShr:
5316 case Instruction::Shl:
5317 return BinaryOp(Op);
5318
5319 case Instruction::Or: {
5320 // Convert or disjoint into add nuw nsw.
5321 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5322 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5323 /*IsNSW=*/true, /*IsNUW=*/true);
5324 return BinaryOp(Op);
5325 }
5326
5327 case Instruction::Xor:
5328 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5329 // If the RHS of the xor is a signmask, then this is just an add.
5330 // Instcombine turns add of signmask into xor as a strength reduction step.
5331 if (RHSC->getValue().isSignMask())
5332 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5333 // Binary `xor` is a bit-wise `add`.
5334 if (V->getType()->isIntegerTy(1))
5335 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5336 return BinaryOp(Op);
5337
5338 case Instruction::LShr:
5339 // Turn logical shift right of a constant into a unsigned divide.
5340 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5341 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5342
5343 // If the shift count is not less than the bitwidth, the result of
5344 // the shift is undefined. Don't try to analyze it, because the
5345 // resolution chosen here may differ from the resolution chosen in
5346 // other parts of the compiler.
5347 if (SA->getValue().ult(BitWidth)) {
5348 Constant *X =
5349 ConstantInt::get(SA->getContext(),
5350 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5351 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5352 }
5353 }
5354 return BinaryOp(Op);
5355
5356 case Instruction::ExtractValue: {
5357 auto *EVI = cast<ExtractValueInst>(Op);
5358 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5359 break;
5360
5361 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5362 if (!WO)
5363 break;
5364
5365 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5366 bool Signed = WO->isSigned();
5367 // TODO: Should add nuw/nsw flags for mul as well.
5368 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5369 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5370
5371 // Now that we know that all uses of the arithmetic-result component of
5372 // CI are guarded by the overflow check, we can go ahead and pretend
5373 // that the arithmetic is non-overflowing.
5374 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5375 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5376 }
5377
5378 default:
5379 break;
5380 }
5381
5382 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5383 // semantics as a Sub, return a binary sub expression.
5384 if (auto *II = dyn_cast<IntrinsicInst>(V))
5385 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5386 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5387
5388 return std::nullopt;
5389}
5390
5391/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5392/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5393/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5394/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5395/// follows one of the following patterns:
5396/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5397/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5398/// If the SCEV expression of \p Op conforms with one of the expected patterns
5399/// we return the type of the truncation operation, and indicate whether the
5400/// truncated type should be treated as signed/unsigned by setting
5401/// \p Signed to true/false, respectively.
5402static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5403 bool &Signed, ScalarEvolution &SE) {
5404 // The case where Op == SymbolicPHI (that is, with no type conversions on
5405 // the way) is handled by the regular add recurrence creating logic and
5406 // would have already been triggered in createAddRecForPHI. Reaching it here
5407 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5408 // because one of the other operands of the SCEVAddExpr updating this PHI is
5409 // not invariant).
5410 //
5411 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5412 // this case predicates that allow us to prove that Op == SymbolicPHI will
5413 // be added.
5414 if (Op == SymbolicPHI)
5415 return nullptr;
5416
5417 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5418 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5419 if (SourceBits != NewBits)
5420 return nullptr;
5421
5424 if (!SExt && !ZExt)
5425 return nullptr;
5426 const SCEVTruncateExpr *Trunc =
5429 if (!Trunc)
5430 return nullptr;
5431 const SCEV *X = Trunc->getOperand();
5432 if (X != SymbolicPHI)
5433 return nullptr;
5434 Signed = SExt != nullptr;
5435 return Trunc->getType();
5436}
5437
5438static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5439 if (!PN->getType()->isIntegerTy())
5440 return nullptr;
5441 const Loop *L = LI.getLoopFor(PN->getParent());
5442 if (!L || L->getHeader() != PN->getParent())
5443 return nullptr;
5444 return L;
5445}
5446
5447// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5448// computation that updates the phi follows the following pattern:
5449// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5450// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5451// If so, try to see if it can be rewritten as an AddRecExpr under some
5452// Predicates. If successful, return them as a pair. Also cache the results
5453// of the analysis.
5454//
5455// Example usage scenario:
5456// Say the Rewriter is called for the following SCEV:
5457// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5458// where:
5459// %X = phi i64 (%Start, %BEValue)
5460// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5461// and call this function with %SymbolicPHI = %X.
5462//
5463// The analysis will find that the value coming around the backedge has
5464// the following SCEV:
5465// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5466// Upon concluding that this matches the desired pattern, the function
5467// will return the pair {NewAddRec, SmallPredsVec} where:
5468// NewAddRec = {%Start,+,%Step}
5469// SmallPredsVec = {P1, P2, P3} as follows:
5470// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5471// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5472// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5473// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5474// under the predicates {P1,P2,P3}.
5475// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5476// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5477//
5478// TODO's:
5479//
5480// 1) Extend the Induction descriptor to also support inductions that involve
5481// casts: When needed (namely, when we are called in the context of the
5482// vectorizer induction analysis), a Set of cast instructions will be
5483// populated by this method, and provided back to isInductionPHI. This is
5484// needed to allow the vectorizer to properly record them to be ignored by
5485// the cost model and to avoid vectorizing them (otherwise these casts,
5486// which are redundant under the runtime overflow checks, will be
5487// vectorized, which can be costly).
5488//
5489// 2) Support additional induction/PHISCEV patterns: We also want to support
5490// inductions where the sext-trunc / zext-trunc operations (partly) occur
5491// after the induction update operation (the induction increment):
5492//
5493// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5494// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5495//
5496// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5497// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5498//
5499// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5500std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5501ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5503
5504 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5505 // return an AddRec expression under some predicate.
5506
5507 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5508 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5509 assert(L && "Expecting an integer loop header phi");
5510
5511 // The loop may have multiple entrances or multiple exits; we can analyze
5512 // this phi as an addrec if it has a unique entry value and a unique
5513 // backedge value.
5514 Value *BEValueV = nullptr, *StartValueV = nullptr;
5515 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5516 Value *V = PN->getIncomingValue(i);
5517 if (L->contains(PN->getIncomingBlock(i))) {
5518 if (!BEValueV) {
5519 BEValueV = V;
5520 } else if (BEValueV != V) {
5521 BEValueV = nullptr;
5522 break;
5523 }
5524 } else if (!StartValueV) {
5525 StartValueV = V;
5526 } else if (StartValueV != V) {
5527 StartValueV = nullptr;
5528 break;
5529 }
5530 }
5531 if (!BEValueV || !StartValueV)
5532 return std::nullopt;
5533
5534 const SCEV *BEValue = getSCEV(BEValueV);
5535
5536 // If the value coming around the backedge is an add with the symbolic
5537 // value we just inserted, possibly with casts that we can ignore under
5538 // an appropriate runtime guard, then we found a simple induction variable!
5539 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5540 if (!Add)
5541 return std::nullopt;
5542
5543 // If there is a single occurrence of the symbolic value, possibly
5544 // casted, replace it with a recurrence.
5545 unsigned FoundIndex = Add->getNumOperands();
5546 Type *TruncTy = nullptr;
5547 bool Signed;
5548 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5549 if ((TruncTy =
5550 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5551 if (FoundIndex == e) {
5552 FoundIndex = i;
5553 break;
5554 }
5555
5556 if (FoundIndex == Add->getNumOperands())
5557 return std::nullopt;
5558
5559 // Create an add with everything but the specified operand.
5561 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5562 if (i != FoundIndex)
5563 Ops.push_back(Add->getOperand(i));
5564 const SCEV *Accum = getAddExpr(Ops);
5565
5566 // The runtime checks will not be valid if the step amount is
5567 // varying inside the loop.
5568 if (!isLoopInvariant(Accum, L))
5569 return std::nullopt;
5570
5571 // *** Part2: Create the predicates
5572
5573 // Analysis was successful: we have a phi-with-cast pattern for which we
5574 // can return an AddRec expression under the following predicates:
5575 //
5576 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5577 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5578 // P2: An Equal predicate that guarantees that
5579 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5580 // P3: An Equal predicate that guarantees that
5581 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5582 //
5583 // As we next prove, the above predicates guarantee that:
5584 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5585 //
5586 //
5587 // More formally, we want to prove that:
5588 // Expr(i+1) = Start + (i+1) * Accum
5589 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5590 //
5591 // Given that:
5592 // 1) Expr(0) = Start
5593 // 2) Expr(1) = Start + Accum
5594 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5595 // 3) Induction hypothesis (step i):
5596 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5597 //
5598 // Proof:
5599 // Expr(i+1) =
5600 // = Start + (i+1)*Accum
5601 // = (Start + i*Accum) + Accum
5602 // = Expr(i) + Accum
5603 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5604 // :: from step i
5605 //
5606 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5607 //
5608 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5609 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5610 // + Accum :: from P3
5611 //
5612 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5613 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5614 //
5615 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5616 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5617 //
5618 // By induction, the same applies to all iterations 1<=i<n:
5619 //
5620
5621 // Create a truncated addrec for which we will add a no overflow check (P1).
5622 const SCEV *StartVal = getSCEV(StartValueV);
5623 const SCEV *PHISCEV =
5624 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5625 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5626
5627 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5628 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5629 // will be constant.
5630 //
5631 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5632 // add P1.
5633 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5637 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5638 Predicates.push_back(AddRecPred);
5639 }
5640
5641 // Create the Equal Predicates P2,P3:
5642
5643 // It is possible that the predicates P2 and/or P3 are computable at
5644 // compile time due to StartVal and/or Accum being constants.
5645 // If either one is, then we can check that now and escape if either P2
5646 // or P3 is false.
5647
5648 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5649 // for each of StartVal and Accum
5650 auto getExtendedExpr = [&](const SCEV *Expr,
5651 bool CreateSignExtend) -> const SCEV * {
5652 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5653 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5654 const SCEV *ExtendedExpr =
5655 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5656 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5657 return ExtendedExpr;
5658 };
5659
5660 // Given:
5661 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5662 // = getExtendedExpr(Expr)
5663 // Determine whether the predicate P: Expr == ExtendedExpr
5664 // is known to be false at compile time
5665 auto PredIsKnownFalse = [&](const SCEV *Expr,
5666 const SCEV *ExtendedExpr) -> bool {
5667 return Expr != ExtendedExpr &&
5668 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5669 };
5670
5671 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5672 if (PredIsKnownFalse(StartVal, StartExtended)) {
5673 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5674 return std::nullopt;
5675 }
5676
5677 // The Step is always Signed (because the overflow checks are either
5678 // NSSW or NUSW)
5679 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5680 if (PredIsKnownFalse(Accum, AccumExtended)) {
5681 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5682 return std::nullopt;
5683 }
5684
5685 auto AppendPredicate = [&](const SCEV *Expr,
5686 const SCEV *ExtendedExpr) -> void {
5687 if (Expr != ExtendedExpr &&
5688 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5689 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5690 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5691 Predicates.push_back(Pred);
5692 }
5693 };
5694
5695 AppendPredicate(StartVal, StartExtended);
5696 AppendPredicate(Accum, AccumExtended);
5697
5698 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5699 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5700 // into NewAR if it will also add the runtime overflow checks specified in
5701 // Predicates.
5702 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5703
5704 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5705 std::make_pair(NewAR, Predicates);
5706 // Remember the result of the analysis for this SCEV at this locayyytion.
5707 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5708 return PredRewrite;
5709}
5710
5711std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5713 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5714 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5715 if (!L)
5716 return std::nullopt;
5717
5718 // Check to see if we already analyzed this PHI.
5719 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5720 if (I != PredicatedSCEVRewrites.end()) {
5721 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5722 I->second;
5723 // Analysis was done before and failed to create an AddRec:
5724 if (Rewrite.first == SymbolicPHI)
5725 return std::nullopt;
5726 // Analysis was done before and succeeded to create an AddRec under
5727 // a predicate:
5728 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5729 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5730 return Rewrite;
5731 }
5732
5733 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5734 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5735
5736 // Record in the cache that the analysis failed
5737 if (!Rewrite) {
5739 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5740 return std::nullopt;
5741 }
5742
5743 return Rewrite;
5744}
5745
5746// FIXME: This utility is currently required because the Rewriter currently
5747// does not rewrite this expression:
5748// {0, +, (sext ix (trunc iy to ix) to iy)}
5749// into {0, +, %step},
5750// even when the following Equal predicate exists:
5751// "%step == (sext ix (trunc iy to ix) to iy)".
5753 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5754 if (AR1 == AR2)
5755 return true;
5756
5757 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5758 if (Expr1 != Expr2 &&
5759 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5760 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5761 return false;
5762 return true;
5763 };
5764
5765 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5766 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5767 return false;
5768 return true;
5769}
5770
5771/// A helper function for createAddRecFromPHI to handle simple cases.
5772///
5773/// This function tries to find an AddRec expression for the simplest (yet most
5774/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5775/// If it fails, createAddRecFromPHI will use a more general, but slow,
5776/// technique for finding the AddRec expression.
5777const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5778 Value *BEValueV,
5779 Value *StartValueV) {
5780 const Loop *L = LI.getLoopFor(PN->getParent());
5781 assert(L && L->getHeader() == PN->getParent());
5782 assert(BEValueV && StartValueV);
5783
5784 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5785 if (!BO)
5786 return nullptr;
5787
5788 if (BO->Opcode != Instruction::Add)
5789 return nullptr;
5790
5791 const SCEV *Accum = nullptr;
5792 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5793 Accum = getSCEV(BO->RHS);
5794 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5795 Accum = getSCEV(BO->LHS);
5796
5797 if (!Accum)
5798 return nullptr;
5799
5801 if (BO->IsNUW)
5802 Flags = setFlags(Flags, SCEV::FlagNUW);
5803 if (BO->IsNSW)
5804 Flags = setFlags(Flags, SCEV::FlagNSW);
5805
5806 const SCEV *StartVal = getSCEV(StartValueV);
5807 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5808 insertValueToMap(PN, PHISCEV);
5809
5810 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5811 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5813 proveNoWrapViaConstantRanges(AR)));
5814 }
5815
5816 // We can add Flags to the post-inc expression only if we
5817 // know that it is *undefined behavior* for BEValueV to
5818 // overflow.
5819 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5820 assert(isLoopInvariant(Accum, L) &&
5821 "Accum is defined outside L, but is not invariant?");
5822 if (isAddRecNeverPoison(BEInst, L))
5823 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5824 }
5825
5826 return PHISCEV;
5827}
5828
5829const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5830 const Loop *L = LI.getLoopFor(PN->getParent());
5831 if (!L || L->getHeader() != PN->getParent())
5832 return nullptr;
5833
5834 // The loop may have multiple entrances or multiple exits; we can analyze
5835 // this phi as an addrec if it has a unique entry value and a unique
5836 // backedge value.
5837 Value *BEValueV = nullptr, *StartValueV = nullptr;
5838 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5839 Value *V = PN->getIncomingValue(i);
5840 if (L->contains(PN->getIncomingBlock(i))) {
5841 if (!BEValueV) {
5842 BEValueV = V;
5843 } else if (BEValueV != V) {
5844 BEValueV = nullptr;
5845 break;
5846 }
5847 } else if (!StartValueV) {
5848 StartValueV = V;
5849 } else if (StartValueV != V) {
5850 StartValueV = nullptr;
5851 break;
5852 }
5853 }
5854 if (!BEValueV || !StartValueV)
5855 return nullptr;
5856
5857 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5858 "PHI node already processed?");
5859
5860 // First, try to find AddRec expression without creating a fictituos symbolic
5861 // value for PN.
5862 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5863 return S;
5864
5865 // Handle PHI node value symbolically.
5866 const SCEV *SymbolicName = getUnknown(PN);
5867 insertValueToMap(PN, SymbolicName);
5868
5869 // Using this symbolic name for the PHI, analyze the value coming around
5870 // the back-edge.
5871 const SCEV *BEValue = getSCEV(BEValueV);
5872
5873 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5874 // has a special value for the first iteration of the loop.
5875
5876 // If the value coming around the backedge is an add with the symbolic
5877 // value we just inserted, then we found a simple induction variable!
5878 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5879 // If there is a single occurrence of the symbolic value, replace it
5880 // with a recurrence.
5881 unsigned FoundIndex = Add->getNumOperands();
5882 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5883 if (Add->getOperand(i) == SymbolicName)
5884 if (FoundIndex == e) {
5885 FoundIndex = i;
5886 break;
5887 }
5888
5889 if (FoundIndex != Add->getNumOperands()) {
5890 // Create an add with everything but the specified operand.
5892 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5893 if (i != FoundIndex)
5894 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5895 L, *this));
5896 const SCEV *Accum = getAddExpr(Ops);
5897
5898 // This is not a valid addrec if the step amount is varying each
5899 // loop iteration, but is not itself an addrec in this loop.
5900 if (isLoopInvariant(Accum, L) ||
5901 (isa<SCEVAddRecExpr>(Accum) &&
5902 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5904
5905 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5906 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5907 if (BO->IsNUW)
5908 Flags = setFlags(Flags, SCEV::FlagNUW);
5909 if (BO->IsNSW)
5910 Flags = setFlags(Flags, SCEV::FlagNSW);
5911 }
5912 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5913 if (GEP->getOperand(0) == PN) {
5914 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5915 // If the increment has any nowrap flags, then we know the address
5916 // space cannot be wrapped around.
5917 if (NW != GEPNoWrapFlags::none())
5918 Flags = setFlags(Flags, SCEV::FlagNW);
5919 // If the GEP is nuw or nusw with non-negative offset, we know that
5920 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5921 // offset is treated as signed, while the base is unsigned.
5922 if (NW.hasNoUnsignedWrap() ||
5924 Flags = setFlags(Flags, SCEV::FlagNUW);
5925 }
5926
5927 // We cannot transfer nuw and nsw flags from subtraction
5928 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5929 // for instance.
5930 }
5931
5932 const SCEV *StartVal = getSCEV(StartValueV);
5933 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5934
5935 // Okay, for the entire analysis of this edge we assumed the PHI
5936 // to be symbolic. We now need to go back and purge all of the
5937 // entries for the scalars that use the symbolic expression.
5938 forgetMemoizedResults(SymbolicName);
5939 insertValueToMap(PN, PHISCEV);
5940
5941 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5942 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5944 proveNoWrapViaConstantRanges(AR)));
5945 }
5946
5947 // We can add Flags to the post-inc expression only if we
5948 // know that it is *undefined behavior* for BEValueV to
5949 // overflow.
5950 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5951 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5952 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5953
5954 return PHISCEV;
5955 }
5956 }
5957 } else {
5958 // Otherwise, this could be a loop like this:
5959 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5960 // In this case, j = {1,+,1} and BEValue is j.
5961 // Because the other in-value of i (0) fits the evolution of BEValue
5962 // i really is an addrec evolution.
5963 //
5964 // We can generalize this saying that i is the shifted value of BEValue
5965 // by one iteration:
5966 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5967
5968 // Do not allow refinement in rewriting of BEValue.
5969 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5970 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5971 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5972 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5973 const SCEV *StartVal = getSCEV(StartValueV);
5974 if (Start == StartVal) {
5975 // Okay, for the entire analysis of this edge we assumed the PHI
5976 // to be symbolic. We now need to go back and purge all of the
5977 // entries for the scalars that use the symbolic expression.
5978 forgetMemoizedResults(SymbolicName);
5979 insertValueToMap(PN, Shifted);
5980 return Shifted;
5981 }
5982 }
5983 }
5984
5985 // Remove the temporary PHI node SCEV that has been inserted while intending
5986 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5987 // as it will prevent later (possibly simpler) SCEV expressions to be added
5988 // to the ValueExprMap.
5989 eraseValueFromMap(PN);
5990
5991 return nullptr;
5992}
5993
5994// Try to match a control flow sequence that branches out at BI and merges back
5995// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5996// match.
5998 Value *&C, Value *&LHS, Value *&RHS) {
5999 C = BI->getCondition();
6000
6001 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6002 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6003
6004 if (!LeftEdge.isSingleEdge())
6005 return false;
6006
6007 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6008
6009 Use &LeftUse = Merge->getOperandUse(0);
6010 Use &RightUse = Merge->getOperandUse(1);
6011
6012 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6013 LHS = LeftUse;
6014 RHS = RightUse;
6015 return true;
6016 }
6017
6018 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6019 LHS = RightUse;
6020 RHS = LeftUse;
6021 return true;
6022 }
6023
6024 return false;
6025}
6026
6027const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6028 auto IsReachable =
6029 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6030 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6031 // Try to match
6032 //
6033 // br %cond, label %left, label %right
6034 // left:
6035 // br label %merge
6036 // right:
6037 // br label %merge
6038 // merge:
6039 // V = phi [ %x, %left ], [ %y, %right ]
6040 //
6041 // as "select %cond, %x, %y"
6042
6043 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6044 assert(IDom && "At least the entry block should dominate PN");
6045
6046 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6047 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6048
6049 if (BI && BI->isConditional() &&
6050 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6053 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6054 }
6055
6056 return nullptr;
6057}
6058
6059/// Returns SCEV for the first operand of a phi if all phi operands have
6060/// identical opcodes and operands
6061/// eg.
6062/// a: %add = %a + %b
6063/// br %c
6064/// b: %add1 = %a + %b
6065/// br %c
6066/// c: %phi = phi [%add, a], [%add1, b]
6067/// scev(%phi) => scev(%add)
6068const SCEV *
6069ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6070 BinaryOperator *CommonInst = nullptr;
6071 // Check if instructions are identical.
6072 for (Value *Incoming : PN->incoming_values()) {
6073 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6074 if (!IncomingInst)
6075 return nullptr;
6076 if (CommonInst) {
6077 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6078 return nullptr; // Not identical, give up
6079 } else {
6080 // Remember binary operator
6081 CommonInst = IncomingInst;
6082 }
6083 }
6084 if (!CommonInst)
6085 return nullptr;
6086
6087 // Check if SCEV exprs for instructions are identical.
6088 const SCEV *CommonSCEV = getSCEV(CommonInst);
6089 bool SCEVExprsIdentical =
6091 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6092 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6093}
6094
6095const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6096 if (const SCEV *S = createAddRecFromPHI(PN))
6097 return S;
6098
6099 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6100 // phi node for X.
6101 if (Value *V = simplifyInstruction(
6102 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6103 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6104 return getSCEV(V);
6105
6106 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6107 return S;
6108
6109 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6110 return S;
6111
6112 // If it's not a loop phi, we can't handle it yet.
6113 return getUnknown(PN);
6114}
6115
6116bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6117 SCEVTypes RootKind) {
6118 struct FindClosure {
6119 const SCEV *OperandToFind;
6120 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6121 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6122
6123 bool Found = false;
6124
6125 bool canRecurseInto(SCEVTypes Kind) const {
6126 // We can only recurse into the SCEV expression of the same effective type
6127 // as the type of our root SCEV expression, and into zero-extensions.
6128 return RootKind == Kind || NonSequentialRootKind == Kind ||
6129 scZeroExtend == Kind;
6130 };
6131
6132 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6133 : OperandToFind(OperandToFind), RootKind(RootKind),
6134 NonSequentialRootKind(
6136 RootKind)) {}
6137
6138 bool follow(const SCEV *S) {
6139 Found = S == OperandToFind;
6140
6141 return !isDone() && canRecurseInto(S->getSCEVType());
6142 }
6143
6144 bool isDone() const { return Found; }
6145 };
6146
6147 FindClosure FC(OperandToFind, RootKind);
6148 visitAll(Root, FC);
6149 return FC.Found;
6150}
6151
6152std::optional<const SCEV *>
6153ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6154 ICmpInst *Cond,
6155 Value *TrueVal,
6156 Value *FalseVal) {
6157 // Try to match some simple smax or umax patterns.
6158 auto *ICI = Cond;
6159
6160 Value *LHS = ICI->getOperand(0);
6161 Value *RHS = ICI->getOperand(1);
6162
6163 switch (ICI->getPredicate()) {
6164 case ICmpInst::ICMP_SLT:
6165 case ICmpInst::ICMP_SLE:
6166 case ICmpInst::ICMP_ULT:
6167 case ICmpInst::ICMP_ULE:
6168 std::swap(LHS, RHS);
6169 [[fallthrough]];
6170 case ICmpInst::ICMP_SGT:
6171 case ICmpInst::ICMP_SGE:
6172 case ICmpInst::ICMP_UGT:
6173 case ICmpInst::ICMP_UGE:
6174 // a > b ? a+x : b+x -> max(a, b)+x
6175 // a > b ? b+x : a+x -> min(a, b)+x
6177 bool Signed = ICI->isSigned();
6178 const SCEV *LA = getSCEV(TrueVal);
6179 const SCEV *RA = getSCEV(FalseVal);
6180 const SCEV *LS = getSCEV(LHS);
6181 const SCEV *RS = getSCEV(RHS);
6182 if (LA->getType()->isPointerTy()) {
6183 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6184 // Need to make sure we can't produce weird expressions involving
6185 // negated pointers.
6186 if (LA == LS && RA == RS)
6187 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6188 if (LA == RS && RA == LS)
6189 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6190 }
6191 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6192 if (Op->getType()->isPointerTy()) {
6195 return Op;
6196 }
6197 if (Signed)
6198 Op = getNoopOrSignExtend(Op, Ty);
6199 else
6200 Op = getNoopOrZeroExtend(Op, Ty);
6201 return Op;
6202 };
6203 LS = CoerceOperand(LS);
6204 RS = CoerceOperand(RS);
6206 break;
6207 const SCEV *LDiff = getMinusSCEV(LA, LS);
6208 const SCEV *RDiff = getMinusSCEV(RA, RS);
6209 if (LDiff == RDiff)
6210 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6211 LDiff);
6212 LDiff = getMinusSCEV(LA, RS);
6213 RDiff = getMinusSCEV(RA, LS);
6214 if (LDiff == RDiff)
6215 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6216 LDiff);
6217 }
6218 break;
6219 case ICmpInst::ICMP_NE:
6220 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6221 std::swap(TrueVal, FalseVal);
6222 [[fallthrough]];
6223 case ICmpInst::ICMP_EQ:
6224 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6227 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6228 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6229 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6230 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6231 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6232 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6233 return getAddExpr(getUMaxExpr(X, C), Y);
6234 }
6235 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6236 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6237 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6238 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6240 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6241 const SCEV *X = getSCEV(LHS);
6242 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6243 X = ZExt->getOperand();
6244 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6245 const SCEV *FalseValExpr = getSCEV(FalseVal);
6246 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6247 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6248 /*Sequential=*/true);
6249 }
6250 }
6251 break;
6252 default:
6253 break;
6254 }
6255
6256 return std::nullopt;
6257}
6258
6259static std::optional<const SCEV *>
6261 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6262 assert(CondExpr->getType()->isIntegerTy(1) &&
6263 TrueExpr->getType() == FalseExpr->getType() &&
6264 TrueExpr->getType()->isIntegerTy(1) &&
6265 "Unexpected operands of a select.");
6266
6267 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6268 // --> C + (umin_seq cond, x - C)
6269 //
6270 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6271 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6272 // --> C + (umin_seq ~cond, x - C)
6273
6274 // FIXME: while we can't legally model the case where both of the hands
6275 // are fully variable, we only require that the *difference* is constant.
6276 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6277 return std::nullopt;
6278
6279 const SCEV *X, *C;
6280 if (isa<SCEVConstant>(TrueExpr)) {
6281 CondExpr = SE->getNotSCEV(CondExpr);
6282 X = FalseExpr;
6283 C = TrueExpr;
6284 } else {
6285 X = TrueExpr;
6286 C = FalseExpr;
6287 }
6288 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6289 /*Sequential=*/true));
6290}
6291
6292static std::optional<const SCEV *>
6294 Value *FalseVal) {
6295 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6296 return std::nullopt;
6297
6298 const auto *SECond = SE->getSCEV(Cond);
6299 const auto *SETrue = SE->getSCEV(TrueVal);
6300 const auto *SEFalse = SE->getSCEV(FalseVal);
6301 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6302}
6303
6304const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6305 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6306 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6307 assert(TrueVal->getType() == FalseVal->getType() &&
6308 V->getType() == TrueVal->getType() &&
6309 "Types of select hands and of the result must match.");
6310
6311 // For now, only deal with i1-typed `select`s.
6312 if (!V->getType()->isIntegerTy(1))
6313 return getUnknown(V);
6314
6315 if (std::optional<const SCEV *> S =
6316 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6317 return *S;
6318
6319 return getUnknown(V);
6320}
6321
6322const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6323 Value *TrueVal,
6324 Value *FalseVal) {
6325 // Handle "constant" branch or select. This can occur for instance when a
6326 // loop pass transforms an inner loop and moves on to process the outer loop.
6327 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6328 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6329
6330 if (auto *I = dyn_cast<Instruction>(V)) {
6331 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6332 if (std::optional<const SCEV *> S =
6333 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6334 TrueVal, FalseVal))
6335 return *S;
6336 }
6337 }
6338
6339 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6340}
6341
6342/// Expand GEP instructions into add and multiply operations. This allows them
6343/// to be analyzed by regular SCEV code.
6344const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6345 assert(GEP->getSourceElementType()->isSized() &&
6346 "GEP source element type must be sized");
6347
6349 for (Value *Index : GEP->indices())
6350 IndexExprs.push_back(getSCEV(Index));
6351 return getGEPExpr(GEP, IndexExprs);
6352}
6353
6354APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6355 const Instruction *CtxI) {
6356 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6357 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6358 return TrailingZeros >= BitWidth
6360 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6361 };
6362 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6363 // The result is GCD of all operands results.
6364 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6365 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6367 Res, getConstantMultiple(N->getOperand(I), CtxI));
6368 return Res;
6369 };
6370
6371 switch (S->getSCEVType()) {
6372 case scConstant:
6373 return cast<SCEVConstant>(S)->getAPInt();
6374 case scPtrToInt:
6375 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
6376 case scUDivExpr:
6377 case scVScale:
6378 return APInt(BitWidth, 1);
6379 case scTruncate: {
6380 // Only multiples that are a power of 2 will hold after truncation.
6381 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6382 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6383 return GetShiftedByZeros(TZ);
6384 }
6385 case scZeroExtend: {
6386 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6387 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6388 }
6389 case scSignExtend: {
6390 // Only multiples that are a power of 2 will hold after sext.
6391 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6392 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6393 return GetShiftedByZeros(TZ);
6394 }
6395 case scMulExpr: {
6396 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6397 if (M->hasNoUnsignedWrap()) {
6398 // The result is the product of all operand results.
6399 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6400 for (const SCEV *Operand : M->operands().drop_front())
6401 Res = Res * getConstantMultiple(Operand, CtxI);
6402 return Res;
6403 }
6404
6405 // If there are no wrap guarentees, find the trailing zeros, which is the
6406 // sum of trailing zeros for all its operands.
6407 uint32_t TZ = 0;
6408 for (const SCEV *Operand : M->operands())
6409 TZ += getMinTrailingZeros(Operand, CtxI);
6410 return GetShiftedByZeros(TZ);
6411 }
6412 case scAddExpr:
6413 case scAddRecExpr: {
6414 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6415 if (N->hasNoUnsignedWrap())
6416 return GetGCDMultiple(N);
6417 // Find the trailing bits, which is the minimum of its operands.
6418 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6419 for (const SCEV *Operand : N->operands().drop_front())
6420 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6421 return GetShiftedByZeros(TZ);
6422 }
6423 case scUMaxExpr:
6424 case scSMaxExpr:
6425 case scUMinExpr:
6426 case scSMinExpr:
6428 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6429 case scUnknown: {
6430 // ask ValueTracking for known bits
6431 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6432 unsigned Known =
6433 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6434 .countMinTrailingZeros();
6435 return GetShiftedByZeros(Known);
6436 }
6437 case scCouldNotCompute:
6438 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6439 }
6440 llvm_unreachable("Unknown SCEV kind!");
6441}
6442
6444 const Instruction *CtxI) {
6445 // Skip looking up and updating the cache if there is a context instruction,
6446 // as the result will only be valid in the specified context.
6447 if (CtxI)
6448 return getConstantMultipleImpl(S, CtxI);
6449
6450 auto I = ConstantMultipleCache.find(S);
6451 if (I != ConstantMultipleCache.end())
6452 return I->second;
6453
6454 APInt Result = getConstantMultipleImpl(S, CtxI);
6455 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6456 assert(InsertPair.second && "Should insert a new key");
6457 return InsertPair.first->second;
6458}
6459
6461 APInt Multiple = getConstantMultiple(S);
6462 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6463}
6464
6466 const Instruction *CtxI) {
6467 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6468 (unsigned)getTypeSizeInBits(S->getType()));
6469}
6470
6471/// Helper method to assign a range to V from metadata present in the IR.
6472static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6474 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6475 return getConstantRangeFromMetadata(*MD);
6476 if (const auto *CB = dyn_cast<CallBase>(V))
6477 if (std::optional<ConstantRange> Range = CB->getRange())
6478 return Range;
6479 }
6480 if (auto *A = dyn_cast<Argument>(V))
6481 if (std::optional<ConstantRange> Range = A->getRange())
6482 return Range;
6483
6484 return std::nullopt;
6485}
6486
6488 SCEV::NoWrapFlags Flags) {
6489 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6490 AddRec->setNoWrapFlags(Flags);
6491 UnsignedRanges.erase(AddRec);
6492 SignedRanges.erase(AddRec);
6493 ConstantMultipleCache.erase(AddRec);
6494 }
6495}
6496
6497ConstantRange ScalarEvolution::
6498getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6499 const DataLayout &DL = getDataLayout();
6500
6501 unsigned BitWidth = getTypeSizeInBits(U->getType());
6502 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6503
6504 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6505 // use information about the trip count to improve our available range. Note
6506 // that the trip count independent cases are already handled by known bits.
6507 // WARNING: The definition of recurrence used here is subtly different than
6508 // the one used by AddRec (and thus most of this file). Step is allowed to
6509 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6510 // and other addrecs in the same loop (for non-affine addrecs). The code
6511 // below intentionally handles the case where step is not loop invariant.
6512 auto *P = dyn_cast<PHINode>(U->getValue());
6513 if (!P)
6514 return FullSet;
6515
6516 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6517 // even the values that are not available in these blocks may come from them,
6518 // and this leads to false-positive recurrence test.
6519 for (auto *Pred : predecessors(P->getParent()))
6520 if (!DT.isReachableFromEntry(Pred))
6521 return FullSet;
6522
6523 BinaryOperator *BO;
6524 Value *Start, *Step;
6525 if (!matchSimpleRecurrence(P, BO, Start, Step))
6526 return FullSet;
6527
6528 // If we found a recurrence in reachable code, we must be in a loop. Note
6529 // that BO might be in some subloop of L, and that's completely okay.
6530 auto *L = LI.getLoopFor(P->getParent());
6531 assert(L && L->getHeader() == P->getParent());
6532 if (!L->contains(BO->getParent()))
6533 // NOTE: This bailout should be an assert instead. However, asserting
6534 // the condition here exposes a case where LoopFusion is querying SCEV
6535 // with malformed loop information during the midst of the transform.
6536 // There doesn't appear to be an obvious fix, so for the moment bailout
6537 // until the caller issue can be fixed. PR49566 tracks the bug.
6538 return FullSet;
6539
6540 // TODO: Extend to other opcodes such as mul, and div
6541 switch (BO->getOpcode()) {
6542 default:
6543 return FullSet;
6544 case Instruction::AShr:
6545 case Instruction::LShr:
6546 case Instruction::Shl:
6547 break;
6548 };
6549
6550 if (BO->getOperand(0) != P)
6551 // TODO: Handle the power function forms some day.
6552 return FullSet;
6553
6554 unsigned TC = getSmallConstantMaxTripCount(L);
6555 if (!TC || TC >= BitWidth)
6556 return FullSet;
6557
6558 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6559 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6560 assert(KnownStart.getBitWidth() == BitWidth &&
6561 KnownStep.getBitWidth() == BitWidth);
6562
6563 // Compute total shift amount, being careful of overflow and bitwidths.
6564 auto MaxShiftAmt = KnownStep.getMaxValue();
6565 APInt TCAP(BitWidth, TC-1);
6566 bool Overflow = false;
6567 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6568 if (Overflow)
6569 return FullSet;
6570
6571 switch (BO->getOpcode()) {
6572 default:
6573 llvm_unreachable("filtered out above");
6574 case Instruction::AShr: {
6575 // For each ashr, three cases:
6576 // shift = 0 => unchanged value
6577 // saturation => 0 or -1
6578 // other => a value closer to zero (of the same sign)
6579 // Thus, the end value is closer to zero than the start.
6580 auto KnownEnd = KnownBits::ashr(KnownStart,
6581 KnownBits::makeConstant(TotalShift));
6582 if (KnownStart.isNonNegative())
6583 // Analogous to lshr (simply not yet canonicalized)
6584 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6585 KnownStart.getMaxValue() + 1);
6586 if (KnownStart.isNegative())
6587 // End >=u Start && End <=s Start
6588 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6589 KnownEnd.getMaxValue() + 1);
6590 break;
6591 }
6592 case Instruction::LShr: {
6593 // For each lshr, three cases:
6594 // shift = 0 => unchanged value
6595 // saturation => 0
6596 // other => a smaller positive number
6597 // Thus, the low end of the unsigned range is the last value produced.
6598 auto KnownEnd = KnownBits::lshr(KnownStart,
6599 KnownBits::makeConstant(TotalShift));
6600 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6601 KnownStart.getMaxValue() + 1);
6602 }
6603 case Instruction::Shl: {
6604 // Iff no bits are shifted out, value increases on every shift.
6605 auto KnownEnd = KnownBits::shl(KnownStart,
6606 KnownBits::makeConstant(TotalShift));
6607 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6608 return ConstantRange(KnownStart.getMinValue(),
6609 KnownEnd.getMaxValue() + 1);
6610 break;
6611 }
6612 };
6613 return FullSet;
6614}
6615
6616const ConstantRange &
6617ScalarEvolution::getRangeRefIter(const SCEV *S,
6618 ScalarEvolution::RangeSignHint SignHint) {
6619 DenseMap<const SCEV *, ConstantRange> &Cache =
6620 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6621 : SignedRanges;
6623 SmallPtrSet<const SCEV *, 8> Seen;
6624
6625 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6626 // SCEVUnknown PHI node.
6627 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6628 if (!Seen.insert(Expr).second)
6629 return;
6630 if (Cache.contains(Expr))
6631 return;
6632 switch (Expr->getSCEVType()) {
6633 case scUnknown:
6634 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6635 break;
6636 [[fallthrough]];
6637 case scConstant:
6638 case scVScale:
6639 case scTruncate:
6640 case scZeroExtend:
6641 case scSignExtend:
6642 case scPtrToInt:
6643 case scAddExpr:
6644 case scMulExpr:
6645 case scUDivExpr:
6646 case scAddRecExpr:
6647 case scUMaxExpr:
6648 case scSMaxExpr:
6649 case scUMinExpr:
6650 case scSMinExpr:
6652 WorkList.push_back(Expr);
6653 break;
6654 case scCouldNotCompute:
6655 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6656 }
6657 };
6658 AddToWorklist(S);
6659
6660 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6661 for (unsigned I = 0; I != WorkList.size(); ++I) {
6662 const SCEV *P = WorkList[I];
6663 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6664 // If it is not a `SCEVUnknown`, just recurse into operands.
6665 if (!UnknownS) {
6666 for (const SCEV *Op : P->operands())
6667 AddToWorklist(Op);
6668 continue;
6669 }
6670 // `SCEVUnknown`'s require special treatment.
6671 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6672 if (!PendingPhiRangesIter.insert(P).second)
6673 continue;
6674 for (auto &Op : reverse(P->operands()))
6675 AddToWorklist(getSCEV(Op));
6676 }
6677 }
6678
6679 if (!WorkList.empty()) {
6680 // Use getRangeRef to compute ranges for items in the worklist in reverse
6681 // order. This will force ranges for earlier operands to be computed before
6682 // their users in most cases.
6683 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6684 getRangeRef(P, SignHint);
6685
6686 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6687 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6688 PendingPhiRangesIter.erase(P);
6689 }
6690 }
6691
6692 return getRangeRef(S, SignHint, 0);
6693}
6694
6695/// Determine the range for a particular SCEV. If SignHint is
6696/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6697/// with a "cleaner" unsigned (resp. signed) representation.
6698const ConstantRange &ScalarEvolution::getRangeRef(
6699 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6700 DenseMap<const SCEV *, ConstantRange> &Cache =
6701 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6702 : SignedRanges;
6704 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6706
6707 // See if we've computed this range already.
6709 if (I != Cache.end())
6710 return I->second;
6711
6712 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6713 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6714
6715 // Switch to iteratively computing the range for S, if it is part of a deeply
6716 // nested expression.
6718 return getRangeRefIter(S, SignHint);
6719
6720 unsigned BitWidth = getTypeSizeInBits(S->getType());
6721 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6722 using OBO = OverflowingBinaryOperator;
6723
6724 // If the value has known zeros, the maximum value will have those known zeros
6725 // as well.
6726 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6727 APInt Multiple = getNonZeroConstantMultiple(S);
6728 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6729 if (!Remainder.isZero())
6730 ConservativeResult =
6731 ConstantRange(APInt::getMinValue(BitWidth),
6732 APInt::getMaxValue(BitWidth) - Remainder + 1);
6733 }
6734 else {
6735 uint32_t TZ = getMinTrailingZeros(S);
6736 if (TZ != 0) {
6737 ConservativeResult = ConstantRange(
6739 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6740 }
6741 }
6742
6743 switch (S->getSCEVType()) {
6744 case scConstant:
6745 llvm_unreachable("Already handled above.");
6746 case scVScale:
6747 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6748 case scTruncate: {
6749 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6750 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6751 return setRange(
6752 Trunc, SignHint,
6753 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6754 }
6755 case scZeroExtend: {
6756 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6757 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6758 return setRange(
6759 ZExt, SignHint,
6760 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6761 }
6762 case scSignExtend: {
6763 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6764 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6765 return setRange(
6766 SExt, SignHint,
6767 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6768 }
6769 case scPtrToInt: {
6770 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6771 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6772 return setRange(PtrToInt, SignHint, X);
6773 }
6774 case scAddExpr: {
6775 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6776 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6777 unsigned WrapType = OBO::AnyWrap;
6778 if (Add->hasNoSignedWrap())
6779 WrapType |= OBO::NoSignedWrap;
6780 if (Add->hasNoUnsignedWrap())
6781 WrapType |= OBO::NoUnsignedWrap;
6782 for (const SCEV *Op : drop_begin(Add->operands()))
6783 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6784 RangeType);
6785 return setRange(Add, SignHint,
6786 ConservativeResult.intersectWith(X, RangeType));
6787 }
6788 case scMulExpr: {
6789 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6790 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6791 for (const SCEV *Op : drop_begin(Mul->operands()))
6792 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6793 return setRange(Mul, SignHint,
6794 ConservativeResult.intersectWith(X, RangeType));
6795 }
6796 case scUDivExpr: {
6797 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6798 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6799 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6800 return setRange(UDiv, SignHint,
6801 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6802 }
6803 case scAddRecExpr: {
6804 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6805 // If there's no unsigned wrap, the value will never be less than its
6806 // initial value.
6807 if (AddRec->hasNoUnsignedWrap()) {
6808 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6809 if (!UnsignedMinValue.isZero())
6810 ConservativeResult = ConservativeResult.intersectWith(
6811 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6812 }
6813
6814 // If there's no signed wrap, and all the operands except initial value have
6815 // the same sign or zero, the value won't ever be:
6816 // 1: smaller than initial value if operands are non negative,
6817 // 2: bigger than initial value if operands are non positive.
6818 // For both cases, value can not cross signed min/max boundary.
6819 if (AddRec->hasNoSignedWrap()) {
6820 bool AllNonNeg = true;
6821 bool AllNonPos = true;
6822 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6823 if (!isKnownNonNegative(AddRec->getOperand(i)))
6824 AllNonNeg = false;
6825 if (!isKnownNonPositive(AddRec->getOperand(i)))
6826 AllNonPos = false;
6827 }
6828 if (AllNonNeg)
6829 ConservativeResult = ConservativeResult.intersectWith(
6832 RangeType);
6833 else if (AllNonPos)
6834 ConservativeResult = ConservativeResult.intersectWith(
6836 getSignedRangeMax(AddRec->getStart()) +
6837 1),
6838 RangeType);
6839 }
6840
6841 // TODO: non-affine addrec
6842 if (AddRec->isAffine()) {
6843 const SCEV *MaxBEScev =
6845 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6846 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6847
6848 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6849 // MaxBECount's active bits are all <= AddRec's bit width.
6850 if (MaxBECount.getBitWidth() > BitWidth &&
6851 MaxBECount.getActiveBits() <= BitWidth)
6852 MaxBECount = MaxBECount.trunc(BitWidth);
6853 else if (MaxBECount.getBitWidth() < BitWidth)
6854 MaxBECount = MaxBECount.zext(BitWidth);
6855
6856 if (MaxBECount.getBitWidth() == BitWidth) {
6857 auto RangeFromAffine = getRangeForAffineAR(
6858 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6859 ConservativeResult =
6860 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6861
6862 auto RangeFromFactoring = getRangeViaFactoring(
6863 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6864 ConservativeResult =
6865 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6866 }
6867 }
6868
6869 // Now try symbolic BE count and more powerful methods.
6871 const SCEV *SymbolicMaxBECount =
6873 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6874 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6875 AddRec->hasNoSelfWrap()) {
6876 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6877 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6878 ConservativeResult =
6879 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6880 }
6881 }
6882 }
6883
6884 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6885 }
6886 case scUMaxExpr:
6887 case scSMaxExpr:
6888 case scUMinExpr:
6889 case scSMinExpr:
6890 case scSequentialUMinExpr: {
6892 switch (S->getSCEVType()) {
6893 case scUMaxExpr:
6894 ID = Intrinsic::umax;
6895 break;
6896 case scSMaxExpr:
6897 ID = Intrinsic::smax;
6898 break;
6899 case scUMinExpr:
6901 ID = Intrinsic::umin;
6902 break;
6903 case scSMinExpr:
6904 ID = Intrinsic::smin;
6905 break;
6906 default:
6907 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6908 }
6909
6910 const auto *NAry = cast<SCEVNAryExpr>(S);
6911 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6912 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6913 X = X.intrinsic(
6914 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6915 return setRange(S, SignHint,
6916 ConservativeResult.intersectWith(X, RangeType));
6917 }
6918 case scUnknown: {
6919 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6920 Value *V = U->getValue();
6921
6922 // Check if the IR explicitly contains !range metadata.
6923 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6924 if (MDRange)
6925 ConservativeResult =
6926 ConservativeResult.intersectWith(*MDRange, RangeType);
6927
6928 // Use facts about recurrences in the underlying IR. Note that add
6929 // recurrences are AddRecExprs and thus don't hit this path. This
6930 // primarily handles shift recurrences.
6931 auto CR = getRangeForUnknownRecurrence(U);
6932 ConservativeResult = ConservativeResult.intersectWith(CR);
6933
6934 // See if ValueTracking can give us a useful range.
6935 const DataLayout &DL = getDataLayout();
6936 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6937 if (Known.getBitWidth() != BitWidth)
6938 Known = Known.zextOrTrunc(BitWidth);
6939
6940 // ValueTracking may be able to compute a tighter result for the number of
6941 // sign bits than for the value of those sign bits.
6942 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6943 if (U->getType()->isPointerTy()) {
6944 // If the pointer size is larger than the index size type, this can cause
6945 // NS to be larger than BitWidth. So compensate for this.
6946 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6947 int ptrIdxDiff = ptrSize - BitWidth;
6948 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6949 NS -= ptrIdxDiff;
6950 }
6951
6952 if (NS > 1) {
6953 // If we know any of the sign bits, we know all of the sign bits.
6954 if (!Known.Zero.getHiBits(NS).isZero())
6955 Known.Zero.setHighBits(NS);
6956 if (!Known.One.getHiBits(NS).isZero())
6957 Known.One.setHighBits(NS);
6958 }
6959
6960 if (Known.getMinValue() != Known.getMaxValue() + 1)
6961 ConservativeResult = ConservativeResult.intersectWith(
6962 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6963 RangeType);
6964 if (NS > 1)
6965 ConservativeResult = ConservativeResult.intersectWith(
6966 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6967 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6968 RangeType);
6969
6970 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6971 // Strengthen the range if the underlying IR value is a
6972 // global/alloca/heap allocation using the size of the object.
6973 bool CanBeNull, CanBeFreed;
6974 uint64_t DerefBytes =
6975 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6976 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6977 // The highest address the object can start is DerefBytes bytes before
6978 // the end (unsigned max value). If this value is not a multiple of the
6979 // alignment, the last possible start value is the next lowest multiple
6980 // of the alignment. Note: The computations below cannot overflow,
6981 // because if they would there's no possible start address for the
6982 // object.
6983 APInt MaxVal =
6984 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6985 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6986 uint64_t Rem = MaxVal.urem(Align);
6987 MaxVal -= APInt(BitWidth, Rem);
6988 APInt MinVal = APInt::getZero(BitWidth);
6989 if (llvm::isKnownNonZero(V, DL))
6990 MinVal = Align;
6991 ConservativeResult = ConservativeResult.intersectWith(
6992 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6993 }
6994 }
6995
6996 // A range of Phi is a subset of union of all ranges of its input.
6997 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6998 // Make sure that we do not run over cycled Phis.
6999 if (PendingPhiRanges.insert(Phi).second) {
7000 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7001
7002 for (const auto &Op : Phi->operands()) {
7003 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7004 RangeFromOps = RangeFromOps.unionWith(OpRange);
7005 // No point to continue if we already have a full set.
7006 if (RangeFromOps.isFullSet())
7007 break;
7008 }
7009 ConservativeResult =
7010 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7011 bool Erased = PendingPhiRanges.erase(Phi);
7012 assert(Erased && "Failed to erase Phi properly?");
7013 (void)Erased;
7014 }
7015 }
7016
7017 // vscale can't be equal to zero
7018 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7019 if (II->getIntrinsicID() == Intrinsic::vscale) {
7020 ConstantRange Disallowed = APInt::getZero(BitWidth);
7021 ConservativeResult = ConservativeResult.difference(Disallowed);
7022 }
7023
7024 return setRange(U, SignHint, std::move(ConservativeResult));
7025 }
7026 case scCouldNotCompute:
7027 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7028 }
7029
7030 return setRange(S, SignHint, std::move(ConservativeResult));
7031}
7032
7033// Given a StartRange, Step and MaxBECount for an expression compute a range of
7034// values that the expression can take. Initially, the expression has a value
7035// from StartRange and then is changed by Step up to MaxBECount times. Signed
7036// argument defines if we treat Step as signed or unsigned.
7038 const ConstantRange &StartRange,
7039 const APInt &MaxBECount,
7040 bool Signed) {
7041 unsigned BitWidth = Step.getBitWidth();
7042 assert(BitWidth == StartRange.getBitWidth() &&
7043 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7044 // If either Step or MaxBECount is 0, then the expression won't change, and we
7045 // just need to return the initial range.
7046 if (Step == 0 || MaxBECount == 0)
7047 return StartRange;
7048
7049 // If we don't know anything about the initial value (i.e. StartRange is
7050 // FullRange), then we don't know anything about the final range either.
7051 // Return FullRange.
7052 if (StartRange.isFullSet())
7053 return ConstantRange::getFull(BitWidth);
7054
7055 // If Step is signed and negative, then we use its absolute value, but we also
7056 // note that we're moving in the opposite direction.
7057 bool Descending = Signed && Step.isNegative();
7058
7059 if (Signed)
7060 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7061 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7062 // This equations hold true due to the well-defined wrap-around behavior of
7063 // APInt.
7064 Step = Step.abs();
7065
7066 // Check if Offset is more than full span of BitWidth. If it is, the
7067 // expression is guaranteed to overflow.
7068 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7069 return ConstantRange::getFull(BitWidth);
7070
7071 // Offset is by how much the expression can change. Checks above guarantee no
7072 // overflow here.
7073 APInt Offset = Step * MaxBECount;
7074
7075 // Minimum value of the final range will match the minimal value of StartRange
7076 // if the expression is increasing and will be decreased by Offset otherwise.
7077 // Maximum value of the final range will match the maximal value of StartRange
7078 // if the expression is decreasing and will be increased by Offset otherwise.
7079 APInt StartLower = StartRange.getLower();
7080 APInt StartUpper = StartRange.getUpper() - 1;
7081 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7082 : (StartUpper + std::move(Offset));
7083
7084 // It's possible that the new minimum/maximum value will fall into the initial
7085 // range (due to wrap around). This means that the expression can take any
7086 // value in this bitwidth, and we have to return full range.
7087 if (StartRange.contains(MovedBoundary))
7088 return ConstantRange::getFull(BitWidth);
7089
7090 APInt NewLower =
7091 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7092 APInt NewUpper =
7093 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7094 NewUpper += 1;
7095
7096 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7097 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7098}
7099
7100ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7101 const SCEV *Step,
7102 const APInt &MaxBECount) {
7103 assert(getTypeSizeInBits(Start->getType()) ==
7104 getTypeSizeInBits(Step->getType()) &&
7105 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7106 "mismatched bit widths");
7107
7108 // First, consider step signed.
7109 ConstantRange StartSRange = getSignedRange(Start);
7110 ConstantRange StepSRange = getSignedRange(Step);
7111
7112 // If Step can be both positive and negative, we need to find ranges for the
7113 // maximum absolute step values in both directions and union them.
7114 ConstantRange SR = getRangeForAffineARHelper(
7115 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7117 StartSRange, MaxBECount,
7118 /* Signed = */ true));
7119
7120 // Next, consider step unsigned.
7121 ConstantRange UR = getRangeForAffineARHelper(
7122 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7123 /* Signed = */ false);
7124
7125 // Finally, intersect signed and unsigned ranges.
7127}
7128
7129ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7130 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7131 ScalarEvolution::RangeSignHint SignHint) {
7132 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7133 assert(AddRec->hasNoSelfWrap() &&
7134 "This only works for non-self-wrapping AddRecs!");
7135 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7136 const SCEV *Step = AddRec->getStepRecurrence(*this);
7137 // Only deal with constant step to save compile time.
7138 if (!isa<SCEVConstant>(Step))
7139 return ConstantRange::getFull(BitWidth);
7140 // Let's make sure that we can prove that we do not self-wrap during
7141 // MaxBECount iterations. We need this because MaxBECount is a maximum
7142 // iteration count estimate, and we might infer nw from some exit for which we
7143 // do not know max exit count (or any other side reasoning).
7144 // TODO: Turn into assert at some point.
7145 if (getTypeSizeInBits(MaxBECount->getType()) >
7146 getTypeSizeInBits(AddRec->getType()))
7147 return ConstantRange::getFull(BitWidth);
7148 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7149 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7150 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7151 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7152 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7153 MaxItersWithoutWrap))
7154 return ConstantRange::getFull(BitWidth);
7155
7156 ICmpInst::Predicate LEPred =
7158 ICmpInst::Predicate GEPred =
7160 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7161
7162 // We know that there is no self-wrap. Let's take Start and End values and
7163 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7164 // the iteration. They either lie inside the range [Min(Start, End),
7165 // Max(Start, End)] or outside it:
7166 //
7167 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7168 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7169 //
7170 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7171 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7172 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7173 // Start <= End and step is positive, or Start >= End and step is negative.
7174 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7175 ConstantRange StartRange = getRangeRef(Start, SignHint);
7176 ConstantRange EndRange = getRangeRef(End, SignHint);
7177 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7178 // If they already cover full iteration space, we will know nothing useful
7179 // even if we prove what we want to prove.
7180 if (RangeBetween.isFullSet())
7181 return RangeBetween;
7182 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7183 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7184 : RangeBetween.isWrappedSet();
7185 if (IsWrappedSet)
7186 return ConstantRange::getFull(BitWidth);
7187
7188 if (isKnownPositive(Step) &&
7189 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7190 return RangeBetween;
7191 if (isKnownNegative(Step) &&
7192 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7193 return RangeBetween;
7194 return ConstantRange::getFull(BitWidth);
7195}
7196
7197ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7198 const SCEV *Step,
7199 const APInt &MaxBECount) {
7200 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7201 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7202
7203 unsigned BitWidth = MaxBECount.getBitWidth();
7204 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7205 getTypeSizeInBits(Step->getType()) == BitWidth &&
7206 "mismatched bit widths");
7207
7208 struct SelectPattern {
7209 Value *Condition = nullptr;
7210 APInt TrueValue;
7211 APInt FalseValue;
7212
7213 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7214 const SCEV *S) {
7215 std::optional<unsigned> CastOp;
7216 APInt Offset(BitWidth, 0);
7217
7219 "Should be!");
7220
7221 // Peel off a constant offset. In the future we could consider being
7222 // smarter here and handle {Start+Step,+,Step} too.
7223 const APInt *Off;
7224 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7225 Offset = *Off;
7226
7227 // Peel off a cast operation
7228 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7229 CastOp = SCast->getSCEVType();
7230 S = SCast->getOperand();
7231 }
7232
7233 using namespace llvm::PatternMatch;
7234
7235 auto *SU = dyn_cast<SCEVUnknown>(S);
7236 const APInt *TrueVal, *FalseVal;
7237 if (!SU ||
7238 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7239 m_APInt(FalseVal)))) {
7240 Condition = nullptr;
7241 return;
7242 }
7243
7244 TrueValue = *TrueVal;
7245 FalseValue = *FalseVal;
7246
7247 // Re-apply the cast we peeled off earlier
7248 if (CastOp)
7249 switch (*CastOp) {
7250 default:
7251 llvm_unreachable("Unknown SCEV cast type!");
7252
7253 case scTruncate:
7254 TrueValue = TrueValue.trunc(BitWidth);
7255 FalseValue = FalseValue.trunc(BitWidth);
7256 break;
7257 case scZeroExtend:
7258 TrueValue = TrueValue.zext(BitWidth);
7259 FalseValue = FalseValue.zext(BitWidth);
7260 break;
7261 case scSignExtend:
7262 TrueValue = TrueValue.sext(BitWidth);
7263 FalseValue = FalseValue.sext(BitWidth);
7264 break;
7265 }
7266
7267 // Re-apply the constant offset we peeled off earlier
7268 TrueValue += Offset;
7269 FalseValue += Offset;
7270 }
7271
7272 bool isRecognized() { return Condition != nullptr; }
7273 };
7274
7275 SelectPattern StartPattern(*this, BitWidth, Start);
7276 if (!StartPattern.isRecognized())
7277 return ConstantRange::getFull(BitWidth);
7278
7279 SelectPattern StepPattern(*this, BitWidth, Step);
7280 if (!StepPattern.isRecognized())
7281 return ConstantRange::getFull(BitWidth);
7282
7283 if (StartPattern.Condition != StepPattern.Condition) {
7284 // We don't handle this case today; but we could, by considering four
7285 // possibilities below instead of two. I'm not sure if there are cases where
7286 // that will help over what getRange already does, though.
7287 return ConstantRange::getFull(BitWidth);
7288 }
7289
7290 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7291 // construct arbitrary general SCEV expressions here. This function is called
7292 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7293 // say) can end up caching a suboptimal value.
7294
7295 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7296 // C2352 and C2512 (otherwise it isn't needed).
7297
7298 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7299 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7300 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7301 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7302
7303 ConstantRange TrueRange =
7304 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7305 ConstantRange FalseRange =
7306 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7307
7308 return TrueRange.unionWith(FalseRange);
7309}
7310
7311SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7312 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7313 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7314
7315 // Return early if there are no flags to propagate to the SCEV.
7317 if (BinOp->hasNoUnsignedWrap())
7319 if (BinOp->hasNoSignedWrap())
7321 if (Flags == SCEV::FlagAnyWrap)
7322 return SCEV::FlagAnyWrap;
7323
7324 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7325}
7326
7327const Instruction *
7328ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7329 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7330 return &*AddRec->getLoop()->getHeader()->begin();
7331 if (auto *U = dyn_cast<SCEVUnknown>(S))
7332 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7333 return I;
7334 return nullptr;
7335}
7336
7337const Instruction *
7338ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7339 bool &Precise) {
7340 Precise = true;
7341 // Do a bounded search of the def relation of the requested SCEVs.
7342 SmallPtrSet<const SCEV *, 16> Visited;
7344 auto pushOp = [&](const SCEV *S) {
7345 if (!Visited.insert(S).second)
7346 return;
7347 // Threshold of 30 here is arbitrary.
7348 if (Visited.size() > 30) {
7349 Precise = false;
7350 return;
7351 }
7352 Worklist.push_back(S);
7353 };
7354
7355 for (const auto *S : Ops)
7356 pushOp(S);
7357
7358 const Instruction *Bound = nullptr;
7359 while (!Worklist.empty()) {
7360 auto *S = Worklist.pop_back_val();
7361 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7362 if (!Bound || DT.dominates(Bound, DefI))
7363 Bound = DefI;
7364 } else {
7365 for (const auto *Op : S->operands())
7366 pushOp(Op);
7367 }
7368 }
7369 return Bound ? Bound : &*F.getEntryBlock().begin();
7370}
7371
7372const Instruction *
7373ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7374 bool Discard;
7375 return getDefiningScopeBound(Ops, Discard);
7376}
7377
7378bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7379 const Instruction *B) {
7380 if (A->getParent() == B->getParent() &&
7382 B->getIterator()))
7383 return true;
7384
7385 auto *BLoop = LI.getLoopFor(B->getParent());
7386 if (BLoop && BLoop->getHeader() == B->getParent() &&
7387 BLoop->getLoopPreheader() == A->getParent() &&
7389 A->getParent()->end()) &&
7390 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7391 B->getIterator()))
7392 return true;
7393 return false;
7394}
7395
7396bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7397 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7398 visitAll(Op, PC);
7399 return PC.MaybePoison.empty();
7400}
7401
7402bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7403 return !SCEVExprContains(Op, [this](const SCEV *S) {
7404 const SCEV *Op1;
7405 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7406 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7407 // is a non-zero constant, we have to assume the UDiv may be UB.
7408 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7409 });
7410}
7411
7412bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7413 // Only proceed if we can prove that I does not yield poison.
7415 return false;
7416
7417 // At this point we know that if I is executed, then it does not wrap
7418 // according to at least one of NSW or NUW. If I is not executed, then we do
7419 // not know if the calculation that I represents would wrap. Multiple
7420 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7421 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7422 // derived from other instructions that map to the same SCEV. We cannot make
7423 // that guarantee for cases where I is not executed. So we need to find a
7424 // upper bound on the defining scope for the SCEV, and prove that I is
7425 // executed every time we enter that scope. When the bounding scope is a
7426 // loop (the common case), this is equivalent to proving I executes on every
7427 // iteration of that loop.
7429 for (const Use &Op : I->operands()) {
7430 // I could be an extractvalue from a call to an overflow intrinsic.
7431 // TODO: We can do better here in some cases.
7432 if (isSCEVable(Op->getType()))
7433 SCEVOps.push_back(getSCEV(Op));
7434 }
7435 auto *DefI = getDefiningScopeBound(SCEVOps);
7436 return isGuaranteedToTransferExecutionTo(DefI, I);
7437}
7438
7439bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7440 // If we know that \c I can never be poison period, then that's enough.
7441 if (isSCEVExprNeverPoison(I))
7442 return true;
7443
7444 // If the loop only has one exit, then we know that, if the loop is entered,
7445 // any instruction dominating that exit will be executed. If any such
7446 // instruction would result in UB, the addrec cannot be poison.
7447 //
7448 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7449 // also handles uses outside the loop header (they just need to dominate the
7450 // single exit).
7451
7452 auto *ExitingBB = L->getExitingBlock();
7453 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7454 return false;
7455
7456 SmallPtrSet<const Value *, 16> KnownPoison;
7458
7459 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7460 // things that are known to be poison under that assumption go on the
7461 // Worklist.
7462 KnownPoison.insert(I);
7463 Worklist.push_back(I);
7464
7465 while (!Worklist.empty()) {
7466 const Instruction *Poison = Worklist.pop_back_val();
7467
7468 for (const Use &U : Poison->uses()) {
7469 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7470 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7471 DT.dominates(PoisonUser->getParent(), ExitingBB))
7472 return true;
7473
7474 if (propagatesPoison(U) && L->contains(PoisonUser))
7475 if (KnownPoison.insert(PoisonUser).second)
7476 Worklist.push_back(PoisonUser);
7477 }
7478 }
7479
7480 return false;
7481}
7482
7483ScalarEvolution::LoopProperties
7484ScalarEvolution::getLoopProperties(const Loop *L) {
7485 using LoopProperties = ScalarEvolution::LoopProperties;
7486
7487 auto Itr = LoopPropertiesCache.find(L);
7488 if (Itr == LoopPropertiesCache.end()) {
7489 auto HasSideEffects = [](Instruction *I) {
7490 if (auto *SI = dyn_cast<StoreInst>(I))
7491 return !SI->isSimple();
7492
7493 if (I->mayThrow())
7494 return true;
7495
7496 // Non-volatile memset / memcpy do not count as side-effect for forward
7497 // progress.
7498 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7499 return false;
7500
7501 return I->mayWriteToMemory();
7502 };
7503
7504 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7505 /*HasNoSideEffects*/ true};
7506
7507 for (auto *BB : L->getBlocks())
7508 for (auto &I : *BB) {
7510 LP.HasNoAbnormalExits = false;
7511 if (HasSideEffects(&I))
7512 LP.HasNoSideEffects = false;
7513 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7514 break; // We're already as pessimistic as we can get.
7515 }
7516
7517 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7518 assert(InsertPair.second && "We just checked!");
7519 Itr = InsertPair.first;
7520 }
7521
7522 return Itr->second;
7523}
7524
7526 // A mustprogress loop without side effects must be finite.
7527 // TODO: The check used here is very conservative. It's only *specific*
7528 // side effects which are well defined in infinite loops.
7529 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7530}
7531
7532const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7533 // Worklist item with a Value and a bool indicating whether all operands have
7534 // been visited already.
7537
7538 Stack.emplace_back(V, true);
7539 Stack.emplace_back(V, false);
7540 while (!Stack.empty()) {
7541 auto E = Stack.pop_back_val();
7542 Value *CurV = E.getPointer();
7543
7544 if (getExistingSCEV(CurV))
7545 continue;
7546
7548 const SCEV *CreatedSCEV = nullptr;
7549 // If all operands have been visited already, create the SCEV.
7550 if (E.getInt()) {
7551 CreatedSCEV = createSCEV(CurV);
7552 } else {
7553 // Otherwise get the operands we need to create SCEV's for before creating
7554 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7555 // just use it.
7556 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7557 }
7558
7559 if (CreatedSCEV) {
7560 insertValueToMap(CurV, CreatedSCEV);
7561 } else {
7562 // Queue CurV for SCEV creation, followed by its's operands which need to
7563 // be constructed first.
7564 Stack.emplace_back(CurV, true);
7565 for (Value *Op : Ops)
7566 Stack.emplace_back(Op, false);
7567 }
7568 }
7569
7570 return getExistingSCEV(V);
7571}
7572
7573const SCEV *
7574ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7575 if (!isSCEVable(V->getType()))
7576 return getUnknown(V);
7577
7578 if (Instruction *I = dyn_cast<Instruction>(V)) {
7579 // Don't attempt to analyze instructions in blocks that aren't
7580 // reachable. Such instructions don't matter, and they aren't required
7581 // to obey basic rules for definitions dominating uses which this
7582 // analysis depends on.
7583 if (!DT.isReachableFromEntry(I->getParent()))
7584 return getUnknown(PoisonValue::get(V->getType()));
7585 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7586 return getConstant(CI);
7587 else if (isa<GlobalAlias>(V))
7588 return getUnknown(V);
7589 else if (!isa<ConstantExpr>(V))
7590 return getUnknown(V);
7591
7593 if (auto BO =
7595 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7596 switch (BO->Opcode) {
7597 case Instruction::Add:
7598 case Instruction::Mul: {
7599 // For additions and multiplications, traverse add/mul chains for which we
7600 // can potentially create a single SCEV, to reduce the number of
7601 // get{Add,Mul}Expr calls.
7602 do {
7603 if (BO->Op) {
7604 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7605 Ops.push_back(BO->Op);
7606 break;
7607 }
7608 }
7609 Ops.push_back(BO->RHS);
7610 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7612 if (!NewBO ||
7613 (BO->Opcode == Instruction::Add &&
7614 (NewBO->Opcode != Instruction::Add &&
7615 NewBO->Opcode != Instruction::Sub)) ||
7616 (BO->Opcode == Instruction::Mul &&
7617 NewBO->Opcode != Instruction::Mul)) {
7618 Ops.push_back(BO->LHS);
7619 break;
7620 }
7621 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7622 // requires a SCEV for the LHS.
7623 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7624 auto *I = dyn_cast<Instruction>(BO->Op);
7625 if (I && programUndefinedIfPoison(I)) {
7626 Ops.push_back(BO->LHS);
7627 break;
7628 }
7629 }
7630 BO = NewBO;
7631 } while (true);
7632 return nullptr;
7633 }
7634 case Instruction::Sub:
7635 case Instruction::UDiv:
7636 case Instruction::URem:
7637 break;
7638 case Instruction::AShr:
7639 case Instruction::Shl:
7640 case Instruction::Xor:
7641 if (!IsConstArg)
7642 return nullptr;
7643 break;
7644 case Instruction::And:
7645 case Instruction::Or:
7646 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7647 return nullptr;
7648 break;
7649 case Instruction::LShr:
7650 return getUnknown(V);
7651 default:
7652 llvm_unreachable("Unhandled binop");
7653 break;
7654 }
7655
7656 Ops.push_back(BO->LHS);
7657 Ops.push_back(BO->RHS);
7658 return nullptr;
7659 }
7660
7661 switch (U->getOpcode()) {
7662 case Instruction::Trunc:
7663 case Instruction::ZExt:
7664 case Instruction::SExt:
7665 case Instruction::PtrToInt:
7666 Ops.push_back(U->getOperand(0));
7667 return nullptr;
7668
7669 case Instruction::BitCast:
7670 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7671 Ops.push_back(U->getOperand(0));
7672 return nullptr;
7673 }
7674 return getUnknown(V);
7675
7676 case Instruction::SDiv:
7677 case Instruction::SRem:
7678 Ops.push_back(U->getOperand(0));
7679 Ops.push_back(U->getOperand(1));
7680 return nullptr;
7681
7682 case Instruction::GetElementPtr:
7683 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7684 "GEP source element type must be sized");
7685 llvm::append_range(Ops, U->operands());
7686 return nullptr;
7687
7688 case Instruction::IntToPtr:
7689 return getUnknown(V);
7690
7691 case Instruction::PHI:
7692 // Keep constructing SCEVs' for phis recursively for now.
7693 return nullptr;
7694
7695 case Instruction::Select: {
7696 // Check if U is a select that can be simplified to a SCEVUnknown.
7697 auto CanSimplifyToUnknown = [this, U]() {
7698 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7699 return false;
7700
7701 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7702 if (!ICI)
7703 return false;
7704 Value *LHS = ICI->getOperand(0);
7705 Value *RHS = ICI->getOperand(1);
7706 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7707 ICI->getPredicate() == CmpInst::ICMP_NE) {
7709 return true;
7710 } else if (getTypeSizeInBits(LHS->getType()) >
7711 getTypeSizeInBits(U->getType()))
7712 return true;
7713 return false;
7714 };
7715 if (CanSimplifyToUnknown())
7716 return getUnknown(U);
7717
7718 llvm::append_range(Ops, U->operands());
7719 return nullptr;
7720 break;
7721 }
7722 case Instruction::Call:
7723 case Instruction::Invoke:
7724 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7725 Ops.push_back(RV);
7726 return nullptr;
7727 }
7728
7729 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7730 switch (II->getIntrinsicID()) {
7731 case Intrinsic::abs:
7732 Ops.push_back(II->getArgOperand(0));
7733 return nullptr;
7734 case Intrinsic::umax:
7735 case Intrinsic::umin:
7736 case Intrinsic::smax:
7737 case Intrinsic::smin:
7738 case Intrinsic::usub_sat:
7739 case Intrinsic::uadd_sat:
7740 Ops.push_back(II->getArgOperand(0));
7741 Ops.push_back(II->getArgOperand(1));
7742 return nullptr;
7743 case Intrinsic::start_loop_iterations:
7744 case Intrinsic::annotation:
7745 case Intrinsic::ptr_annotation:
7746 Ops.push_back(II->getArgOperand(0));
7747 return nullptr;
7748 default:
7749 break;
7750 }
7751 }
7752 break;
7753 }
7754
7755 return nullptr;
7756}
7757
7758const SCEV *ScalarEvolution::createSCEV(Value *V) {
7759 if (!isSCEVable(V->getType()))
7760 return getUnknown(V);
7761
7762 if (Instruction *I = dyn_cast<Instruction>(V)) {
7763 // Don't attempt to analyze instructions in blocks that aren't
7764 // reachable. Such instructions don't matter, and they aren't required
7765 // to obey basic rules for definitions dominating uses which this
7766 // analysis depends on.
7767 if (!DT.isReachableFromEntry(I->getParent()))
7768 return getUnknown(PoisonValue::get(V->getType()));
7769 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7770 return getConstant(CI);
7771 else if (isa<GlobalAlias>(V))
7772 return getUnknown(V);
7773 else if (!isa<ConstantExpr>(V))
7774 return getUnknown(V);
7775
7776 const SCEV *LHS;
7777 const SCEV *RHS;
7778
7780 if (auto BO =
7782 switch (BO->Opcode) {
7783 case Instruction::Add: {
7784 // The simple thing to do would be to just call getSCEV on both operands
7785 // and call getAddExpr with the result. However if we're looking at a
7786 // bunch of things all added together, this can be quite inefficient,
7787 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7788 // Instead, gather up all the operands and make a single getAddExpr call.
7789 // LLVM IR canonical form means we need only traverse the left operands.
7791 do {
7792 if (BO->Op) {
7793 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7794 AddOps.push_back(OpSCEV);
7795 break;
7796 }
7797
7798 // If a NUW or NSW flag can be applied to the SCEV for this
7799 // addition, then compute the SCEV for this addition by itself
7800 // with a separate call to getAddExpr. We need to do that
7801 // instead of pushing the operands of the addition onto AddOps,
7802 // since the flags are only known to apply to this particular
7803 // addition - they may not apply to other additions that can be
7804 // formed with operands from AddOps.
7805 const SCEV *RHS = getSCEV(BO->RHS);
7806 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7807 if (Flags != SCEV::FlagAnyWrap) {
7808 const SCEV *LHS = getSCEV(BO->LHS);
7809 if (BO->Opcode == Instruction::Sub)
7810 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7811 else
7812 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7813 break;
7814 }
7815 }
7816
7817 if (BO->Opcode == Instruction::Sub)
7818 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7819 else
7820 AddOps.push_back(getSCEV(BO->RHS));
7821
7822 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7824 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7825 NewBO->Opcode != Instruction::Sub)) {
7826 AddOps.push_back(getSCEV(BO->LHS));
7827 break;
7828 }
7829 BO = NewBO;
7830 } while (true);
7831
7832 return getAddExpr(AddOps);
7833 }
7834
7835 case Instruction::Mul: {
7837 do {
7838 if (BO->Op) {
7839 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7840 MulOps.push_back(OpSCEV);
7841 break;
7842 }
7843
7844 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7845 if (Flags != SCEV::FlagAnyWrap) {
7846 LHS = getSCEV(BO->LHS);
7847 RHS = getSCEV(BO->RHS);
7848 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7849 break;
7850 }
7851 }
7852
7853 MulOps.push_back(getSCEV(BO->RHS));
7854 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7856 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7857 MulOps.push_back(getSCEV(BO->LHS));
7858 break;
7859 }
7860 BO = NewBO;
7861 } while (true);
7862
7863 return getMulExpr(MulOps);
7864 }
7865 case Instruction::UDiv:
7866 LHS = getSCEV(BO->LHS);
7867 RHS = getSCEV(BO->RHS);
7868 return getUDivExpr(LHS, RHS);
7869 case Instruction::URem:
7870 LHS = getSCEV(BO->LHS);
7871 RHS = getSCEV(BO->RHS);
7872 return getURemExpr(LHS, RHS);
7873 case Instruction::Sub: {
7875 if (BO->Op)
7876 Flags = getNoWrapFlagsFromUB(BO->Op);
7877 LHS = getSCEV(BO->LHS);
7878 RHS = getSCEV(BO->RHS);
7879 return getMinusSCEV(LHS, RHS, Flags);
7880 }
7881 case Instruction::And:
7882 // For an expression like x&255 that merely masks off the high bits,
7883 // use zext(trunc(x)) as the SCEV expression.
7884 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7885 if (CI->isZero())
7886 return getSCEV(BO->RHS);
7887 if (CI->isMinusOne())
7888 return getSCEV(BO->LHS);
7889 const APInt &A = CI->getValue();
7890
7891 // Instcombine's ShrinkDemandedConstant may strip bits out of
7892 // constants, obscuring what would otherwise be a low-bits mask.
7893 // Use computeKnownBits to compute what ShrinkDemandedConstant
7894 // knew about to reconstruct a low-bits mask value.
7895 unsigned LZ = A.countl_zero();
7896 unsigned TZ = A.countr_zero();
7897 unsigned BitWidth = A.getBitWidth();
7898 KnownBits Known(BitWidth);
7899 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7900
7901 APInt EffectiveMask =
7902 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7903 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7904 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7905 const SCEV *LHS = getSCEV(BO->LHS);
7906 const SCEV *ShiftedLHS = nullptr;
7907 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7908 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7909 // For an expression like (x * 8) & 8, simplify the multiply.
7910 unsigned MulZeros = OpC->getAPInt().countr_zero();
7911 unsigned GCD = std::min(MulZeros, TZ);
7912 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7914 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7915 append_range(MulOps, LHSMul->operands().drop_front());
7916 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7917 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7918 }
7919 }
7920 if (!ShiftedLHS)
7921 ShiftedLHS = getUDivExpr(LHS, MulCount);
7922 return getMulExpr(
7924 getTruncateExpr(ShiftedLHS,
7925 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7926 BO->LHS->getType()),
7927 MulCount);
7928 }
7929 }
7930 // Binary `and` is a bit-wise `umin`.
7931 if (BO->LHS->getType()->isIntegerTy(1)) {
7932 LHS = getSCEV(BO->LHS);
7933 RHS = getSCEV(BO->RHS);
7934 return getUMinExpr(LHS, RHS);
7935 }
7936 break;
7937
7938 case Instruction::Or:
7939 // Binary `or` is a bit-wise `umax`.
7940 if (BO->LHS->getType()->isIntegerTy(1)) {
7941 LHS = getSCEV(BO->LHS);
7942 RHS = getSCEV(BO->RHS);
7943 return getUMaxExpr(LHS, RHS);
7944 }
7945 break;
7946
7947 case Instruction::Xor:
7948 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7949 // If the RHS of xor is -1, then this is a not operation.
7950 if (CI->isMinusOne())
7951 return getNotSCEV(getSCEV(BO->LHS));
7952
7953 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7954 // This is a variant of the check for xor with -1, and it handles
7955 // the case where instcombine has trimmed non-demanded bits out
7956 // of an xor with -1.
7957 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7958 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7959 if (LBO->getOpcode() == Instruction::And &&
7960 LCI->getValue() == CI->getValue())
7961 if (const SCEVZeroExtendExpr *Z =
7963 Type *UTy = BO->LHS->getType();
7964 const SCEV *Z0 = Z->getOperand();
7965 Type *Z0Ty = Z0->getType();
7966 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7967
7968 // If C is a low-bits mask, the zero extend is serving to
7969 // mask off the high bits. Complement the operand and
7970 // re-apply the zext.
7971 if (CI->getValue().isMask(Z0TySize))
7972 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7973
7974 // If C is a single bit, it may be in the sign-bit position
7975 // before the zero-extend. In this case, represent the xor
7976 // using an add, which is equivalent, and re-apply the zext.
7977 APInt Trunc = CI->getValue().trunc(Z0TySize);
7978 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7979 Trunc.isSignMask())
7980 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7981 UTy);
7982 }
7983 }
7984 break;
7985
7986 case Instruction::Shl:
7987 // Turn shift left of a constant amount into a multiply.
7988 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7989 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7990
7991 // If the shift count is not less than the bitwidth, the result of
7992 // the shift is undefined. Don't try to analyze it, because the
7993 // resolution chosen here may differ from the resolution chosen in
7994 // other parts of the compiler.
7995 if (SA->getValue().uge(BitWidth))
7996 break;
7997
7998 // We can safely preserve the nuw flag in all cases. It's also safe to
7999 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8000 // requires special handling. It can be preserved as long as we're not
8001 // left shifting by bitwidth - 1.
8002 auto Flags = SCEV::FlagAnyWrap;
8003 if (BO->Op) {
8004 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8005 if ((MulFlags & SCEV::FlagNSW) &&
8006 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8008 if (MulFlags & SCEV::FlagNUW)
8010 }
8011
8012 ConstantInt *X = ConstantInt::get(
8013 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8014 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8015 }
8016 break;
8017
8018 case Instruction::AShr:
8019 // AShr X, C, where C is a constant.
8020 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8021 if (!CI)
8022 break;
8023
8024 Type *OuterTy = BO->LHS->getType();
8025 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8026 // If the shift count is not less than the bitwidth, the result of
8027 // the shift is undefined. Don't try to analyze it, because the
8028 // resolution chosen here may differ from the resolution chosen in
8029 // other parts of the compiler.
8030 if (CI->getValue().uge(BitWidth))
8031 break;
8032
8033 if (CI->isZero())
8034 return getSCEV(BO->LHS); // shift by zero --> noop
8035
8036 uint64_t AShrAmt = CI->getZExtValue();
8037 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8038
8039 Operator *L = dyn_cast<Operator>(BO->LHS);
8040 const SCEV *AddTruncateExpr = nullptr;
8041 ConstantInt *ShlAmtCI = nullptr;
8042 const SCEV *AddConstant = nullptr;
8043
8044 if (L && L->getOpcode() == Instruction::Add) {
8045 // X = Shl A, n
8046 // Y = Add X, c
8047 // Z = AShr Y, m
8048 // n, c and m are constants.
8049
8050 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8051 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8052 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8053 if (AddOperandCI) {
8054 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8055 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8056 // since we truncate to TruncTy, the AddConstant should be of the
8057 // same type, so create a new Constant with type same as TruncTy.
8058 // Also, the Add constant should be shifted right by AShr amount.
8059 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8060 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8061 // we model the expression as sext(add(trunc(A), c << n)), since the
8062 // sext(trunc) part is already handled below, we create a
8063 // AddExpr(TruncExp) which will be used later.
8064 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8065 }
8066 }
8067 } else if (L && L->getOpcode() == Instruction::Shl) {
8068 // X = Shl A, n
8069 // Y = AShr X, m
8070 // Both n and m are constant.
8071
8072 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8073 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8074 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8075 }
8076
8077 if (AddTruncateExpr && ShlAmtCI) {
8078 // We can merge the two given cases into a single SCEV statement,
8079 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8080 // a simpler case. The following code handles the two cases:
8081 //
8082 // 1) For a two-shift sext-inreg, i.e. n = m,
8083 // use sext(trunc(x)) as the SCEV expression.
8084 //
8085 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8086 // expression. We already checked that ShlAmt < BitWidth, so
8087 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8088 // ShlAmt - AShrAmt < Amt.
8089 const APInt &ShlAmt = ShlAmtCI->getValue();
8090 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8091 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8092 ShlAmtCI->getZExtValue() - AShrAmt);
8093 const SCEV *CompositeExpr =
8094 getMulExpr(AddTruncateExpr, getConstant(Mul));
8095 if (L->getOpcode() != Instruction::Shl)
8096 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8097
8098 return getSignExtendExpr(CompositeExpr, OuterTy);
8099 }
8100 }
8101 break;
8102 }
8103 }
8104
8105 switch (U->getOpcode()) {
8106 case Instruction::Trunc:
8107 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8108
8109 case Instruction::ZExt:
8110 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8111
8112 case Instruction::SExt:
8113 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8115 // The NSW flag of a subtract does not always survive the conversion to
8116 // A + (-1)*B. By pushing sign extension onto its operands we are much
8117 // more likely to preserve NSW and allow later AddRec optimisations.
8118 //
8119 // NOTE: This is effectively duplicating this logic from getSignExtend:
8120 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8121 // but by that point the NSW information has potentially been lost.
8122 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8123 Type *Ty = U->getType();
8124 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8125 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8126 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8127 }
8128 }
8129 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8130
8131 case Instruction::BitCast:
8132 // BitCasts are no-op casts so we just eliminate the cast.
8133 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8134 return getSCEV(U->getOperand(0));
8135 break;
8136
8137 case Instruction::PtrToInt: {
8138 // Pointer to integer cast is straight-forward, so do model it.
8139 const SCEV *Op = getSCEV(U->getOperand(0));
8140 Type *DstIntTy = U->getType();
8141 // But only if effective SCEV (integer) type is wide enough to represent
8142 // all possible pointer values.
8143 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8144 if (isa<SCEVCouldNotCompute>(IntOp))
8145 return getUnknown(V);
8146 return IntOp;
8147 }
8148 case Instruction::IntToPtr:
8149 // Just don't deal with inttoptr casts.
8150 return getUnknown(V);
8151
8152 case Instruction::SDiv:
8153 // If both operands are non-negative, this is just an udiv.
8154 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8155 isKnownNonNegative(getSCEV(U->getOperand(1))))
8156 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8157 break;
8158
8159 case Instruction::SRem:
8160 // If both operands are non-negative, this is just an urem.
8161 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8162 isKnownNonNegative(getSCEV(U->getOperand(1))))
8163 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8164 break;
8165
8166 case Instruction::GetElementPtr:
8167 return createNodeForGEP(cast<GEPOperator>(U));
8168
8169 case Instruction::PHI:
8170 return createNodeForPHI(cast<PHINode>(U));
8171
8172 case Instruction::Select:
8173 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8174 U->getOperand(2));
8175
8176 case Instruction::Call:
8177 case Instruction::Invoke:
8178 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8179 return getSCEV(RV);
8180
8181 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8182 switch (II->getIntrinsicID()) {
8183 case Intrinsic::abs:
8184 return getAbsExpr(
8185 getSCEV(II->getArgOperand(0)),
8186 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8187 case Intrinsic::umax:
8188 LHS = getSCEV(II->getArgOperand(0));
8189 RHS = getSCEV(II->getArgOperand(1));
8190 return getUMaxExpr(LHS, RHS);
8191 case Intrinsic::umin:
8192 LHS = getSCEV(II->getArgOperand(0));
8193 RHS = getSCEV(II->getArgOperand(1));
8194 return getUMinExpr(LHS, RHS);
8195 case Intrinsic::smax:
8196 LHS = getSCEV(II->getArgOperand(0));
8197 RHS = getSCEV(II->getArgOperand(1));
8198 return getSMaxExpr(LHS, RHS);
8199 case Intrinsic::smin:
8200 LHS = getSCEV(II->getArgOperand(0));
8201 RHS = getSCEV(II->getArgOperand(1));
8202 return getSMinExpr(LHS, RHS);
8203 case Intrinsic::usub_sat: {
8204 const SCEV *X = getSCEV(II->getArgOperand(0));
8205 const SCEV *Y = getSCEV(II->getArgOperand(1));
8206 const SCEV *ClampedY = getUMinExpr(X, Y);
8207 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8208 }
8209 case Intrinsic::uadd_sat: {
8210 const SCEV *X = getSCEV(II->getArgOperand(0));
8211 const SCEV *Y = getSCEV(II->getArgOperand(1));
8212 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8213 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8214 }
8215 case Intrinsic::start_loop_iterations:
8216 case Intrinsic::annotation:
8217 case Intrinsic::ptr_annotation:
8218 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8219 // just eqivalent to the first operand for SCEV purposes.
8220 return getSCEV(II->getArgOperand(0));
8221 case Intrinsic::vscale:
8222 return getVScale(II->getType());
8223 default:
8224 break;
8225 }
8226 }
8227 break;
8228 }
8229
8230 return getUnknown(V);
8231}
8232
8233//===----------------------------------------------------------------------===//
8234// Iteration Count Computation Code
8235//
8236
8238 if (isa<SCEVCouldNotCompute>(ExitCount))
8239 return getCouldNotCompute();
8240
8241 auto *ExitCountType = ExitCount->getType();
8242 assert(ExitCountType->isIntegerTy());
8243 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8244 1 + ExitCountType->getScalarSizeInBits());
8245 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8246}
8247
8249 Type *EvalTy,
8250 const Loop *L) {
8251 if (isa<SCEVCouldNotCompute>(ExitCount))
8252 return getCouldNotCompute();
8253
8254 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8255 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8256
8257 auto CanAddOneWithoutOverflow = [&]() {
8258 ConstantRange ExitCountRange =
8259 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8260 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8261 return true;
8262
8263 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8264 getMinusOne(ExitCount->getType()));
8265 };
8266
8267 // If we need to zero extend the backedge count, check if we can add one to
8268 // it prior to zero extending without overflow. Provided this is safe, it
8269 // allows better simplification of the +1.
8270 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8271 return getZeroExtendExpr(
8272 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8273
8274 // Get the total trip count from the count by adding 1. This may wrap.
8275 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8276}
8277
8278static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8279 if (!ExitCount)
8280 return 0;
8281
8282 ConstantInt *ExitConst = ExitCount->getValue();
8283
8284 // Guard against huge trip counts.
8285 if (ExitConst->getValue().getActiveBits() > 32)
8286 return 0;
8287
8288 // In case of integer overflow, this returns 0, which is correct.
8289 return ((unsigned)ExitConst->getZExtValue()) + 1;
8290}
8291
8293 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8294 return getConstantTripCount(ExitCount);
8295}
8296
8297unsigned
8299 const BasicBlock *ExitingBlock) {
8300 assert(ExitingBlock && "Must pass a non-null exiting block!");
8301 assert(L->isLoopExiting(ExitingBlock) &&
8302 "Exiting block must actually branch out of the loop!");
8303 const SCEVConstant *ExitCount =
8304 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8305 return getConstantTripCount(ExitCount);
8306}
8307
8309 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8310
8311 const auto *MaxExitCount =
8312 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8314 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8315}
8316
8318 SmallVector<BasicBlock *, 8> ExitingBlocks;
8319 L->getExitingBlocks(ExitingBlocks);
8320
8321 std::optional<unsigned> Res;
8322 for (auto *ExitingBB : ExitingBlocks) {
8323 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8324 if (!Res)
8325 Res = Multiple;
8326 Res = std::gcd(*Res, Multiple);
8327 }
8328 return Res.value_or(1);
8329}
8330
8332 const SCEV *ExitCount) {
8333 if (isa<SCEVCouldNotCompute>(ExitCount))
8334 return 1;
8335
8336 // Get the trip count
8337 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8338
8339 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8340 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8341 // the greatest power of 2 divisor less than 2^32.
8342 return Multiple.getActiveBits() > 32
8343 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8344 : (unsigned)Multiple.getZExtValue();
8345}
8346
8347/// Returns the largest constant divisor of the trip count of this loop as a
8348/// normal unsigned value, if possible. This means that the actual trip count is
8349/// always a multiple of the returned value (don't forget the trip count could
8350/// very well be zero as well!).
8351///
8352/// Returns 1 if the trip count is unknown or not guaranteed to be the
8353/// multiple of a constant (which is also the case if the trip count is simply
8354/// constant, use getSmallConstantTripCount for that case), Will also return 1
8355/// if the trip count is very large (>= 2^32).
8356///
8357/// As explained in the comments for getSmallConstantTripCount, this assumes
8358/// that control exits the loop via ExitingBlock.
8359unsigned
8361 const BasicBlock *ExitingBlock) {
8362 assert(ExitingBlock && "Must pass a non-null exiting block!");
8363 assert(L->isLoopExiting(ExitingBlock) &&
8364 "Exiting block must actually branch out of the loop!");
8365 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8366 return getSmallConstantTripMultiple(L, ExitCount);
8367}
8368
8370 const BasicBlock *ExitingBlock,
8371 ExitCountKind Kind) {
8372 switch (Kind) {
8373 case Exact:
8374 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8375 case SymbolicMaximum:
8376 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8377 case ConstantMaximum:
8378 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8379 };
8380 llvm_unreachable("Invalid ExitCountKind!");
8381}
8382
8384 const Loop *L, const BasicBlock *ExitingBlock,
8386 switch (Kind) {
8387 case Exact:
8388 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8389 Predicates);
8390 case SymbolicMaximum:
8391 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8392 Predicates);
8393 case ConstantMaximum:
8394 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8395 Predicates);
8396 };
8397 llvm_unreachable("Invalid ExitCountKind!");
8398}
8399
8402 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8403}
8404
8406 ExitCountKind Kind) {
8407 switch (Kind) {
8408 case Exact:
8409 return getBackedgeTakenInfo(L).getExact(L, this);
8410 case ConstantMaximum:
8411 return getBackedgeTakenInfo(L).getConstantMax(this);
8412 case SymbolicMaximum:
8413 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8414 };
8415 llvm_unreachable("Invalid ExitCountKind!");
8416}
8417
8420 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8421}
8422
8425 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8426}
8427
8429 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8430}
8431
8432/// Push PHI nodes in the header of the given loop onto the given Worklist.
8433static void PushLoopPHIs(const Loop *L,
8436 BasicBlock *Header = L->getHeader();
8437
8438 // Push all Loop-header PHIs onto the Worklist stack.
8439 for (PHINode &PN : Header->phis())
8440 if (Visited.insert(&PN).second)
8441 Worklist.push_back(&PN);
8442}
8443
8444ScalarEvolution::BackedgeTakenInfo &
8445ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8446 auto &BTI = getBackedgeTakenInfo(L);
8447 if (BTI.hasFullInfo())
8448 return BTI;
8449
8450 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8451
8452 if (!Pair.second)
8453 return Pair.first->second;
8454
8455 BackedgeTakenInfo Result =
8456 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8457
8458 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8459}
8460
8461ScalarEvolution::BackedgeTakenInfo &
8462ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8463 // Initially insert an invalid entry for this loop. If the insertion
8464 // succeeds, proceed to actually compute a backedge-taken count and
8465 // update the value. The temporary CouldNotCompute value tells SCEV
8466 // code elsewhere that it shouldn't attempt to request a new
8467 // backedge-taken count, which could result in infinite recursion.
8468 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8469 BackedgeTakenCounts.try_emplace(L);
8470 if (!Pair.second)
8471 return Pair.first->second;
8472
8473 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8474 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8475 // must be cleared in this scope.
8476 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8477
8478 // Now that we know more about the trip count for this loop, forget any
8479 // existing SCEV values for PHI nodes in this loop since they are only
8480 // conservative estimates made without the benefit of trip count
8481 // information. This invalidation is not necessary for correctness, and is
8482 // only done to produce more precise results.
8483 if (Result.hasAnyInfo()) {
8484 // Invalidate any expression using an addrec in this loop.
8486 auto LoopUsersIt = LoopUsers.find(L);
8487 if (LoopUsersIt != LoopUsers.end())
8488 append_range(ToForget, LoopUsersIt->second);
8489 forgetMemoizedResults(ToForget);
8490
8491 // Invalidate constant-evolved loop header phis.
8492 for (PHINode &PN : L->getHeader()->phis())
8493 ConstantEvolutionLoopExitValue.erase(&PN);
8494 }
8495
8496 // Re-lookup the insert position, since the call to
8497 // computeBackedgeTakenCount above could result in a
8498 // recusive call to getBackedgeTakenInfo (on a different
8499 // loop), which would invalidate the iterator computed
8500 // earlier.
8501 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8502}
8503
8505 // This method is intended to forget all info about loops. It should
8506 // invalidate caches as if the following happened:
8507 // - The trip counts of all loops have changed arbitrarily
8508 // - Every llvm::Value has been updated in place to produce a different
8509 // result.
8510 BackedgeTakenCounts.clear();
8511 PredicatedBackedgeTakenCounts.clear();
8512 BECountUsers.clear();
8513 LoopPropertiesCache.clear();
8514 ConstantEvolutionLoopExitValue.clear();
8515 ValueExprMap.clear();
8516 ValuesAtScopes.clear();
8517 ValuesAtScopesUsers.clear();
8518 LoopDispositions.clear();
8519 BlockDispositions.clear();
8520 UnsignedRanges.clear();
8521 SignedRanges.clear();
8522 ExprValueMap.clear();
8523 HasRecMap.clear();
8524 ConstantMultipleCache.clear();
8525 PredicatedSCEVRewrites.clear();
8526 FoldCache.clear();
8527 FoldCacheUser.clear();
8528}
8529void ScalarEvolution::visitAndClearUsers(
8533 while (!Worklist.empty()) {
8534 Instruction *I = Worklist.pop_back_val();
8535 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8536 continue;
8537
8539 ValueExprMap.find_as(static_cast<Value *>(I));
8540 if (It != ValueExprMap.end()) {
8541 eraseValueFromMap(It->first);
8542 ToForget.push_back(It->second);
8543 if (PHINode *PN = dyn_cast<PHINode>(I))
8544 ConstantEvolutionLoopExitValue.erase(PN);
8545 }
8546
8547 PushDefUseChildren(I, Worklist, Visited);
8548 }
8549}
8550
8552 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8556
8557 // Iterate over all the loops and sub-loops to drop SCEV information.
8558 while (!LoopWorklist.empty()) {
8559 auto *CurrL = LoopWorklist.pop_back_val();
8560
8561 // Drop any stored trip count value.
8562 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8563 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8564
8565 // Drop information about predicated SCEV rewrites for this loop.
8566 for (auto I = PredicatedSCEVRewrites.begin();
8567 I != PredicatedSCEVRewrites.end();) {
8568 std::pair<const SCEV *, const Loop *> Entry = I->first;
8569 if (Entry.second == CurrL)
8570 PredicatedSCEVRewrites.erase(I++);
8571 else
8572 ++I;
8573 }
8574
8575 auto LoopUsersItr = LoopUsers.find(CurrL);
8576 if (LoopUsersItr != LoopUsers.end())
8577 llvm::append_range(ToForget, LoopUsersItr->second);
8578
8579 // Drop information about expressions based on loop-header PHIs.
8580 PushLoopPHIs(CurrL, Worklist, Visited);
8581 visitAndClearUsers(Worklist, Visited, ToForget);
8582
8583 LoopPropertiesCache.erase(CurrL);
8584 // Forget all contained loops too, to avoid dangling entries in the
8585 // ValuesAtScopes map.
8586 LoopWorklist.append(CurrL->begin(), CurrL->end());
8587 }
8588 forgetMemoizedResults(ToForget);
8589}
8590
8592 forgetLoop(L->getOutermostLoop());
8593}
8594
8597 if (!I) return;
8598
8599 // Drop information about expressions based on loop-header PHIs.
8603 Worklist.push_back(I);
8604 Visited.insert(I);
8605 visitAndClearUsers(Worklist, Visited, ToForget);
8606
8607 forgetMemoizedResults(ToForget);
8608}
8609
8611 if (!isSCEVable(V->getType()))
8612 return;
8613
8614 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8615 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8616 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8617 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8618 if (const SCEV *S = getExistingSCEV(V)) {
8619 struct InvalidationRootCollector {
8620 Loop *L;
8622
8623 InvalidationRootCollector(Loop *L) : L(L) {}
8624
8625 bool follow(const SCEV *S) {
8626 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8627 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8628 if (L->contains(I))
8629 Roots.push_back(S);
8630 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8631 if (L->contains(AddRec->getLoop()))
8632 Roots.push_back(S);
8633 }
8634 return true;
8635 }
8636 bool isDone() const { return false; }
8637 };
8638
8639 InvalidationRootCollector C(L);
8640 visitAll(S, C);
8641 forgetMemoizedResults(C.Roots);
8642 }
8643
8644 // Also perform the normal invalidation.
8645 forgetValue(V);
8646}
8647
8648void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8649
8651 // Unless a specific value is passed to invalidation, completely clear both
8652 // caches.
8653 if (!V) {
8654 BlockDispositions.clear();
8655 LoopDispositions.clear();
8656 return;
8657 }
8658
8659 if (!isSCEVable(V->getType()))
8660 return;
8661
8662 const SCEV *S = getExistingSCEV(V);
8663 if (!S)
8664 return;
8665
8666 // Invalidate the block and loop dispositions cached for S. Dispositions of
8667 // S's users may change if S's disposition changes (i.e. a user may change to
8668 // loop-invariant, if S changes to loop invariant), so also invalidate
8669 // dispositions of S's users recursively.
8670 SmallVector<const SCEV *, 8> Worklist = {S};
8672 while (!Worklist.empty()) {
8673 const SCEV *Curr = Worklist.pop_back_val();
8674 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8675 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8676 if (!LoopDispoRemoved && !BlockDispoRemoved)
8677 continue;
8678 auto Users = SCEVUsers.find(Curr);
8679 if (Users != SCEVUsers.end())
8680 for (const auto *User : Users->second)
8681 if (Seen.insert(User).second)
8682 Worklist.push_back(User);
8683 }
8684}
8685
8686/// Get the exact loop backedge taken count considering all loop exits. A
8687/// computable result can only be returned for loops with all exiting blocks
8688/// dominating the latch. howFarToZero assumes that the limit of each loop test
8689/// is never skipped. This is a valid assumption as long as the loop exits via
8690/// that test. For precise results, it is the caller's responsibility to specify
8691/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8692const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8693 const Loop *L, ScalarEvolution *SE,
8695 // If any exits were not computable, the loop is not computable.
8696 if (!isComplete() || ExitNotTaken.empty())
8697 return SE->getCouldNotCompute();
8698
8699 const BasicBlock *Latch = L->getLoopLatch();
8700 // All exiting blocks we have collected must dominate the only backedge.
8701 if (!Latch)
8702 return SE->getCouldNotCompute();
8703
8704 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8705 // count is simply a minimum out of all these calculated exit counts.
8707 for (const auto &ENT : ExitNotTaken) {
8708 const SCEV *BECount = ENT.ExactNotTaken;
8709 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8710 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8711 "We should only have known counts for exiting blocks that dominate "
8712 "latch!");
8713
8714 Ops.push_back(BECount);
8715
8716 if (Preds)
8717 append_range(*Preds, ENT.Predicates);
8718
8719 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8720 "Predicate should be always true!");
8721 }
8722
8723 // If an earlier exit exits on the first iteration (exit count zero), then
8724 // a later poison exit count should not propagate into the result. This are
8725 // exactly the semantics provided by umin_seq.
8726 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8727}
8728
8729const ScalarEvolution::ExitNotTakenInfo *
8730ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8731 const BasicBlock *ExitingBlock,
8732 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8733 for (const auto &ENT : ExitNotTaken)
8734 if (ENT.ExitingBlock == ExitingBlock) {
8735 if (ENT.hasAlwaysTruePredicate())
8736 return &ENT;
8737 else if (Predicates) {
8738 append_range(*Predicates, ENT.Predicates);
8739 return &ENT;
8740 }
8741 }
8742
8743 return nullptr;
8744}
8745
8746/// getConstantMax - Get the constant max backedge taken count for the loop.
8747const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8748 ScalarEvolution *SE,
8749 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8750 if (!getConstantMax())
8751 return SE->getCouldNotCompute();
8752
8753 for (const auto &ENT : ExitNotTaken)
8754 if (!ENT.hasAlwaysTruePredicate()) {
8755 if (!Predicates)
8756 return SE->getCouldNotCompute();
8757 append_range(*Predicates, ENT.Predicates);
8758 }
8759
8760 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8761 isa<SCEVConstant>(getConstantMax())) &&
8762 "No point in having a non-constant max backedge taken count!");
8763 return getConstantMax();
8764}
8765
8766const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8767 const Loop *L, ScalarEvolution *SE,
8768 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8769 if (!SymbolicMax) {
8770 // Form an expression for the maximum exit count possible for this loop. We
8771 // merge the max and exact information to approximate a version of
8772 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8773 // constants.
8775
8776 for (const auto &ENT : ExitNotTaken) {
8777 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8778 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8779 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8780 "We should only have known counts for exiting blocks that "
8781 "dominate latch!");
8782 ExitCounts.push_back(ExitCount);
8783 if (Predicates)
8784 append_range(*Predicates, ENT.Predicates);
8785
8786 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8787 "Predicate should be always true!");
8788 }
8789 }
8790 if (ExitCounts.empty())
8791 SymbolicMax = SE->getCouldNotCompute();
8792 else
8793 SymbolicMax =
8794 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8795 }
8796 return SymbolicMax;
8797}
8798
8799bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8800 ScalarEvolution *SE) const {
8801 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8802 return !ENT.hasAlwaysTruePredicate();
8803 };
8804 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8805}
8806
8809
8811 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8812 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8816 // If we prove the max count is zero, so is the symbolic bound. This happens
8817 // in practice due to differences in a) how context sensitive we've chosen
8818 // to be and b) how we reason about bounds implied by UB.
8819 if (ConstantMaxNotTaken->isZero()) {
8820 this->ExactNotTaken = E = ConstantMaxNotTaken;
8821 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8822 }
8823
8826 "Exact is not allowed to be less precise than Constant Max");
8829 "Exact is not allowed to be less precise than Symbolic Max");
8832 "Symbolic Max is not allowed to be less precise than Constant Max");
8835 "No point in having a non-constant max backedge taken count!");
8837 for (const auto PredList : PredLists)
8838 for (const auto *P : PredList) {
8839 if (SeenPreds.contains(P))
8840 continue;
8841 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8842 SeenPreds.insert(P);
8843 Predicates.push_back(P);
8844 }
8845 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8846 "Backedge count should be int");
8848 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8849 "Max backedge count should be int");
8850}
8851
8859
8860/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8861/// computable exit into a persistent ExitNotTakenInfo array.
8862ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8864 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8865 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8866 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8867
8868 ExitNotTaken.reserve(ExitCounts.size());
8869 std::transform(ExitCounts.begin(), ExitCounts.end(),
8870 std::back_inserter(ExitNotTaken),
8871 [&](const EdgeExitInfo &EEI) {
8872 BasicBlock *ExitBB = EEI.first;
8873 const ExitLimit &EL = EEI.second;
8874 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8875 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8876 EL.Predicates);
8877 });
8878 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8879 isa<SCEVConstant>(ConstantMax)) &&
8880 "No point in having a non-constant max backedge taken count!");
8881}
8882
8883/// Compute the number of times the backedge of the specified loop will execute.
8884ScalarEvolution::BackedgeTakenInfo
8885ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8886 bool AllowPredicates) {
8887 SmallVector<BasicBlock *, 8> ExitingBlocks;
8888 L->getExitingBlocks(ExitingBlocks);
8889
8890 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8891
8893 bool CouldComputeBECount = true;
8894 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8895 const SCEV *MustExitMaxBECount = nullptr;
8896 const SCEV *MayExitMaxBECount = nullptr;
8897 bool MustExitMaxOrZero = false;
8898 bool IsOnlyExit = ExitingBlocks.size() == 1;
8899
8900 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8901 // and compute maxBECount.
8902 // Do a union of all the predicates here.
8903 for (BasicBlock *ExitBB : ExitingBlocks) {
8904 // We canonicalize untaken exits to br (constant), ignore them so that
8905 // proving an exit untaken doesn't negatively impact our ability to reason
8906 // about the loop as whole.
8907 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8908 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8909 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8910 if (ExitIfTrue == CI->isZero())
8911 continue;
8912 }
8913
8914 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8915
8916 assert((AllowPredicates || EL.Predicates.empty()) &&
8917 "Predicated exit limit when predicates are not allowed!");
8918
8919 // 1. For each exit that can be computed, add an entry to ExitCounts.
8920 // CouldComputeBECount is true only if all exits can be computed.
8921 if (EL.ExactNotTaken != getCouldNotCompute())
8922 ++NumExitCountsComputed;
8923 else
8924 // We couldn't compute an exact value for this exit, so
8925 // we won't be able to compute an exact value for the loop.
8926 CouldComputeBECount = false;
8927 // Remember exit count if either exact or symbolic is known. Because
8928 // Exact always implies symbolic, only check symbolic.
8929 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8930 ExitCounts.emplace_back(ExitBB, EL);
8931 else {
8932 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8933 "Exact is known but symbolic isn't?");
8934 ++NumExitCountsNotComputed;
8935 }
8936
8937 // 2. Derive the loop's MaxBECount from each exit's max number of
8938 // non-exiting iterations. Partition the loop exits into two kinds:
8939 // LoopMustExits and LoopMayExits.
8940 //
8941 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8942 // is a LoopMayExit. If any computable LoopMustExit is found, then
8943 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8944 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8945 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8946 // any
8947 // computable EL.ConstantMaxNotTaken.
8948 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8949 DT.dominates(ExitBB, Latch)) {
8950 if (!MustExitMaxBECount) {
8951 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8952 MustExitMaxOrZero = EL.MaxOrZero;
8953 } else {
8954 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8955 EL.ConstantMaxNotTaken);
8956 }
8957 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8958 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8959 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8960 else {
8961 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8962 EL.ConstantMaxNotTaken);
8963 }
8964 }
8965 }
8966 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8967 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8968 // The loop backedge will be taken the maximum or zero times if there's
8969 // a single exit that must be taken the maximum or zero times.
8970 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8971
8972 // Remember which SCEVs are used in exit limits for invalidation purposes.
8973 // We only care about non-constant SCEVs here, so we can ignore
8974 // EL.ConstantMaxNotTaken
8975 // and MaxBECount, which must be SCEVConstant.
8976 for (const auto &Pair : ExitCounts) {
8977 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8978 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8979 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8980 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8981 {L, AllowPredicates});
8982 }
8983 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8984 MaxBECount, MaxOrZero);
8985}
8986
8987ScalarEvolution::ExitLimit
8988ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8989 bool IsOnlyExit, bool AllowPredicates) {
8990 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8991 // If our exiting block does not dominate the latch, then its connection with
8992 // loop's exit limit may be far from trivial.
8993 const BasicBlock *Latch = L->getLoopLatch();
8994 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8995 return getCouldNotCompute();
8996
8997 Instruction *Term = ExitingBlock->getTerminator();
8998 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8999 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9000 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9001 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9002 "It should have one successor in loop and one exit block!");
9003 // Proceed to the next level to examine the exit condition expression.
9004 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9005 /*ControlsOnlyExit=*/IsOnlyExit,
9006 AllowPredicates);
9007 }
9008
9009 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9010 // For switch, make sure that there is a single exit from the loop.
9011 BasicBlock *Exit = nullptr;
9012 for (auto *SBB : successors(ExitingBlock))
9013 if (!L->contains(SBB)) {
9014 if (Exit) // Multiple exit successors.
9015 return getCouldNotCompute();
9016 Exit = SBB;
9017 }
9018 assert(Exit && "Exiting block must have at least one exit");
9019 return computeExitLimitFromSingleExitSwitch(
9020 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9021 }
9022
9023 return getCouldNotCompute();
9024}
9025
9027 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9028 bool AllowPredicates) {
9029 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9030 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9031 ControlsOnlyExit, AllowPredicates);
9032}
9033
9034std::optional<ScalarEvolution::ExitLimit>
9035ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9036 bool ExitIfTrue, bool ControlsOnlyExit,
9037 bool AllowPredicates) {
9038 (void)this->L;
9039 (void)this->ExitIfTrue;
9040 (void)this->AllowPredicates;
9041
9042 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9043 this->AllowPredicates == AllowPredicates &&
9044 "Variance in assumed invariant key components!");
9045 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9046 if (Itr == TripCountMap.end())
9047 return std::nullopt;
9048 return Itr->second;
9049}
9050
9051void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9052 bool ExitIfTrue,
9053 bool ControlsOnlyExit,
9054 bool AllowPredicates,
9055 const ExitLimit &EL) {
9056 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9057 this->AllowPredicates == AllowPredicates &&
9058 "Variance in assumed invariant key components!");
9059
9060 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9061 assert(InsertResult.second && "Expected successful insertion!");
9062 (void)InsertResult;
9063 (void)ExitIfTrue;
9064}
9065
9066ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9067 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9068 bool ControlsOnlyExit, bool AllowPredicates) {
9069
9070 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9071 AllowPredicates))
9072 return *MaybeEL;
9073
9074 ExitLimit EL = computeExitLimitFromCondImpl(
9075 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9076 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9077 return EL;
9078}
9079
9080ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9081 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9082 bool ControlsOnlyExit, bool AllowPredicates) {
9083 // Handle BinOp conditions (And, Or).
9084 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9085 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9086 return *LimitFromBinOp;
9087
9088 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9089 // Proceed to the next level to examine the icmp.
9090 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9091 ExitLimit EL =
9092 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9093 if (EL.hasFullInfo() || !AllowPredicates)
9094 return EL;
9095
9096 // Try again, but use SCEV predicates this time.
9097 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9098 ControlsOnlyExit,
9099 /*AllowPredicates=*/true);
9100 }
9101
9102 // Check for a constant condition. These are normally stripped out by
9103 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9104 // preserve the CFG and is temporarily leaving constant conditions
9105 // in place.
9106 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9107 if (ExitIfTrue == !CI->getZExtValue())
9108 // The backedge is always taken.
9109 return getCouldNotCompute();
9110 // The backedge is never taken.
9111 return getZero(CI->getType());
9112 }
9113
9114 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9115 // with a constant step, we can form an equivalent icmp predicate and figure
9116 // out how many iterations will be taken before we exit.
9117 const WithOverflowInst *WO;
9118 const APInt *C;
9119 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9120 match(WO->getRHS(), m_APInt(C))) {
9121 ConstantRange NWR =
9123 WO->getNoWrapKind());
9124 CmpInst::Predicate Pred;
9125 APInt NewRHSC, Offset;
9126 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9127 if (!ExitIfTrue)
9128 Pred = ICmpInst::getInversePredicate(Pred);
9129 auto *LHS = getSCEV(WO->getLHS());
9130 if (Offset != 0)
9132 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9133 ControlsOnlyExit, AllowPredicates);
9134 if (EL.hasAnyInfo())
9135 return EL;
9136 }
9137
9138 // If it's not an integer or pointer comparison then compute it the hard way.
9139 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9140}
9141
9142std::optional<ScalarEvolution::ExitLimit>
9143ScalarEvolution::computeExitLimitFromCondFromBinOp(
9144 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9145 bool ControlsOnlyExit, bool AllowPredicates) {
9146 // Check if the controlling expression for this loop is an And or Or.
9147 Value *Op0, *Op1;
9148 bool IsAnd = false;
9149 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9150 IsAnd = true;
9151 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9152 IsAnd = false;
9153 else
9154 return std::nullopt;
9155
9156 // EitherMayExit is true in these two cases:
9157 // br (and Op0 Op1), loop, exit
9158 // br (or Op0 Op1), exit, loop
9159 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9160 ExitLimit EL0 = computeExitLimitFromCondCached(
9161 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9162 AllowPredicates);
9163 ExitLimit EL1 = computeExitLimitFromCondCached(
9164 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9165 AllowPredicates);
9166
9167 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9168 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9169 if (isa<ConstantInt>(Op1))
9170 return Op1 == NeutralElement ? EL0 : EL1;
9171 if (isa<ConstantInt>(Op0))
9172 return Op0 == NeutralElement ? EL1 : EL0;
9173
9174 const SCEV *BECount = getCouldNotCompute();
9175 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9176 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9177 if (EitherMayExit) {
9178 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9179 // Both conditions must be same for the loop to continue executing.
9180 // Choose the less conservative count.
9181 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9182 EL1.ExactNotTaken != getCouldNotCompute()) {
9183 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9184 UseSequentialUMin);
9185 }
9186 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9187 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9188 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9189 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9190 else
9191 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9192 EL1.ConstantMaxNotTaken);
9193 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9194 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9195 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9196 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9197 else
9198 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9199 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9200 } else {
9201 // Both conditions must be same at the same time for the loop to exit.
9202 // For now, be conservative.
9203 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9204 BECount = EL0.ExactNotTaken;
9205 }
9206
9207 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9208 // to be more aggressive when computing BECount than when computing
9209 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9210 // and
9211 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9212 // EL1.ConstantMaxNotTaken to not.
9213 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9214 !isa<SCEVCouldNotCompute>(BECount))
9215 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9216 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9217 SymbolicMaxBECount =
9218 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9219 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9220 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9221}
9222
9223ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9224 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9225 bool AllowPredicates) {
9226 // If the condition was exit on true, convert the condition to exit on false
9227 CmpPredicate Pred;
9228 if (!ExitIfTrue)
9229 Pred = ExitCond->getCmpPredicate();
9230 else
9231 Pred = ExitCond->getInverseCmpPredicate();
9232 const ICmpInst::Predicate OriginalPred = Pred;
9233
9234 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9235 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9236
9237 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9238 AllowPredicates);
9239 if (EL.hasAnyInfo())
9240 return EL;
9241
9242 auto *ExhaustiveCount =
9243 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9244
9245 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9246 return ExhaustiveCount;
9247
9248 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9249 ExitCond->getOperand(1), L, OriginalPred);
9250}
9251ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9252 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9253 bool ControlsOnlyExit, bool AllowPredicates) {
9254
9255 // Try to evaluate any dependencies out of the loop.
9256 LHS = getSCEVAtScope(LHS, L);
9257 RHS = getSCEVAtScope(RHS, L);
9258
9259 // At this point, we would like to compute how many iterations of the
9260 // loop the predicate will return true for these inputs.
9261 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9262 // If there is a loop-invariant, force it into the RHS.
9263 std::swap(LHS, RHS);
9265 }
9266
9267 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9269 // Simplify the operands before analyzing them.
9270 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9271
9272 // If we have a comparison of a chrec against a constant, try to use value
9273 // ranges to answer this query.
9274 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9275 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9276 if (AddRec->getLoop() == L) {
9277 // Form the constant range.
9278 ConstantRange CompRange =
9279 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9280
9281 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9282 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9283 }
9284
9285 // If this loop must exit based on this condition (or execute undefined
9286 // behaviour), see if we can improve wrap flags. This is essentially
9287 // a must execute style proof.
9288 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9289 // If we can prove the test sequence produced must repeat the same values
9290 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9291 // because if it did, we'd have an infinite (undefined) loop.
9292 // TODO: We can peel off any functions which are invertible *in L*. Loop
9293 // invariant terms are effectively constants for our purposes here.
9294 auto *InnerLHS = LHS;
9295 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9296 InnerLHS = ZExt->getOperand();
9297 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9298 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9299 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9300 /*OrNegative=*/true)) {
9301 auto Flags = AR->getNoWrapFlags();
9302 Flags = setFlags(Flags, SCEV::FlagNW);
9305 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9306 }
9307
9308 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9309 // From no-self-wrap, this follows trivially from the fact that every
9310 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9311 // last value before (un)signed wrap. Since we know that last value
9312 // didn't exit, nor will any smaller one.
9313 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9314 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9315 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9316 AR && AR->getLoop() == L && AR->isAffine() &&
9317 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9318 isKnownPositive(AR->getStepRecurrence(*this))) {
9319 auto Flags = AR->getNoWrapFlags();
9320 Flags = setFlags(Flags, WrapType);
9323 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9324 }
9325 }
9326 }
9327
9328 switch (Pred) {
9329 case ICmpInst::ICMP_NE: { // while (X != Y)
9330 // Convert to: while (X-Y != 0)
9331 if (LHS->getType()->isPointerTy()) {
9334 return LHS;
9335 }
9336 if (RHS->getType()->isPointerTy()) {
9339 return RHS;
9340 }
9341 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9342 AllowPredicates);
9343 if (EL.hasAnyInfo())
9344 return EL;
9345 break;
9346 }
9347 case ICmpInst::ICMP_EQ: { // while (X == Y)
9348 // Convert to: while (X-Y == 0)
9349 if (LHS->getType()->isPointerTy()) {
9352 return LHS;
9353 }
9354 if (RHS->getType()->isPointerTy()) {
9357 return RHS;
9358 }
9359 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9360 if (EL.hasAnyInfo()) return EL;
9361 break;
9362 }
9363 case ICmpInst::ICMP_SLE:
9364 case ICmpInst::ICMP_ULE:
9365 // Since the loop is finite, an invariant RHS cannot include the boundary
9366 // value, otherwise it would loop forever.
9367 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9368 !isLoopInvariant(RHS, L)) {
9369 // Otherwise, perform the addition in a wider type, to avoid overflow.
9370 // If the LHS is an addrec with the appropriate nowrap flag, the
9371 // extension will be sunk into it and the exit count can be analyzed.
9372 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9373 if (!OldType)
9374 break;
9375 // Prefer doubling the bitwidth over adding a single bit to make it more
9376 // likely that we use a legal type.
9377 auto *NewType =
9378 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9379 if (ICmpInst::isSigned(Pred)) {
9380 LHS = getSignExtendExpr(LHS, NewType);
9381 RHS = getSignExtendExpr(RHS, NewType);
9382 } else {
9383 LHS = getZeroExtendExpr(LHS, NewType);
9384 RHS = getZeroExtendExpr(RHS, NewType);
9385 }
9386 }
9388 [[fallthrough]];
9389 case ICmpInst::ICMP_SLT:
9390 case ICmpInst::ICMP_ULT: { // while (X < Y)
9391 bool IsSigned = ICmpInst::isSigned(Pred);
9392 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9393 AllowPredicates);
9394 if (EL.hasAnyInfo())
9395 return EL;
9396 break;
9397 }
9398 case ICmpInst::ICMP_SGE:
9399 case ICmpInst::ICMP_UGE:
9400 // Since the loop is finite, an invariant RHS cannot include the boundary
9401 // value, otherwise it would loop forever.
9402 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9403 !isLoopInvariant(RHS, L))
9404 break;
9406 [[fallthrough]];
9407 case ICmpInst::ICMP_SGT:
9408 case ICmpInst::ICMP_UGT: { // while (X > Y)
9409 bool IsSigned = ICmpInst::isSigned(Pred);
9410 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9411 AllowPredicates);
9412 if (EL.hasAnyInfo())
9413 return EL;
9414 break;
9415 }
9416 default:
9417 break;
9418 }
9419
9420 return getCouldNotCompute();
9421}
9422
9423ScalarEvolution::ExitLimit
9424ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9425 SwitchInst *Switch,
9426 BasicBlock *ExitingBlock,
9427 bool ControlsOnlyExit) {
9428 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9429
9430 // Give up if the exit is the default dest of a switch.
9431 if (Switch->getDefaultDest() == ExitingBlock)
9432 return getCouldNotCompute();
9433
9434 assert(L->contains(Switch->getDefaultDest()) &&
9435 "Default case must not exit the loop!");
9436 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9437 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9438
9439 // while (X != Y) --> while (X-Y != 0)
9440 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9441 if (EL.hasAnyInfo())
9442 return EL;
9443
9444 return getCouldNotCompute();
9445}
9446
9447static ConstantInt *
9449 ScalarEvolution &SE) {
9450 const SCEV *InVal = SE.getConstant(C);
9451 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9453 "Evaluation of SCEV at constant didn't fold correctly?");
9454 return cast<SCEVConstant>(Val)->getValue();
9455}
9456
9457ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9458 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9459 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9460 if (!RHS)
9461 return getCouldNotCompute();
9462
9463 const BasicBlock *Latch = L->getLoopLatch();
9464 if (!Latch)
9465 return getCouldNotCompute();
9466
9467 const BasicBlock *Predecessor = L->getLoopPredecessor();
9468 if (!Predecessor)
9469 return getCouldNotCompute();
9470
9471 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9472 // Return LHS in OutLHS and shift_opt in OutOpCode.
9473 auto MatchPositiveShift =
9474 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9475
9476 using namespace PatternMatch;
9477
9478 ConstantInt *ShiftAmt;
9479 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9480 OutOpCode = Instruction::LShr;
9481 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9482 OutOpCode = Instruction::AShr;
9483 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9484 OutOpCode = Instruction::Shl;
9485 else
9486 return false;
9487
9488 return ShiftAmt->getValue().isStrictlyPositive();
9489 };
9490
9491 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9492 //
9493 // loop:
9494 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9495 // %iv.shifted = lshr i32 %iv, <positive constant>
9496 //
9497 // Return true on a successful match. Return the corresponding PHI node (%iv
9498 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9499 auto MatchShiftRecurrence =
9500 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9501 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9502
9503 {
9505 Value *V;
9506
9507 // If we encounter a shift instruction, "peel off" the shift operation,
9508 // and remember that we did so. Later when we inspect %iv's backedge
9509 // value, we will make sure that the backedge value uses the same
9510 // operation.
9511 //
9512 // Note: the peeled shift operation does not have to be the same
9513 // instruction as the one feeding into the PHI's backedge value. We only
9514 // really care about it being the same *kind* of shift instruction --
9515 // that's all that is required for our later inferences to hold.
9516 if (MatchPositiveShift(LHS, V, OpC)) {
9517 PostShiftOpCode = OpC;
9518 LHS = V;
9519 }
9520 }
9521
9522 PNOut = dyn_cast<PHINode>(LHS);
9523 if (!PNOut || PNOut->getParent() != L->getHeader())
9524 return false;
9525
9526 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9527 Value *OpLHS;
9528
9529 return
9530 // The backedge value for the PHI node must be a shift by a positive
9531 // amount
9532 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9533
9534 // of the PHI node itself
9535 OpLHS == PNOut &&
9536
9537 // and the kind of shift should be match the kind of shift we peeled
9538 // off, if any.
9539 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9540 };
9541
9542 PHINode *PN;
9544 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9545 return getCouldNotCompute();
9546
9547 const DataLayout &DL = getDataLayout();
9548
9549 // The key rationale for this optimization is that for some kinds of shift
9550 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9551 // within a finite number of iterations. If the condition guarding the
9552 // backedge (in the sense that the backedge is taken if the condition is true)
9553 // is false for the value the shift recurrence stabilizes to, then we know
9554 // that the backedge is taken only a finite number of times.
9555
9556 ConstantInt *StableValue = nullptr;
9557 switch (OpCode) {
9558 default:
9559 llvm_unreachable("Impossible case!");
9560
9561 case Instruction::AShr: {
9562 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9563 // bitwidth(K) iterations.
9564 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9565 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9566 Predecessor->getTerminator(), &DT);
9567 auto *Ty = cast<IntegerType>(RHS->getType());
9568 if (Known.isNonNegative())
9569 StableValue = ConstantInt::get(Ty, 0);
9570 else if (Known.isNegative())
9571 StableValue = ConstantInt::get(Ty, -1, true);
9572 else
9573 return getCouldNotCompute();
9574
9575 break;
9576 }
9577 case Instruction::LShr:
9578 case Instruction::Shl:
9579 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9580 // stabilize to 0 in at most bitwidth(K) iterations.
9581 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9582 break;
9583 }
9584
9585 auto *Result =
9586 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9587 assert(Result->getType()->isIntegerTy(1) &&
9588 "Otherwise cannot be an operand to a branch instruction");
9589
9590 if (Result->isZeroValue()) {
9591 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9592 const SCEV *UpperBound =
9594 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9595 }
9596
9597 return getCouldNotCompute();
9598}
9599
9600/// Return true if we can constant fold an instruction of the specified type,
9601/// assuming that all operands were constants.
9602static bool CanConstantFold(const Instruction *I) {
9606 return true;
9607
9608 if (const CallInst *CI = dyn_cast<CallInst>(I))
9609 if (const Function *F = CI->getCalledFunction())
9610 return canConstantFoldCallTo(CI, F);
9611 return false;
9612}
9613
9614/// Determine whether this instruction can constant evolve within this loop
9615/// assuming its operands can all constant evolve.
9616static bool canConstantEvolve(Instruction *I, const Loop *L) {
9617 // An instruction outside of the loop can't be derived from a loop PHI.
9618 if (!L->contains(I)) return false;
9619
9620 if (isa<PHINode>(I)) {
9621 // We don't currently keep track of the control flow needed to evaluate
9622 // PHIs, so we cannot handle PHIs inside of loops.
9623 return L->getHeader() == I->getParent();
9624 }
9625
9626 // If we won't be able to constant fold this expression even if the operands
9627 // are constants, bail early.
9628 return CanConstantFold(I);
9629}
9630
9631/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9632/// recursing through each instruction operand until reaching a loop header phi.
9633static PHINode *
9636 unsigned Depth) {
9638 return nullptr;
9639
9640 // Otherwise, we can evaluate this instruction if all of its operands are
9641 // constant or derived from a PHI node themselves.
9642 PHINode *PHI = nullptr;
9643 for (Value *Op : UseInst->operands()) {
9644 if (isa<Constant>(Op)) continue;
9645
9647 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9648
9649 PHINode *P = dyn_cast<PHINode>(OpInst);
9650 if (!P)
9651 // If this operand is already visited, reuse the prior result.
9652 // We may have P != PHI if this is the deepest point at which the
9653 // inconsistent paths meet.
9654 P = PHIMap.lookup(OpInst);
9655 if (!P) {
9656 // Recurse and memoize the results, whether a phi is found or not.
9657 // This recursive call invalidates pointers into PHIMap.
9658 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9659 PHIMap[OpInst] = P;
9660 }
9661 if (!P)
9662 return nullptr; // Not evolving from PHI
9663 if (PHI && PHI != P)
9664 return nullptr; // Evolving from multiple different PHIs.
9665 PHI = P;
9666 }
9667 // This is a expression evolving from a constant PHI!
9668 return PHI;
9669}
9670
9671/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9672/// in the loop that V is derived from. We allow arbitrary operations along the
9673/// way, but the operands of an operation must either be constants or a value
9674/// derived from a constant PHI. If this expression does not fit with these
9675/// constraints, return null.
9678 if (!I || !canConstantEvolve(I, L)) return nullptr;
9679
9680 if (PHINode *PN = dyn_cast<PHINode>(I))
9681 return PN;
9682
9683 // Record non-constant instructions contained by the loop.
9685 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9686}
9687
9688/// EvaluateExpression - Given an expression that passes the
9689/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9690/// in the loop has the value PHIVal. If we can't fold this expression for some
9691/// reason, return null.
9694 const DataLayout &DL,
9695 const TargetLibraryInfo *TLI) {
9696 // Convenient constant check, but redundant for recursive calls.
9697 if (Constant *C = dyn_cast<Constant>(V)) return C;
9699 if (!I) return nullptr;
9700
9701 if (Constant *C = Vals.lookup(I)) return C;
9702
9703 // An instruction inside the loop depends on a value outside the loop that we
9704 // weren't given a mapping for, or a value such as a call inside the loop.
9705 if (!canConstantEvolve(I, L)) return nullptr;
9706
9707 // An unmapped PHI can be due to a branch or another loop inside this loop,
9708 // or due to this not being the initial iteration through a loop where we
9709 // couldn't compute the evolution of this particular PHI last time.
9710 if (isa<PHINode>(I)) return nullptr;
9711
9712 std::vector<Constant*> Operands(I->getNumOperands());
9713
9714 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9715 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9716 if (!Operand) {
9717 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9718 if (!Operands[i]) return nullptr;
9719 continue;
9720 }
9721 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9722 Vals[Operand] = C;
9723 if (!C) return nullptr;
9724 Operands[i] = C;
9725 }
9726
9727 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9728 /*AllowNonDeterministic=*/false);
9729}
9730
9731
9732// If every incoming value to PN except the one for BB is a specific Constant,
9733// return that, else return nullptr.
9735 Constant *IncomingVal = nullptr;
9736
9737 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9738 if (PN->getIncomingBlock(i) == BB)
9739 continue;
9740
9741 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9742 if (!CurrentVal)
9743 return nullptr;
9744
9745 if (IncomingVal != CurrentVal) {
9746 if (IncomingVal)
9747 return nullptr;
9748 IncomingVal = CurrentVal;
9749 }
9750 }
9751
9752 return IncomingVal;
9753}
9754
9755/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9756/// in the header of its containing loop, we know the loop executes a
9757/// constant number of times, and the PHI node is just a recurrence
9758/// involving constants, fold it.
9759Constant *
9760ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9761 const APInt &BEs,
9762 const Loop *L) {
9763 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9764 if (!Inserted)
9765 return I->second;
9766
9768 return nullptr; // Not going to evaluate it.
9769
9770 Constant *&RetVal = I->second;
9771
9772 DenseMap<Instruction *, Constant *> CurrentIterVals;
9773 BasicBlock *Header = L->getHeader();
9774 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9775
9776 BasicBlock *Latch = L->getLoopLatch();
9777 if (!Latch)
9778 return nullptr;
9779
9780 for (PHINode &PHI : Header->phis()) {
9781 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9782 CurrentIterVals[&PHI] = StartCST;
9783 }
9784 if (!CurrentIterVals.count(PN))
9785 return RetVal = nullptr;
9786
9787 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9788
9789 // Execute the loop symbolically to determine the exit value.
9790 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9791 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9792
9793 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9794 unsigned IterationNum = 0;
9795 const DataLayout &DL = getDataLayout();
9796 for (; ; ++IterationNum) {
9797 if (IterationNum == NumIterations)
9798 return RetVal = CurrentIterVals[PN]; // Got exit value!
9799
9800 // Compute the value of the PHIs for the next iteration.
9801 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9802 DenseMap<Instruction *, Constant *> NextIterVals;
9803 Constant *NextPHI =
9804 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9805 if (!NextPHI)
9806 return nullptr; // Couldn't evaluate!
9807 NextIterVals[PN] = NextPHI;
9808
9809 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9810
9811 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9812 // cease to be able to evaluate one of them or if they stop evolving,
9813 // because that doesn't necessarily prevent us from computing PN.
9815 for (const auto &I : CurrentIterVals) {
9816 PHINode *PHI = dyn_cast<PHINode>(I.first);
9817 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9818 PHIsToCompute.emplace_back(PHI, I.second);
9819 }
9820 // We use two distinct loops because EvaluateExpression may invalidate any
9821 // iterators into CurrentIterVals.
9822 for (const auto &I : PHIsToCompute) {
9823 PHINode *PHI = I.first;
9824 Constant *&NextPHI = NextIterVals[PHI];
9825 if (!NextPHI) { // Not already computed.
9826 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9827 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9828 }
9829 if (NextPHI != I.second)
9830 StoppedEvolving = false;
9831 }
9832
9833 // If all entries in CurrentIterVals == NextIterVals then we can stop
9834 // iterating, the loop can't continue to change.
9835 if (StoppedEvolving)
9836 return RetVal = CurrentIterVals[PN];
9837
9838 CurrentIterVals.swap(NextIterVals);
9839 }
9840}
9841
9842const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9843 Value *Cond,
9844 bool ExitWhen) {
9845 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9846 if (!PN) return getCouldNotCompute();
9847
9848 // If the loop is canonicalized, the PHI will have exactly two entries.
9849 // That's the only form we support here.
9850 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9851
9852 DenseMap<Instruction *, Constant *> CurrentIterVals;
9853 BasicBlock *Header = L->getHeader();
9854 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9855
9856 BasicBlock *Latch = L->getLoopLatch();
9857 assert(Latch && "Should follow from NumIncomingValues == 2!");
9858
9859 for (PHINode &PHI : Header->phis()) {
9860 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9861 CurrentIterVals[&PHI] = StartCST;
9862 }
9863 if (!CurrentIterVals.count(PN))
9864 return getCouldNotCompute();
9865
9866 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9867 // the loop symbolically to determine when the condition gets a value of
9868 // "ExitWhen".
9869 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9870 const DataLayout &DL = getDataLayout();
9871 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9872 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9873 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9874
9875 // Couldn't symbolically evaluate.
9876 if (!CondVal) return getCouldNotCompute();
9877
9878 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9879 ++NumBruteForceTripCountsComputed;
9880 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9881 }
9882
9883 // Update all the PHI nodes for the next iteration.
9884 DenseMap<Instruction *, Constant *> NextIterVals;
9885
9886 // Create a list of which PHIs we need to compute. We want to do this before
9887 // calling EvaluateExpression on them because that may invalidate iterators
9888 // into CurrentIterVals.
9889 SmallVector<PHINode *, 8> PHIsToCompute;
9890 for (const auto &I : CurrentIterVals) {
9891 PHINode *PHI = dyn_cast<PHINode>(I.first);
9892 if (!PHI || PHI->getParent() != Header) continue;
9893 PHIsToCompute.push_back(PHI);
9894 }
9895 for (PHINode *PHI : PHIsToCompute) {
9896 Constant *&NextPHI = NextIterVals[PHI];
9897 if (NextPHI) continue; // Already computed!
9898
9899 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9900 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9901 }
9902 CurrentIterVals.swap(NextIterVals);
9903 }
9904
9905 // Too many iterations were needed to evaluate.
9906 return getCouldNotCompute();
9907}
9908
9909const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9911 ValuesAtScopes[V];
9912 // Check to see if we've folded this expression at this loop before.
9913 for (auto &LS : Values)
9914 if (LS.first == L)
9915 return LS.second ? LS.second : V;
9916
9917 Values.emplace_back(L, nullptr);
9918
9919 // Otherwise compute it.
9920 const SCEV *C = computeSCEVAtScope(V, L);
9921 for (auto &LS : reverse(ValuesAtScopes[V]))
9922 if (LS.first == L) {
9923 LS.second = C;
9924 if (!isa<SCEVConstant>(C))
9925 ValuesAtScopesUsers[C].push_back({L, V});
9926 break;
9927 }
9928 return C;
9929}
9930
9931/// This builds up a Constant using the ConstantExpr interface. That way, we
9932/// will return Constants for objects which aren't represented by a
9933/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9934/// Returns NULL if the SCEV isn't representable as a Constant.
9936 switch (V->getSCEVType()) {
9937 case scCouldNotCompute:
9938 case scAddRecExpr:
9939 case scVScale:
9940 return nullptr;
9941 case scConstant:
9942 return cast<SCEVConstant>(V)->getValue();
9943 case scUnknown:
9944 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9945 case scPtrToInt: {
9947 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9948 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9949
9950 return nullptr;
9951 }
9952 case scTruncate: {
9954 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9955 return ConstantExpr::getTrunc(CastOp, ST->getType());
9956 return nullptr;
9957 }
9958 case scAddExpr: {
9959 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9960 Constant *C = nullptr;
9961 for (const SCEV *Op : SA->operands()) {
9963 if (!OpC)
9964 return nullptr;
9965 if (!C) {
9966 C = OpC;
9967 continue;
9968 }
9969 assert(!C->getType()->isPointerTy() &&
9970 "Can only have one pointer, and it must be last");
9971 if (OpC->getType()->isPointerTy()) {
9972 // The offsets have been converted to bytes. We can add bytes using
9973 // an i8 GEP.
9975 OpC, C);
9976 } else {
9977 C = ConstantExpr::getAdd(C, OpC);
9978 }
9979 }
9980 return C;
9981 }
9982 case scMulExpr:
9983 case scSignExtend:
9984 case scZeroExtend:
9985 case scUDivExpr:
9986 case scSMaxExpr:
9987 case scUMaxExpr:
9988 case scSMinExpr:
9989 case scUMinExpr:
9991 return nullptr;
9992 }
9993 llvm_unreachable("Unknown SCEV kind!");
9994}
9995
9996const SCEV *
9997ScalarEvolution::getWithOperands(const SCEV *S,
9998 SmallVectorImpl<const SCEV *> &NewOps) {
9999 switch (S->getSCEVType()) {
10000 case scTruncate:
10001 case scZeroExtend:
10002 case scSignExtend:
10003 case scPtrToInt:
10004 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10005 case scAddRecExpr: {
10006 auto *AddRec = cast<SCEVAddRecExpr>(S);
10007 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10008 }
10009 case scAddExpr:
10010 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10011 case scMulExpr:
10012 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10013 case scUDivExpr:
10014 return getUDivExpr(NewOps[0], NewOps[1]);
10015 case scUMaxExpr:
10016 case scSMaxExpr:
10017 case scUMinExpr:
10018 case scSMinExpr:
10019 return getMinMaxExpr(S->getSCEVType(), NewOps);
10021 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10022 case scConstant:
10023 case scVScale:
10024 case scUnknown:
10025 return S;
10026 case scCouldNotCompute:
10027 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10028 }
10029 llvm_unreachable("Unknown SCEV kind!");
10030}
10031
10032const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10033 switch (V->getSCEVType()) {
10034 case scConstant:
10035 case scVScale:
10036 return V;
10037 case scAddRecExpr: {
10038 // If this is a loop recurrence for a loop that does not contain L, then we
10039 // are dealing with the final value computed by the loop.
10040 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10041 // First, attempt to evaluate each operand.
10042 // Avoid performing the look-up in the common case where the specified
10043 // expression has no loop-variant portions.
10044 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10045 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10046 if (OpAtScope == AddRec->getOperand(i))
10047 continue;
10048
10049 // Okay, at least one of these operands is loop variant but might be
10050 // foldable. Build a new instance of the folded commutative expression.
10052 NewOps.reserve(AddRec->getNumOperands());
10053 append_range(NewOps, AddRec->operands().take_front(i));
10054 NewOps.push_back(OpAtScope);
10055 for (++i; i != e; ++i)
10056 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10057
10058 const SCEV *FoldedRec = getAddRecExpr(
10059 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10060 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10061 // The addrec may be folded to a nonrecurrence, for example, if the
10062 // induction variable is multiplied by zero after constant folding. Go
10063 // ahead and return the folded value.
10064 if (!AddRec)
10065 return FoldedRec;
10066 break;
10067 }
10068
10069 // If the scope is outside the addrec's loop, evaluate it by using the
10070 // loop exit value of the addrec.
10071 if (!AddRec->getLoop()->contains(L)) {
10072 // To evaluate this recurrence, we need to know how many times the AddRec
10073 // loop iterates. Compute this now.
10074 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10075 if (BackedgeTakenCount == getCouldNotCompute())
10076 return AddRec;
10077
10078 // Then, evaluate the AddRec.
10079 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10080 }
10081
10082 return AddRec;
10083 }
10084 case scTruncate:
10085 case scZeroExtend:
10086 case scSignExtend:
10087 case scPtrToInt:
10088 case scAddExpr:
10089 case scMulExpr:
10090 case scUDivExpr:
10091 case scUMaxExpr:
10092 case scSMaxExpr:
10093 case scUMinExpr:
10094 case scSMinExpr:
10095 case scSequentialUMinExpr: {
10096 ArrayRef<const SCEV *> Ops = V->operands();
10097 // Avoid performing the look-up in the common case where the specified
10098 // expression has no loop-variant portions.
10099 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10100 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10101 if (OpAtScope != Ops[i]) {
10102 // Okay, at least one of these operands is loop variant but might be
10103 // foldable. Build a new instance of the folded commutative expression.
10105 NewOps.reserve(Ops.size());
10106 append_range(NewOps, Ops.take_front(i));
10107 NewOps.push_back(OpAtScope);
10108
10109 for (++i; i != e; ++i) {
10110 OpAtScope = getSCEVAtScope(Ops[i], L);
10111 NewOps.push_back(OpAtScope);
10112 }
10113
10114 return getWithOperands(V, NewOps);
10115 }
10116 }
10117 // If we got here, all operands are loop invariant.
10118 return V;
10119 }
10120 case scUnknown: {
10121 // If this instruction is evolved from a constant-evolving PHI, compute the
10122 // exit value from the loop without using SCEVs.
10123 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10125 if (!I)
10126 return V; // This is some other type of SCEVUnknown, just return it.
10127
10128 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10129 const Loop *CurrLoop = this->LI[I->getParent()];
10130 // Looking for loop exit value.
10131 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10132 PN->getParent() == CurrLoop->getHeader()) {
10133 // Okay, there is no closed form solution for the PHI node. Check
10134 // to see if the loop that contains it has a known backedge-taken
10135 // count. If so, we may be able to force computation of the exit
10136 // value.
10137 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10138 // This trivial case can show up in some degenerate cases where
10139 // the incoming IR has not yet been fully simplified.
10140 if (BackedgeTakenCount->isZero()) {
10141 Value *InitValue = nullptr;
10142 bool MultipleInitValues = false;
10143 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10144 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10145 if (!InitValue)
10146 InitValue = PN->getIncomingValue(i);
10147 else if (InitValue != PN->getIncomingValue(i)) {
10148 MultipleInitValues = true;
10149 break;
10150 }
10151 }
10152 }
10153 if (!MultipleInitValues && InitValue)
10154 return getSCEV(InitValue);
10155 }
10156 // Do we have a loop invariant value flowing around the backedge
10157 // for a loop which must execute the backedge?
10158 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10159 isKnownNonZero(BackedgeTakenCount) &&
10160 PN->getNumIncomingValues() == 2) {
10161
10162 unsigned InLoopPred =
10163 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10164 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10165 if (CurrLoop->isLoopInvariant(BackedgeVal))
10166 return getSCEV(BackedgeVal);
10167 }
10168 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10169 // Okay, we know how many times the containing loop executes. If
10170 // this is a constant evolving PHI node, get the final value at
10171 // the specified iteration number.
10172 Constant *RV =
10173 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10174 if (RV)
10175 return getSCEV(RV);
10176 }
10177 }
10178 }
10179
10180 // Okay, this is an expression that we cannot symbolically evaluate
10181 // into a SCEV. Check to see if it's possible to symbolically evaluate
10182 // the arguments into constants, and if so, try to constant propagate the
10183 // result. This is particularly useful for computing loop exit values.
10184 if (!CanConstantFold(I))
10185 return V; // This is some other type of SCEVUnknown, just return it.
10186
10188 Operands.reserve(I->getNumOperands());
10189 bool MadeImprovement = false;
10190 for (Value *Op : I->operands()) {
10191 if (Constant *C = dyn_cast<Constant>(Op)) {
10192 Operands.push_back(C);
10193 continue;
10194 }
10195
10196 // If any of the operands is non-constant and if they are
10197 // non-integer and non-pointer, don't even try to analyze them
10198 // with scev techniques.
10199 if (!isSCEVable(Op->getType()))
10200 return V;
10201
10202 const SCEV *OrigV = getSCEV(Op);
10203 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10204 MadeImprovement |= OrigV != OpV;
10205
10207 if (!C)
10208 return V;
10209 assert(C->getType() == Op->getType() && "Type mismatch");
10210 Operands.push_back(C);
10211 }
10212
10213 // Check to see if getSCEVAtScope actually made an improvement.
10214 if (!MadeImprovement)
10215 return V; // This is some other type of SCEVUnknown, just return it.
10216
10217 Constant *C = nullptr;
10218 const DataLayout &DL = getDataLayout();
10219 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10220 /*AllowNonDeterministic=*/false);
10221 if (!C)
10222 return V;
10223 return getSCEV(C);
10224 }
10225 case scCouldNotCompute:
10226 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10227 }
10228 llvm_unreachable("Unknown SCEV type!");
10229}
10230
10232 return getSCEVAtScope(getSCEV(V), L);
10233}
10234
10235const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10237 return stripInjectiveFunctions(ZExt->getOperand());
10239 return stripInjectiveFunctions(SExt->getOperand());
10240 return S;
10241}
10242
10243/// Finds the minimum unsigned root of the following equation:
10244///
10245/// A * X = B (mod N)
10246///
10247/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10248/// A and B isn't important.
10249///
10250/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10251static const SCEV *
10254 ScalarEvolution &SE, const Loop *L) {
10255 uint32_t BW = A.getBitWidth();
10256 assert(BW == SE.getTypeSizeInBits(B->getType()));
10257 assert(A != 0 && "A must be non-zero.");
10258
10259 // 1. D = gcd(A, N)
10260 //
10261 // The gcd of A and N may have only one prime factor: 2. The number of
10262 // trailing zeros in A is its multiplicity
10263 uint32_t Mult2 = A.countr_zero();
10264 // D = 2^Mult2
10265
10266 // 2. Check if B is divisible by D.
10267 //
10268 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10269 // is not less than multiplicity of this prime factor for D.
10270 unsigned MinTZ = SE.getMinTrailingZeros(B);
10271 // Try again with the terminator of the loop predecessor for context-specific
10272 // result, if MinTZ s too small.
10273 if (MinTZ < Mult2 && L->getLoopPredecessor())
10274 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10275 if (MinTZ < Mult2) {
10276 // Check if we can prove there's no remainder using URem.
10277 const SCEV *URem =
10278 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10279 const SCEV *Zero = SE.getZero(B->getType());
10280 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10281 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10282 if (!Predicates)
10283 return SE.getCouldNotCompute();
10284
10285 // Avoid adding a predicate that is known to be false.
10286 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10287 return SE.getCouldNotCompute();
10288 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10289 }
10290 }
10291
10292 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10293 // modulo (N / D).
10294 //
10295 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10296 // (N / D) in general. The inverse itself always fits into BW bits, though,
10297 // so we immediately truncate it.
10298 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10299 APInt I = AD.multiplicativeInverse().zext(BW);
10300
10301 // 4. Compute the minimum unsigned root of the equation:
10302 // I * (B / D) mod (N / D)
10303 // To simplify the computation, we factor out the divide by D:
10304 // (I * B mod N) / D
10305 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10306 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10307}
10308
10309/// For a given quadratic addrec, generate coefficients of the corresponding
10310/// quadratic equation, multiplied by a common value to ensure that they are
10311/// integers.
10312/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10313/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10314/// were multiplied by, and BitWidth is the bit width of the original addrec
10315/// coefficients.
10316/// This function returns std::nullopt if the addrec coefficients are not
10317/// compile- time constants.
10318static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10320 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10321 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10322 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10323 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10324 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10325 << *AddRec << '\n');
10326
10327 // We currently can only solve this if the coefficients are constants.
10328 if (!LC || !MC || !NC) {
10329 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10330 return std::nullopt;
10331 }
10332
10333 APInt L = LC->getAPInt();
10334 APInt M = MC->getAPInt();
10335 APInt N = NC->getAPInt();
10336 assert(!N.isZero() && "This is not a quadratic addrec");
10337
10338 unsigned BitWidth = LC->getAPInt().getBitWidth();
10339 unsigned NewWidth = BitWidth + 1;
10340 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10341 << BitWidth << '\n');
10342 // The sign-extension (as opposed to a zero-extension) here matches the
10343 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10344 N = N.sext(NewWidth);
10345 M = M.sext(NewWidth);
10346 L = L.sext(NewWidth);
10347
10348 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10349 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10350 // L+M, L+2M+N, L+3M+3N, ...
10351 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10352 //
10353 // The equation Acc = 0 is then
10354 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10355 // In a quadratic form it becomes:
10356 // N n^2 + (2M-N) n + 2L = 0.
10357
10358 APInt A = N;
10359 APInt B = 2 * M - A;
10360 APInt C = 2 * L;
10361 APInt T = APInt(NewWidth, 2);
10362 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10363 << "x + " << C << ", coeff bw: " << NewWidth
10364 << ", multiplied by " << T << '\n');
10365 return std::make_tuple(A, B, C, T, BitWidth);
10366}
10367
10368/// Helper function to compare optional APInts:
10369/// (a) if X and Y both exist, return min(X, Y),
10370/// (b) if neither X nor Y exist, return std::nullopt,
10371/// (c) if exactly one of X and Y exists, return that value.
10372static std::optional<APInt> MinOptional(std::optional<APInt> X,
10373 std::optional<APInt> Y) {
10374 if (X && Y) {
10375 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10376 APInt XW = X->sext(W);
10377 APInt YW = Y->sext(W);
10378 return XW.slt(YW) ? *X : *Y;
10379 }
10380 if (!X && !Y)
10381 return std::nullopt;
10382 return X ? *X : *Y;
10383}
10384
10385/// Helper function to truncate an optional APInt to a given BitWidth.
10386/// When solving addrec-related equations, it is preferable to return a value
10387/// that has the same bit width as the original addrec's coefficients. If the
10388/// solution fits in the original bit width, truncate it (except for i1).
10389/// Returning a value of a different bit width may inhibit some optimizations.
10390///
10391/// In general, a solution to a quadratic equation generated from an addrec
10392/// may require BW+1 bits, where BW is the bit width of the addrec's
10393/// coefficients. The reason is that the coefficients of the quadratic
10394/// equation are BW+1 bits wide (to avoid truncation when converting from
10395/// the addrec to the equation).
10396static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10397 unsigned BitWidth) {
10398 if (!X)
10399 return std::nullopt;
10400 unsigned W = X->getBitWidth();
10402 return X->trunc(BitWidth);
10403 return X;
10404}
10405
10406/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10407/// iterations. The values L, M, N are assumed to be signed, and they
10408/// should all have the same bit widths.
10409/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10410/// where BW is the bit width of the addrec's coefficients.
10411/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10412/// returned as such, otherwise the bit width of the returned value may
10413/// be greater than BW.
10414///
10415/// This function returns std::nullopt if
10416/// (a) the addrec coefficients are not constant, or
10417/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10418/// like x^2 = 5, no integer solutions exist, in other cases an integer
10419/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10420static std::optional<APInt>
10422 APInt A, B, C, M;
10423 unsigned BitWidth;
10424 auto T = GetQuadraticEquation(AddRec);
10425 if (!T)
10426 return std::nullopt;
10427
10428 std::tie(A, B, C, M, BitWidth) = *T;
10429 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10430 std::optional<APInt> X =
10432 if (!X)
10433 return std::nullopt;
10434
10435 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10436 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10437 if (!V->isZero())
10438 return std::nullopt;
10439
10440 return TruncIfPossible(X, BitWidth);
10441}
10442
10443/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10444/// iterations. The values M, N are assumed to be signed, and they
10445/// should all have the same bit widths.
10446/// Find the least n such that c(n) does not belong to the given range,
10447/// while c(n-1) does.
10448///
10449/// This function returns std::nullopt if
10450/// (a) the addrec coefficients are not constant, or
10451/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10452/// bounds of the range.
10453static std::optional<APInt>
10455 const ConstantRange &Range, ScalarEvolution &SE) {
10456 assert(AddRec->getOperand(0)->isZero() &&
10457 "Starting value of addrec should be 0");
10458 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10459 << Range << ", addrec " << *AddRec << '\n');
10460 // This case is handled in getNumIterationsInRange. Here we can assume that
10461 // we start in the range.
10462 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10463 "Addrec's initial value should be in range");
10464
10465 APInt A, B, C, M;
10466 unsigned BitWidth;
10467 auto T = GetQuadraticEquation(AddRec);
10468 if (!T)
10469 return std::nullopt;
10470
10471 // Be careful about the return value: there can be two reasons for not
10472 // returning an actual number. First, if no solutions to the equations
10473 // were found, and second, if the solutions don't leave the given range.
10474 // The first case means that the actual solution is "unknown", the second
10475 // means that it's known, but not valid. If the solution is unknown, we
10476 // cannot make any conclusions.
10477 // Return a pair: the optional solution and a flag indicating if the
10478 // solution was found.
10479 auto SolveForBoundary =
10480 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10481 // Solve for signed overflow and unsigned overflow, pick the lower
10482 // solution.
10483 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10484 << Bound << " (before multiplying by " << M << ")\n");
10485 Bound *= M; // The quadratic equation multiplier.
10486
10487 std::optional<APInt> SO;
10488 if (BitWidth > 1) {
10489 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10490 "signed overflow\n");
10492 }
10493 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10494 "unsigned overflow\n");
10495 std::optional<APInt> UO =
10497
10498 auto LeavesRange = [&] (const APInt &X) {
10499 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10500 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10501 if (Range.contains(V0->getValue()))
10502 return false;
10503 // X should be at least 1, so X-1 is non-negative.
10504 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10505 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10506 if (Range.contains(V1->getValue()))
10507 return true;
10508 return false;
10509 };
10510
10511 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10512 // can be a solution, but the function failed to find it. We cannot treat it
10513 // as "no solution".
10514 if (!SO || !UO)
10515 return {std::nullopt, false};
10516
10517 // Check the smaller value first to see if it leaves the range.
10518 // At this point, both SO and UO must have values.
10519 std::optional<APInt> Min = MinOptional(SO, UO);
10520 if (LeavesRange(*Min))
10521 return { Min, true };
10522 std::optional<APInt> Max = Min == SO ? UO : SO;
10523 if (LeavesRange(*Max))
10524 return { Max, true };
10525
10526 // Solutions were found, but were eliminated, hence the "true".
10527 return {std::nullopt, true};
10528 };
10529
10530 std::tie(A, B, C, M, BitWidth) = *T;
10531 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10532 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10533 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10534 auto SL = SolveForBoundary(Lower);
10535 auto SU = SolveForBoundary(Upper);
10536 // If any of the solutions was unknown, no meaninigful conclusions can
10537 // be made.
10538 if (!SL.second || !SU.second)
10539 return std::nullopt;
10540
10541 // Claim: The correct solution is not some value between Min and Max.
10542 //
10543 // Justification: Assuming that Min and Max are different values, one of
10544 // them is when the first signed overflow happens, the other is when the
10545 // first unsigned overflow happens. Crossing the range boundary is only
10546 // possible via an overflow (treating 0 as a special case of it, modeling
10547 // an overflow as crossing k*2^W for some k).
10548 //
10549 // The interesting case here is when Min was eliminated as an invalid
10550 // solution, but Max was not. The argument is that if there was another
10551 // overflow between Min and Max, it would also have been eliminated if
10552 // it was considered.
10553 //
10554 // For a given boundary, it is possible to have two overflows of the same
10555 // type (signed/unsigned) without having the other type in between: this
10556 // can happen when the vertex of the parabola is between the iterations
10557 // corresponding to the overflows. This is only possible when the two
10558 // overflows cross k*2^W for the same k. In such case, if the second one
10559 // left the range (and was the first one to do so), the first overflow
10560 // would have to enter the range, which would mean that either we had left
10561 // the range before or that we started outside of it. Both of these cases
10562 // are contradictions.
10563 //
10564 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10565 // solution is not some value between the Max for this boundary and the
10566 // Min of the other boundary.
10567 //
10568 // Justification: Assume that we had such Max_A and Min_B corresponding
10569 // to range boundaries A and B and such that Max_A < Min_B. If there was
10570 // a solution between Max_A and Min_B, it would have to be caused by an
10571 // overflow corresponding to either A or B. It cannot correspond to B,
10572 // since Min_B is the first occurrence of such an overflow. If it
10573 // corresponded to A, it would have to be either a signed or an unsigned
10574 // overflow that is larger than both eliminated overflows for A. But
10575 // between the eliminated overflows and this overflow, the values would
10576 // cover the entire value space, thus crossing the other boundary, which
10577 // is a contradiction.
10578
10579 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10580}
10581
10582ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10583 const Loop *L,
10584 bool ControlsOnlyExit,
10585 bool AllowPredicates) {
10586
10587 // This is only used for loops with a "x != y" exit test. The exit condition
10588 // is now expressed as a single expression, V = x-y. So the exit test is
10589 // effectively V != 0. We know and take advantage of the fact that this
10590 // expression only being used in a comparison by zero context.
10591
10593 // If the value is a constant
10594 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10595 // If the value is already zero, the branch will execute zero times.
10596 if (C->getValue()->isZero()) return C;
10597 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10598 }
10599
10600 const SCEVAddRecExpr *AddRec =
10601 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10602
10603 if (!AddRec && AllowPredicates)
10604 // Try to make this an AddRec using runtime tests, in the first X
10605 // iterations of this loop, where X is the SCEV expression found by the
10606 // algorithm below.
10607 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10608
10609 if (!AddRec || AddRec->getLoop() != L)
10610 return getCouldNotCompute();
10611
10612 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10613 // the quadratic equation to solve it.
10614 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10615 // We can only use this value if the chrec ends up with an exact zero
10616 // value at this index. When solving for "X*X != 5", for example, we
10617 // should not accept a root of 2.
10618 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10619 const auto *R = cast<SCEVConstant>(getConstant(*S));
10620 return ExitLimit(R, R, R, false, Predicates);
10621 }
10622 return getCouldNotCompute();
10623 }
10624
10625 // Otherwise we can only handle this if it is affine.
10626 if (!AddRec->isAffine())
10627 return getCouldNotCompute();
10628
10629 // If this is an affine expression, the execution count of this branch is
10630 // the minimum unsigned root of the following equation:
10631 //
10632 // Start + Step*N = 0 (mod 2^BW)
10633 //
10634 // equivalent to:
10635 //
10636 // Step*N = -Start (mod 2^BW)
10637 //
10638 // where BW is the common bit width of Start and Step.
10639
10640 // Get the initial value for the loop.
10641 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10642 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10643
10644 if (!isLoopInvariant(Step, L))
10645 return getCouldNotCompute();
10646
10647 LoopGuards Guards = LoopGuards::collect(L, *this);
10648 // Specialize step for this loop so we get context sensitive facts below.
10649 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10650
10651 // For positive steps (counting up until unsigned overflow):
10652 // N = -Start/Step (as unsigned)
10653 // For negative steps (counting down to zero):
10654 // N = Start/-Step
10655 // First compute the unsigned distance from zero in the direction of Step.
10656 bool CountDown = isKnownNegative(StepWLG);
10657 if (!CountDown && !isKnownNonNegative(StepWLG))
10658 return getCouldNotCompute();
10659
10660 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10661 // Handle unitary steps, which cannot wraparound.
10662 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10663 // N = Distance (as unsigned)
10664
10665 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10666 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10667 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10668
10669 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10670 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10671 // case, and see if we can improve the bound.
10672 //
10673 // Explicitly handling this here is necessary because getUnsignedRange
10674 // isn't context-sensitive; it doesn't know that we only care about the
10675 // range inside the loop.
10676 const SCEV *Zero = getZero(Distance->getType());
10677 const SCEV *One = getOne(Distance->getType());
10678 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10679 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10680 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10681 // as "unsigned_max(Distance + 1) - 1".
10682 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10683 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10684 }
10685 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10686 Predicates);
10687 }
10688
10689 // If the condition controls loop exit (the loop exits only if the expression
10690 // is true) and the addition is no-wrap we can use unsigned divide to
10691 // compute the backedge count. In this case, the step may not divide the
10692 // distance, but we don't care because if the condition is "missed" the loop
10693 // will have undefined behavior due to wrapping.
10694 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10695 loopHasNoAbnormalExits(AddRec->getLoop())) {
10696
10697 // If the stride is zero and the start is non-zero, the loop must be
10698 // infinite. In C++, most loops are finite by assumption, in which case the
10699 // step being zero implies UB must execute if the loop is entered.
10700 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10701 !isKnownNonZero(StepWLG))
10702 return getCouldNotCompute();
10703
10704 const SCEV *Exact =
10705 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10706 const SCEV *ConstantMax = getCouldNotCompute();
10707 if (Exact != getCouldNotCompute()) {
10708 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10709 ConstantMax =
10711 }
10712 const SCEV *SymbolicMax =
10713 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10714 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10715 }
10716
10717 // Solve the general equation.
10718 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10719 if (!StepC || StepC->getValue()->isZero())
10720 return getCouldNotCompute();
10721 const SCEV *E = SolveLinEquationWithOverflow(
10722 StepC->getAPInt(), getNegativeSCEV(Start),
10723 AllowPredicates ? &Predicates : nullptr, *this, L);
10724
10725 const SCEV *M = E;
10726 if (E != getCouldNotCompute()) {
10727 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10728 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10729 }
10730 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10731 return ExitLimit(E, M, S, false, Predicates);
10732}
10733
10734ScalarEvolution::ExitLimit
10735ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10736 // Loops that look like: while (X == 0) are very strange indeed. We don't
10737 // handle them yet except for the trivial case. This could be expanded in the
10738 // future as needed.
10739
10740 // If the value is a constant, check to see if it is known to be non-zero
10741 // already. If so, the backedge will execute zero times.
10742 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10743 if (!C->getValue()->isZero())
10744 return getZero(C->getType());
10745 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10746 }
10747
10748 // We could implement others, but I really doubt anyone writes loops like
10749 // this, and if they did, they would already be constant folded.
10750 return getCouldNotCompute();
10751}
10752
10753std::pair<const BasicBlock *, const BasicBlock *>
10754ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10755 const {
10756 // If the block has a unique predecessor, then there is no path from the
10757 // predecessor to the block that does not go through the direct edge
10758 // from the predecessor to the block.
10759 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10760 return {Pred, BB};
10761
10762 // A loop's header is defined to be a block that dominates the loop.
10763 // If the header has a unique predecessor outside the loop, it must be
10764 // a block that has exactly one successor that can reach the loop.
10765 if (const Loop *L = LI.getLoopFor(BB))
10766 return {L->getLoopPredecessor(), L->getHeader()};
10767
10768 return {nullptr, BB};
10769}
10770
10771/// SCEV structural equivalence is usually sufficient for testing whether two
10772/// expressions are equal, however for the purposes of looking for a condition
10773/// guarding a loop, it can be useful to be a little more general, since a
10774/// front-end may have replicated the controlling expression.
10775static bool HasSameValue(const SCEV *A, const SCEV *B) {
10776 // Quick check to see if they are the same SCEV.
10777 if (A == B) return true;
10778
10779 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10780 // Not all instructions that are "identical" compute the same value. For
10781 // instance, two distinct alloca instructions allocating the same type are
10782 // identical and do not read memory; but compute distinct values.
10783 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10784 };
10785
10786 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10787 // two different instructions with the same value. Check for this case.
10788 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10789 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10790 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10791 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10792 if (ComputesEqualValues(AI, BI))
10793 return true;
10794
10795 // Otherwise assume they may have a different value.
10796 return false;
10797}
10798
10799static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10801 if (!Add || Add->getNumOperands() != 2)
10802 return false;
10803 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10804 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10805 LHS = Add->getOperand(1);
10806 RHS = ME->getOperand(1);
10807 return true;
10808 }
10809 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10810 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10811 LHS = Add->getOperand(0);
10812 RHS = ME->getOperand(1);
10813 return true;
10814 }
10815 return false;
10816}
10817
10819 const SCEV *&RHS, unsigned Depth) {
10820 bool Changed = false;
10821 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10822 // '0 != 0'.
10823 auto TrivialCase = [&](bool TriviallyTrue) {
10825 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10826 return true;
10827 };
10828 // If we hit the max recursion limit bail out.
10829 if (Depth >= 3)
10830 return false;
10831
10832 const SCEV *NewLHS, *NewRHS;
10833 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10834 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10835 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10836 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10837
10838 // (X * vscale) pred (Y * vscale) ==> X pred Y
10839 // when both multiples are NSW.
10840 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10841 // when both multiples are NUW.
10842 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10843 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10844 !ICmpInst::isSigned(Pred))) {
10845 LHS = NewLHS;
10846 RHS = NewRHS;
10847 Changed = true;
10848 }
10849 }
10850
10851 // Canonicalize a constant to the right side.
10852 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10853 // Check for both operands constant.
10854 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10855 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10856 return TrivialCase(false);
10857 return TrivialCase(true);
10858 }
10859 // Otherwise swap the operands to put the constant on the right.
10860 std::swap(LHS, RHS);
10862 Changed = true;
10863 }
10864
10865 // If we're comparing an addrec with a value which is loop-invariant in the
10866 // addrec's loop, put the addrec on the left. Also make a dominance check,
10867 // as both operands could be addrecs loop-invariant in each other's loop.
10868 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10869 const Loop *L = AR->getLoop();
10870 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10871 std::swap(LHS, RHS);
10873 Changed = true;
10874 }
10875 }
10876
10877 // If there's a constant operand, canonicalize comparisons with boundary
10878 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10879 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10880 const APInt &RA = RC->getAPInt();
10881
10882 bool SimplifiedByConstantRange = false;
10883
10884 if (!ICmpInst::isEquality(Pred)) {
10886 if (ExactCR.isFullSet())
10887 return TrivialCase(true);
10888 if (ExactCR.isEmptySet())
10889 return TrivialCase(false);
10890
10891 APInt NewRHS;
10892 CmpInst::Predicate NewPred;
10893 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10894 ICmpInst::isEquality(NewPred)) {
10895 // We were able to convert an inequality to an equality.
10896 Pred = NewPred;
10897 RHS = getConstant(NewRHS);
10898 Changed = SimplifiedByConstantRange = true;
10899 }
10900 }
10901
10902 if (!SimplifiedByConstantRange) {
10903 switch (Pred) {
10904 default:
10905 break;
10906 case ICmpInst::ICMP_EQ:
10907 case ICmpInst::ICMP_NE:
10908 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10909 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10910 Changed = true;
10911 break;
10912
10913 // The "Should have been caught earlier!" messages refer to the fact
10914 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10915 // should have fired on the corresponding cases, and canonicalized the
10916 // check to trivial case.
10917
10918 case ICmpInst::ICMP_UGE:
10919 assert(!RA.isMinValue() && "Should have been caught earlier!");
10920 Pred = ICmpInst::ICMP_UGT;
10921 RHS = getConstant(RA - 1);
10922 Changed = true;
10923 break;
10924 case ICmpInst::ICMP_ULE:
10925 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10926 Pred = ICmpInst::ICMP_ULT;
10927 RHS = getConstant(RA + 1);
10928 Changed = true;
10929 break;
10930 case ICmpInst::ICMP_SGE:
10931 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10932 Pred = ICmpInst::ICMP_SGT;
10933 RHS = getConstant(RA - 1);
10934 Changed = true;
10935 break;
10936 case ICmpInst::ICMP_SLE:
10937 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10938 Pred = ICmpInst::ICMP_SLT;
10939 RHS = getConstant(RA + 1);
10940 Changed = true;
10941 break;
10942 }
10943 }
10944 }
10945
10946 // Check for obvious equality.
10947 if (HasSameValue(LHS, RHS)) {
10948 if (ICmpInst::isTrueWhenEqual(Pred))
10949 return TrivialCase(true);
10951 return TrivialCase(false);
10952 }
10953
10954 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10955 // adding or subtracting 1 from one of the operands.
10956 switch (Pred) {
10957 case ICmpInst::ICMP_SLE:
10958 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10959 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10961 Pred = ICmpInst::ICMP_SLT;
10962 Changed = true;
10963 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10964 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10966 Pred = ICmpInst::ICMP_SLT;
10967 Changed = true;
10968 }
10969 break;
10970 case ICmpInst::ICMP_SGE:
10971 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10972 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10974 Pred = ICmpInst::ICMP_SGT;
10975 Changed = true;
10976 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10977 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10979 Pred = ICmpInst::ICMP_SGT;
10980 Changed = true;
10981 }
10982 break;
10983 case ICmpInst::ICMP_ULE:
10984 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10985 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10987 Pred = ICmpInst::ICMP_ULT;
10988 Changed = true;
10989 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10990 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10991 Pred = ICmpInst::ICMP_ULT;
10992 Changed = true;
10993 }
10994 break;
10995 case ICmpInst::ICMP_UGE:
10996 // If RHS is an op we can fold the -1, try that first.
10997 // Otherwise prefer LHS to preserve the nuw flag.
10998 if ((isa<SCEVConstant>(RHS) ||
11000 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11001 !getUnsignedRangeMin(RHS).isMinValue()) {
11002 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11003 Pred = ICmpInst::ICMP_UGT;
11004 Changed = true;
11005 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11006 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11008 Pred = ICmpInst::ICMP_UGT;
11009 Changed = true;
11010 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11011 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11012 Pred = ICmpInst::ICMP_UGT;
11013 Changed = true;
11014 }
11015 break;
11016 default:
11017 break;
11018 }
11019
11020 // TODO: More simplifications are possible here.
11021
11022 // Recursively simplify until we either hit a recursion limit or nothing
11023 // changes.
11024 if (Changed)
11025 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11026
11027 return Changed;
11028}
11029
11031 return getSignedRangeMax(S).isNegative();
11032}
11033
11037
11039 return !getSignedRangeMin(S).isNegative();
11040}
11041
11045
11047 // Query push down for cases where the unsigned range is
11048 // less than sufficient.
11049 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11050 return isKnownNonZero(SExt->getOperand(0));
11051 return getUnsignedRangeMin(S) != 0;
11052}
11053
11055 bool OrNegative) {
11056 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11057 if (auto *C = dyn_cast<SCEVConstant>(S))
11058 return C->getAPInt().isPowerOf2() ||
11059 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11060
11061 // The vscale_range indicates vscale is a power-of-two.
11062 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11063 };
11064
11065 if (NonRecursive(S))
11066 return true;
11067
11068 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11069 if (!Mul)
11070 return false;
11071 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11072}
11073
11075 const SCEV *S, uint64_t M,
11077 if (M == 0)
11078 return false;
11079 if (M == 1)
11080 return true;
11081
11082 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11083 // starts with a multiple of M and at every iteration step S only adds
11084 // multiples of M.
11085 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11086 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11087 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11088
11089 // For a constant, check that "S % M == 0".
11090 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11091 APInt C = Cst->getAPInt();
11092 return C.urem(M) == 0;
11093 }
11094
11095 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11096
11097 // Basic tests have failed.
11098 // Check "S % M == 0" at compile time and record runtime Assumptions.
11099 auto *STy = dyn_cast<IntegerType>(S->getType());
11100 const SCEV *SmodM =
11101 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11102 const SCEV *Zero = getZero(STy);
11103
11104 // Check whether "S % M == 0" is known at compile time.
11105 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11106 return true;
11107
11108 // Check whether "S % M != 0" is known at compile time.
11109 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11110 return false;
11111
11113
11114 // Detect redundant predicates.
11115 for (auto *A : Assumptions)
11116 if (A->implies(P, *this))
11117 return true;
11118
11119 // Only record non-redundant predicates.
11120 Assumptions.push_back(P);
11121 return true;
11122}
11123
11124std::pair<const SCEV *, const SCEV *>
11126 // Compute SCEV on entry of loop L.
11127 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11128 if (Start == getCouldNotCompute())
11129 return { Start, Start };
11130 // Compute post increment SCEV for loop L.
11131 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11132 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11133 return { Start, PostInc };
11134}
11135
11137 const SCEV *RHS) {
11138 // First collect all loops.
11140 getUsedLoops(LHS, LoopsUsed);
11141 getUsedLoops(RHS, LoopsUsed);
11142
11143 if (LoopsUsed.empty())
11144 return false;
11145
11146 // Domination relationship must be a linear order on collected loops.
11147#ifndef NDEBUG
11148 for (const auto *L1 : LoopsUsed)
11149 for (const auto *L2 : LoopsUsed)
11150 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11151 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11152 "Domination relationship is not a linear order");
11153#endif
11154
11155 const Loop *MDL =
11156 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11157 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11158 });
11159
11160 // Get init and post increment value for LHS.
11161 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11162 // if LHS contains unknown non-invariant SCEV then bail out.
11163 if (SplitLHS.first == getCouldNotCompute())
11164 return false;
11165 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11166 // Get init and post increment value for RHS.
11167 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11168 // if RHS contains unknown non-invariant SCEV then bail out.
11169 if (SplitRHS.first == getCouldNotCompute())
11170 return false;
11171 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11172 // It is possible that init SCEV contains an invariant load but it does
11173 // not dominate MDL and is not available at MDL loop entry, so we should
11174 // check it here.
11175 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11176 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11177 return false;
11178
11179 // It seems backedge guard check is faster than entry one so in some cases
11180 // it can speed up whole estimation by short circuit
11181 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11182 SplitRHS.second) &&
11183 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11184}
11185
11187 const SCEV *RHS) {
11188 // Canonicalize the inputs first.
11189 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11190
11191 if (isKnownViaInduction(Pred, LHS, RHS))
11192 return true;
11193
11194 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11195 return true;
11196
11197 // Otherwise see what can be done with some simple reasoning.
11198 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11199}
11200
11202 const SCEV *LHS,
11203 const SCEV *RHS) {
11204 if (isKnownPredicate(Pred, LHS, RHS))
11205 return true;
11207 return false;
11208 return std::nullopt;
11209}
11210
11212 const SCEV *RHS,
11213 const Instruction *CtxI) {
11214 // TODO: Analyze guards and assumes from Context's block.
11215 return isKnownPredicate(Pred, LHS, RHS) ||
11216 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11217}
11218
11219std::optional<bool>
11221 const SCEV *RHS, const Instruction *CtxI) {
11222 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11223 if (KnownWithoutContext)
11224 return KnownWithoutContext;
11225
11226 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11227 return true;
11229 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11230 return false;
11231 return std::nullopt;
11232}
11233
11235 const SCEVAddRecExpr *LHS,
11236 const SCEV *RHS) {
11237 const Loop *L = LHS->getLoop();
11238 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11239 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11240}
11241
11242std::optional<ScalarEvolution::MonotonicPredicateType>
11244 ICmpInst::Predicate Pred) {
11245 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11246
11247#ifndef NDEBUG
11248 // Verify an invariant: inverting the predicate should turn a monotonically
11249 // increasing change to a monotonically decreasing one, and vice versa.
11250 if (Result) {
11251 auto ResultSwapped =
11252 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11253
11254 assert(*ResultSwapped != *Result &&
11255 "monotonicity should flip as we flip the predicate");
11256 }
11257#endif
11258
11259 return Result;
11260}
11261
11262std::optional<ScalarEvolution::MonotonicPredicateType>
11263ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11264 ICmpInst::Predicate Pred) {
11265 // A zero step value for LHS means the induction variable is essentially a
11266 // loop invariant value. We don't really depend on the predicate actually
11267 // flipping from false to true (for increasing predicates, and the other way
11268 // around for decreasing predicates), all we care about is that *if* the
11269 // predicate changes then it only changes from false to true.
11270 //
11271 // A zero step value in itself is not very useful, but there may be places
11272 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11273 // as general as possible.
11274
11275 // Only handle LE/LT/GE/GT predicates.
11276 if (!ICmpInst::isRelational(Pred))
11277 return std::nullopt;
11278
11279 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11280 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11281 "Should be greater or less!");
11282
11283 // Check that AR does not wrap.
11284 if (ICmpInst::isUnsigned(Pred)) {
11285 if (!LHS->hasNoUnsignedWrap())
11286 return std::nullopt;
11288 }
11289 assert(ICmpInst::isSigned(Pred) &&
11290 "Relational predicate is either signed or unsigned!");
11291 if (!LHS->hasNoSignedWrap())
11292 return std::nullopt;
11293
11294 const SCEV *Step = LHS->getStepRecurrence(*this);
11295
11296 if (isKnownNonNegative(Step))
11298
11299 if (isKnownNonPositive(Step))
11301
11302 return std::nullopt;
11303}
11304
11305std::optional<ScalarEvolution::LoopInvariantPredicate>
11307 const SCEV *RHS, const Loop *L,
11308 const Instruction *CtxI) {
11309 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11310 if (!isLoopInvariant(RHS, L)) {
11311 if (!isLoopInvariant(LHS, L))
11312 return std::nullopt;
11313
11314 std::swap(LHS, RHS);
11316 }
11317
11318 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11319 if (!ArLHS || ArLHS->getLoop() != L)
11320 return std::nullopt;
11321
11322 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11323 if (!MonotonicType)
11324 return std::nullopt;
11325 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11326 // true as the loop iterates, and the backedge is control dependent on
11327 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11328 //
11329 // * if the predicate was false in the first iteration then the predicate
11330 // is never evaluated again, since the loop exits without taking the
11331 // backedge.
11332 // * if the predicate was true in the first iteration then it will
11333 // continue to be true for all future iterations since it is
11334 // monotonically increasing.
11335 //
11336 // For both the above possibilities, we can replace the loop varying
11337 // predicate with its value on the first iteration of the loop (which is
11338 // loop invariant).
11339 //
11340 // A similar reasoning applies for a monotonically decreasing predicate, by
11341 // replacing true with false and false with true in the above two bullets.
11343 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11344
11345 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11347 RHS);
11348
11349 if (!CtxI)
11350 return std::nullopt;
11351 // Try to prove via context.
11352 // TODO: Support other cases.
11353 switch (Pred) {
11354 default:
11355 break;
11356 case ICmpInst::ICMP_ULE:
11357 case ICmpInst::ICMP_ULT: {
11358 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11359 // Given preconditions
11360 // (1) ArLHS does not cross the border of positive and negative parts of
11361 // range because of:
11362 // - Positive step; (TODO: lift this limitation)
11363 // - nuw - does not cross zero boundary;
11364 // - nsw - does not cross SINT_MAX boundary;
11365 // (2) ArLHS <s RHS
11366 // (3) RHS >=s 0
11367 // we can replace the loop variant ArLHS <u RHS condition with loop
11368 // invariant Start(ArLHS) <u RHS.
11369 //
11370 // Because of (1) there are two options:
11371 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11372 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11373 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11374 // Because of (2) ArLHS <u RHS is trivially true.
11375 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11376 // We can strengthen this to Start(ArLHS) <u RHS.
11377 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11378 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11379 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11380 isKnownNonNegative(RHS) &&
11381 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11383 RHS);
11384 }
11385 }
11386
11387 return std::nullopt;
11388}
11389
11390std::optional<ScalarEvolution::LoopInvariantPredicate>
11392 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11393 const Instruction *CtxI, const SCEV *MaxIter) {
11395 Pred, LHS, RHS, L, CtxI, MaxIter))
11396 return LIP;
11397 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11398 // Number of iterations expressed as UMIN isn't always great for expressing
11399 // the value on the last iteration. If the straightforward approach didn't
11400 // work, try the following trick: if the a predicate is invariant for X, it
11401 // is also invariant for umin(X, ...). So try to find something that works
11402 // among subexpressions of MaxIter expressed as umin.
11403 for (auto *Op : UMin->operands())
11405 Pred, LHS, RHS, L, CtxI, Op))
11406 return LIP;
11407 return std::nullopt;
11408}
11409
11410std::optional<ScalarEvolution::LoopInvariantPredicate>
11412 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11413 const Instruction *CtxI, const SCEV *MaxIter) {
11414 // Try to prove the following set of facts:
11415 // - The predicate is monotonic in the iteration space.
11416 // - If the check does not fail on the 1st iteration:
11417 // - No overflow will happen during first MaxIter iterations;
11418 // - It will not fail on the MaxIter'th iteration.
11419 // If the check does fail on the 1st iteration, we leave the loop and no
11420 // other checks matter.
11421
11422 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11423 if (!isLoopInvariant(RHS, L)) {
11424 if (!isLoopInvariant(LHS, L))
11425 return std::nullopt;
11426
11427 std::swap(LHS, RHS);
11429 }
11430
11431 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11432 if (!AR || AR->getLoop() != L)
11433 return std::nullopt;
11434
11435 // The predicate must be relational (i.e. <, <=, >=, >).
11436 if (!ICmpInst::isRelational(Pred))
11437 return std::nullopt;
11438
11439 // TODO: Support steps other than +/- 1.
11440 const SCEV *Step = AR->getStepRecurrence(*this);
11441 auto *One = getOne(Step->getType());
11442 auto *MinusOne = getNegativeSCEV(One);
11443 if (Step != One && Step != MinusOne)
11444 return std::nullopt;
11445
11446 // Type mismatch here means that MaxIter is potentially larger than max
11447 // unsigned value in start type, which mean we cannot prove no wrap for the
11448 // indvar.
11449 if (AR->getType() != MaxIter->getType())
11450 return std::nullopt;
11451
11452 // Value of IV on suggested last iteration.
11453 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11454 // Does it still meet the requirement?
11455 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11456 return std::nullopt;
11457 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11458 // not exceed max unsigned value of this type), this effectively proves
11459 // that there is no wrap during the iteration. To prove that there is no
11460 // signed/unsigned wrap, we need to check that
11461 // Start <= Last for step = 1 or Start >= Last for step = -1.
11462 ICmpInst::Predicate NoOverflowPred =
11464 if (Step == MinusOne)
11465 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11466 const SCEV *Start = AR->getStart();
11467 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11468 return std::nullopt;
11469
11470 // Everything is fine.
11471 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11472}
11473
11474bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11475 const SCEV *LHS,
11476 const SCEV *RHS) {
11477 if (HasSameValue(LHS, RHS))
11478 return ICmpInst::isTrueWhenEqual(Pred);
11479
11480 auto CheckRange = [&](bool IsSigned) {
11481 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11482 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11483 return RangeLHS.icmp(Pred, RangeRHS);
11484 };
11485
11486 // The check at the top of the function catches the case where the values are
11487 // known to be equal.
11488 if (Pred == CmpInst::ICMP_EQ)
11489 return false;
11490
11491 if (Pred == CmpInst::ICMP_NE) {
11492 if (CheckRange(true) || CheckRange(false))
11493 return true;
11494 auto *Diff = getMinusSCEV(LHS, RHS);
11495 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11496 }
11497
11498 return CheckRange(CmpInst::isSigned(Pred));
11499}
11500
11501bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11502 const SCEV *LHS,
11503 const SCEV *RHS) {
11504 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11505 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11506 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11507 // OutC1 and OutC2.
11508 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11509 APInt &OutC1, APInt &OutC2,
11510 SCEV::NoWrapFlags ExpectedFlags) {
11511 const SCEV *XNonConstOp, *XConstOp;
11512 const SCEV *YNonConstOp, *YConstOp;
11513 SCEV::NoWrapFlags XFlagsPresent;
11514 SCEV::NoWrapFlags YFlagsPresent;
11515
11516 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11517 XConstOp = getZero(X->getType());
11518 XNonConstOp = X;
11519 XFlagsPresent = ExpectedFlags;
11520 }
11521 if (!isa<SCEVConstant>(XConstOp))
11522 return false;
11523
11524 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11525 YConstOp = getZero(Y->getType());
11526 YNonConstOp = Y;
11527 YFlagsPresent = ExpectedFlags;
11528 }
11529
11530 if (YNonConstOp != XNonConstOp)
11531 return false;
11532
11533 if (!isa<SCEVConstant>(YConstOp))
11534 return false;
11535
11536 // When matching ADDs with NUW flags (and unsigned predicates), only the
11537 // second ADD (with the larger constant) requires NUW.
11538 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11539 return false;
11540 if (ExpectedFlags != SCEV::FlagNUW &&
11541 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11542 return false;
11543 }
11544
11545 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11546 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11547
11548 return true;
11549 };
11550
11551 APInt C1;
11552 APInt C2;
11553
11554 switch (Pred) {
11555 default:
11556 break;
11557
11558 case ICmpInst::ICMP_SGE:
11559 std::swap(LHS, RHS);
11560 [[fallthrough]];
11561 case ICmpInst::ICMP_SLE:
11562 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11563 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11564 return true;
11565
11566 break;
11567
11568 case ICmpInst::ICMP_SGT:
11569 std::swap(LHS, RHS);
11570 [[fallthrough]];
11571 case ICmpInst::ICMP_SLT:
11572 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11573 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11574 return true;
11575
11576 break;
11577
11578 case ICmpInst::ICMP_UGE:
11579 std::swap(LHS, RHS);
11580 [[fallthrough]];
11581 case ICmpInst::ICMP_ULE:
11582 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11583 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11584 return true;
11585
11586 break;
11587
11588 case ICmpInst::ICMP_UGT:
11589 std::swap(LHS, RHS);
11590 [[fallthrough]];
11591 case ICmpInst::ICMP_ULT:
11592 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11593 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11594 return true;
11595 break;
11596 }
11597
11598 return false;
11599}
11600
11601bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11602 const SCEV *LHS,
11603 const SCEV *RHS) {
11604 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11605 return false;
11606
11607 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11608 // the stack can result in exponential time complexity.
11609 SaveAndRestore Restore(ProvingSplitPredicate, true);
11610
11611 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11612 //
11613 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11614 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11615 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11616 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11617 // use isKnownPredicate later if needed.
11618 return isKnownNonNegative(RHS) &&
11621}
11622
11623bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11624 const SCEV *LHS, const SCEV *RHS) {
11625 // No need to even try if we know the module has no guards.
11626 if (!HasGuards)
11627 return false;
11628
11629 return any_of(*BB, [&](const Instruction &I) {
11630 using namespace llvm::PatternMatch;
11631
11632 Value *Condition;
11634 m_Value(Condition))) &&
11635 isImpliedCond(Pred, LHS, RHS, Condition, false);
11636 });
11637}
11638
11639/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11640/// protected by a conditional between LHS and RHS. This is used to
11641/// to eliminate casts.
11643 CmpPredicate Pred,
11644 const SCEV *LHS,
11645 const SCEV *RHS) {
11646 // Interpret a null as meaning no loop, where there is obviously no guard
11647 // (interprocedural conditions notwithstanding). Do not bother about
11648 // unreachable loops.
11649 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11650 return true;
11651
11652 if (VerifyIR)
11653 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11654 "This cannot be done on broken IR!");
11655
11656
11657 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11658 return true;
11659
11660 BasicBlock *Latch = L->getLoopLatch();
11661 if (!Latch)
11662 return false;
11663
11664 BranchInst *LoopContinuePredicate =
11666 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11667 isImpliedCond(Pred, LHS, RHS,
11668 LoopContinuePredicate->getCondition(),
11669 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11670 return true;
11671
11672 // We don't want more than one activation of the following loops on the stack
11673 // -- that can lead to O(n!) time complexity.
11674 if (WalkingBEDominatingConds)
11675 return false;
11676
11677 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11678
11679 // See if we can exploit a trip count to prove the predicate.
11680 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11681 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11682 if (LatchBECount != getCouldNotCompute()) {
11683 // We know that Latch branches back to the loop header exactly
11684 // LatchBECount times. This means the backdege condition at Latch is
11685 // equivalent to "{0,+,1} u< LatchBECount".
11686 Type *Ty = LatchBECount->getType();
11687 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11688 const SCEV *LoopCounter =
11689 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11690 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11691 LatchBECount))
11692 return true;
11693 }
11694
11695 // Check conditions due to any @llvm.assume intrinsics.
11696 for (auto &AssumeVH : AC.assumptions()) {
11697 if (!AssumeVH)
11698 continue;
11699 auto *CI = cast<CallInst>(AssumeVH);
11700 if (!DT.dominates(CI, Latch->getTerminator()))
11701 continue;
11702
11703 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11704 return true;
11705 }
11706
11707 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11708 return true;
11709
11710 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11711 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11712 assert(DTN && "should reach the loop header before reaching the root!");
11713
11714 BasicBlock *BB = DTN->getBlock();
11715 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11716 return true;
11717
11718 BasicBlock *PBB = BB->getSinglePredecessor();
11719 if (!PBB)
11720 continue;
11721
11722 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11723 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11724 continue;
11725
11726 Value *Condition = ContinuePredicate->getCondition();
11727
11728 // If we have an edge `E` within the loop body that dominates the only
11729 // latch, the condition guarding `E` also guards the backedge. This
11730 // reasoning works only for loops with a single latch.
11731
11732 BasicBlockEdge DominatingEdge(PBB, BB);
11733 if (DominatingEdge.isSingleEdge()) {
11734 // We're constructively (and conservatively) enumerating edges within the
11735 // loop body that dominate the latch. The dominator tree better agree
11736 // with us on this:
11737 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11738
11739 if (isImpliedCond(Pred, LHS, RHS, Condition,
11740 BB != ContinuePredicate->getSuccessor(0)))
11741 return true;
11742 }
11743 }
11744
11745 return false;
11746}
11747
11749 CmpPredicate Pred,
11750 const SCEV *LHS,
11751 const SCEV *RHS) {
11752 // Do not bother proving facts for unreachable code.
11753 if (!DT.isReachableFromEntry(BB))
11754 return true;
11755 if (VerifyIR)
11756 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11757 "This cannot be done on broken IR!");
11758
11759 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11760 // the facts (a >= b && a != b) separately. A typical situation is when the
11761 // non-strict comparison is known from ranges and non-equality is known from
11762 // dominating predicates. If we are proving strict comparison, we always try
11763 // to prove non-equality and non-strict comparison separately.
11764 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11765 const bool ProvingStrictComparison =
11766 Pred != NonStrictPredicate.dropSameSign();
11767 bool ProvedNonStrictComparison = false;
11768 bool ProvedNonEquality = false;
11769
11770 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11771 if (!ProvedNonStrictComparison)
11772 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11773 if (!ProvedNonEquality)
11774 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11775 if (ProvedNonStrictComparison && ProvedNonEquality)
11776 return true;
11777 return false;
11778 };
11779
11780 if (ProvingStrictComparison) {
11781 auto ProofFn = [&](CmpPredicate P) {
11782 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11783 };
11784 if (SplitAndProve(ProofFn))
11785 return true;
11786 }
11787
11788 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11789 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11790 const Instruction *CtxI = &BB->front();
11791 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11792 return true;
11793 if (ProvingStrictComparison) {
11794 auto ProofFn = [&](CmpPredicate P) {
11795 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11796 };
11797 if (SplitAndProve(ProofFn))
11798 return true;
11799 }
11800 return false;
11801 };
11802
11803 // Starting at the block's predecessor, climb up the predecessor chain, as long
11804 // as there are predecessors that can be found that have unique successors
11805 // leading to the original block.
11806 const Loop *ContainingLoop = LI.getLoopFor(BB);
11807 const BasicBlock *PredBB;
11808 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11809 PredBB = ContainingLoop->getLoopPredecessor();
11810 else
11811 PredBB = BB->getSinglePredecessor();
11812 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11813 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11814 const BranchInst *BlockEntryPredicate =
11815 dyn_cast<BranchInst>(Pair.first->getTerminator());
11816 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11817 continue;
11818
11819 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11820 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11821 return true;
11822 }
11823
11824 // Check conditions due to any @llvm.assume intrinsics.
11825 for (auto &AssumeVH : AC.assumptions()) {
11826 if (!AssumeVH)
11827 continue;
11828 auto *CI = cast<CallInst>(AssumeVH);
11829 if (!DT.dominates(CI, BB))
11830 continue;
11831
11832 if (ProveViaCond(CI->getArgOperand(0), false))
11833 return true;
11834 }
11835
11836 // Check conditions due to any @llvm.experimental.guard intrinsics.
11837 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11838 F.getParent(), Intrinsic::experimental_guard);
11839 if (GuardDecl)
11840 for (const auto *GU : GuardDecl->users())
11841 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11842 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11843 if (ProveViaCond(Guard->getArgOperand(0), false))
11844 return true;
11845 return false;
11846}
11847
11849 const SCEV *LHS,
11850 const SCEV *RHS) {
11851 // Interpret a null as meaning no loop, where there is obviously no guard
11852 // (interprocedural conditions notwithstanding).
11853 if (!L)
11854 return false;
11855
11856 // Both LHS and RHS must be available at loop entry.
11858 "LHS is not available at Loop Entry");
11860 "RHS is not available at Loop Entry");
11861
11862 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11863 return true;
11864
11865 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11866}
11867
11868bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11869 const SCEV *RHS,
11870 const Value *FoundCondValue, bool Inverse,
11871 const Instruction *CtxI) {
11872 // False conditions implies anything. Do not bother analyzing it further.
11873 if (FoundCondValue ==
11874 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11875 return true;
11876
11877 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11878 return false;
11879
11880 auto ClearOnExit =
11881 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11882
11883 // Recursively handle And and Or conditions.
11884 const Value *Op0, *Op1;
11885 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11886 if (!Inverse)
11887 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11888 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11889 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11890 if (Inverse)
11891 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11892 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11893 }
11894
11895 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11896 if (!ICI) return false;
11897
11898 // Now that we found a conditional branch that dominates the loop or controls
11899 // the loop latch. Check to see if it is the comparison we are looking for.
11900 CmpPredicate FoundPred;
11901 if (Inverse)
11902 FoundPred = ICI->getInverseCmpPredicate();
11903 else
11904 FoundPred = ICI->getCmpPredicate();
11905
11906 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11907 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11908
11909 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11910}
11911
11912bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11913 const SCEV *RHS, CmpPredicate FoundPred,
11914 const SCEV *FoundLHS, const SCEV *FoundRHS,
11915 const Instruction *CtxI) {
11916 // Balance the types.
11917 if (getTypeSizeInBits(LHS->getType()) <
11918 getTypeSizeInBits(FoundLHS->getType())) {
11919 // For unsigned and equality predicates, try to prove that both found
11920 // operands fit into narrow unsigned range. If so, try to prove facts in
11921 // narrow types.
11922 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11923 !FoundRHS->getType()->isPointerTy()) {
11924 auto *NarrowType = LHS->getType();
11925 auto *WideType = FoundLHS->getType();
11926 auto BitWidth = getTypeSizeInBits(NarrowType);
11927 const SCEV *MaxValue = getZeroExtendExpr(
11929 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11930 MaxValue) &&
11931 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11932 MaxValue)) {
11933 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11934 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11935 // We cannot preserve samesign after truncation.
11936 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11937 TruncFoundLHS, TruncFoundRHS, CtxI))
11938 return true;
11939 }
11940 }
11941
11942 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11943 return false;
11944 if (CmpInst::isSigned(Pred)) {
11945 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11946 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11947 } else {
11948 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11949 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11950 }
11951 } else if (getTypeSizeInBits(LHS->getType()) >
11952 getTypeSizeInBits(FoundLHS->getType())) {
11953 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11954 return false;
11955 if (CmpInst::isSigned(FoundPred)) {
11956 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11957 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11958 } else {
11959 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11960 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11961 }
11962 }
11963 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11964 FoundRHS, CtxI);
11965}
11966
11967bool ScalarEvolution::isImpliedCondBalancedTypes(
11968 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11969 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11971 getTypeSizeInBits(FoundLHS->getType()) &&
11972 "Types should be balanced!");
11973 // Canonicalize the query to match the way instcombine will have
11974 // canonicalized the comparison.
11975 if (SimplifyICmpOperands(Pred, LHS, RHS))
11976 if (LHS == RHS)
11977 return CmpInst::isTrueWhenEqual(Pred);
11978 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11979 if (FoundLHS == FoundRHS)
11980 return CmpInst::isFalseWhenEqual(FoundPred);
11981
11982 // Check to see if we can make the LHS or RHS match.
11983 if (LHS == FoundRHS || RHS == FoundLHS) {
11984 if (isa<SCEVConstant>(RHS)) {
11985 std::swap(FoundLHS, FoundRHS);
11986 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11987 } else {
11988 std::swap(LHS, RHS);
11990 }
11991 }
11992
11993 // Check whether the found predicate is the same as the desired predicate.
11994 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11995 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11996
11997 // Check whether swapping the found predicate makes it the same as the
11998 // desired predicate.
11999 if (auto P = CmpPredicate::getMatching(
12000 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12001 // We can write the implication
12002 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12003 // using one of the following ways:
12004 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12005 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12006 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12007 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12008 // Forms 1. and 2. require swapping the operands of one condition. Don't
12009 // do this if it would break canonical constant/addrec ordering.
12011 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12012 LHS, FoundLHS, FoundRHS, CtxI);
12013 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12014 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12015
12016 // There's no clear preference between forms 3. and 4., try both. Avoid
12017 // forming getNotSCEV of pointer values as the resulting subtract is
12018 // not legal.
12019 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12020 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12021 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12022 FoundRHS, CtxI))
12023 return true;
12024
12025 if (!FoundLHS->getType()->isPointerTy() &&
12026 !FoundRHS->getType()->isPointerTy() &&
12027 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12028 getNotSCEV(FoundRHS), CtxI))
12029 return true;
12030
12031 return false;
12032 }
12033
12034 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12035 CmpInst::Predicate P2) {
12036 assert(P1 != P2 && "Handled earlier!");
12037 return CmpInst::isRelational(P2) &&
12039 };
12040 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12041 // Unsigned comparison is the same as signed comparison when both the
12042 // operands are non-negative or negative.
12043 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
12044 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
12045 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12046 // Create local copies that we can freely swap and canonicalize our
12047 // conditions to "le/lt".
12048 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12049 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12050 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12051 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12052 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12053 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12054 std::swap(CanonicalLHS, CanonicalRHS);
12055 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12056 }
12057 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12058 "Must be!");
12059 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12060 ICmpInst::isLE(CanonicalFoundPred)) &&
12061 "Must be!");
12062 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12063 // Use implication:
12064 // x <u y && y >=s 0 --> x <s y.
12065 // If we can prove the left part, the right part is also proven.
12066 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12067 CanonicalRHS, CanonicalFoundLHS,
12068 CanonicalFoundRHS);
12069 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12070 // Use implication:
12071 // x <s y && y <s 0 --> x <u y.
12072 // If we can prove the left part, the right part is also proven.
12073 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12074 CanonicalRHS, CanonicalFoundLHS,
12075 CanonicalFoundRHS);
12076 }
12077
12078 // Check if we can make progress by sharpening ranges.
12079 if (FoundPred == ICmpInst::ICMP_NE &&
12080 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12081
12082 const SCEVConstant *C = nullptr;
12083 const SCEV *V = nullptr;
12084
12085 if (isa<SCEVConstant>(FoundLHS)) {
12086 C = cast<SCEVConstant>(FoundLHS);
12087 V = FoundRHS;
12088 } else {
12089 C = cast<SCEVConstant>(FoundRHS);
12090 V = FoundLHS;
12091 }
12092
12093 // The guarding predicate tells us that C != V. If the known range
12094 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12095 // range we consider has to correspond to same signedness as the
12096 // predicate we're interested in folding.
12097
12098 APInt Min = ICmpInst::isSigned(Pred) ?
12100
12101 if (Min == C->getAPInt()) {
12102 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12103 // This is true even if (Min + 1) wraps around -- in case of
12104 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12105
12106 APInt SharperMin = Min + 1;
12107
12108 switch (Pred) {
12109 case ICmpInst::ICMP_SGE:
12110 case ICmpInst::ICMP_UGE:
12111 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12112 // RHS, we're done.
12113 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12114 CtxI))
12115 return true;
12116 [[fallthrough]];
12117
12118 case ICmpInst::ICMP_SGT:
12119 case ICmpInst::ICMP_UGT:
12120 // We know from the range information that (V `Pred` Min ||
12121 // V == Min). We know from the guarding condition that !(V
12122 // == Min). This gives us
12123 //
12124 // V `Pred` Min || V == Min && !(V == Min)
12125 // => V `Pred` Min
12126 //
12127 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12128
12129 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12130 return true;
12131 break;
12132
12133 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12134 case ICmpInst::ICMP_SLE:
12135 case ICmpInst::ICMP_ULE:
12136 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12137 LHS, V, getConstant(SharperMin), CtxI))
12138 return true;
12139 [[fallthrough]];
12140
12141 case ICmpInst::ICMP_SLT:
12142 case ICmpInst::ICMP_ULT:
12143 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12144 LHS, V, getConstant(Min), CtxI))
12145 return true;
12146 break;
12147
12148 default:
12149 // No change
12150 break;
12151 }
12152 }
12153 }
12154
12155 // Check whether the actual condition is beyond sufficient.
12156 if (FoundPred == ICmpInst::ICMP_EQ)
12157 if (ICmpInst::isTrueWhenEqual(Pred))
12158 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12159 return true;
12160 if (Pred == ICmpInst::ICMP_NE)
12161 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12162 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12163 return true;
12164
12165 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12166 return true;
12167
12168 // Otherwise assume the worst.
12169 return false;
12170}
12171
12172bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12173 const SCEV *&L, const SCEV *&R,
12174 SCEV::NoWrapFlags &Flags) {
12175 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12176 if (!AE || AE->getNumOperands() != 2)
12177 return false;
12178
12179 L = AE->getOperand(0);
12180 R = AE->getOperand(1);
12181 Flags = AE->getNoWrapFlags();
12182 return true;
12183}
12184
12185std::optional<APInt>
12187 // We avoid subtracting expressions here because this function is usually
12188 // fairly deep in the call stack (i.e. is called many times).
12189
12190 unsigned BW = getTypeSizeInBits(More->getType());
12191 APInt Diff(BW, 0);
12192 APInt DiffMul(BW, 1);
12193 // Try various simplifications to reduce the difference to a constant. Limit
12194 // the number of allowed simplifications to keep compile-time low.
12195 for (unsigned I = 0; I < 8; ++I) {
12196 if (More == Less)
12197 return Diff;
12198
12199 // Reduce addrecs with identical steps to their start value.
12201 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12202 const auto *MAR = cast<SCEVAddRecExpr>(More);
12203
12204 if (LAR->getLoop() != MAR->getLoop())
12205 return std::nullopt;
12206
12207 // We look at affine expressions only; not for correctness but to keep
12208 // getStepRecurrence cheap.
12209 if (!LAR->isAffine() || !MAR->isAffine())
12210 return std::nullopt;
12211
12212 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12213 return std::nullopt;
12214
12215 Less = LAR->getStart();
12216 More = MAR->getStart();
12217 continue;
12218 }
12219
12220 // Try to match a common constant multiply.
12221 auto MatchConstMul =
12222 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12223 auto *M = dyn_cast<SCEVMulExpr>(S);
12224 if (!M || M->getNumOperands() != 2 ||
12225 !isa<SCEVConstant>(M->getOperand(0)))
12226 return std::nullopt;
12227 return {
12228 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12229 };
12230 if (auto MatchedMore = MatchConstMul(More)) {
12231 if (auto MatchedLess = MatchConstMul(Less)) {
12232 if (MatchedMore->second == MatchedLess->second) {
12233 More = MatchedMore->first;
12234 Less = MatchedLess->first;
12235 DiffMul *= MatchedMore->second;
12236 continue;
12237 }
12238 }
12239 }
12240
12241 // Try to cancel out common factors in two add expressions.
12243 auto Add = [&](const SCEV *S, int Mul) {
12244 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12245 if (Mul == 1) {
12246 Diff += C->getAPInt() * DiffMul;
12247 } else {
12248 assert(Mul == -1);
12249 Diff -= C->getAPInt() * DiffMul;
12250 }
12251 } else
12252 Multiplicity[S] += Mul;
12253 };
12254 auto Decompose = [&](const SCEV *S, int Mul) {
12255 if (isa<SCEVAddExpr>(S)) {
12256 for (const SCEV *Op : S->operands())
12257 Add(Op, Mul);
12258 } else
12259 Add(S, Mul);
12260 };
12261 Decompose(More, 1);
12262 Decompose(Less, -1);
12263
12264 // Check whether all the non-constants cancel out, or reduce to new
12265 // More/Less values.
12266 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12267 for (const auto &[S, Mul] : Multiplicity) {
12268 if (Mul == 0)
12269 continue;
12270 if (Mul == 1) {
12271 if (NewMore)
12272 return std::nullopt;
12273 NewMore = S;
12274 } else if (Mul == -1) {
12275 if (NewLess)
12276 return std::nullopt;
12277 NewLess = S;
12278 } else
12279 return std::nullopt;
12280 }
12281
12282 // Values stayed the same, no point in trying further.
12283 if (NewMore == More || NewLess == Less)
12284 return std::nullopt;
12285
12286 More = NewMore;
12287 Less = NewLess;
12288
12289 // Reduced to constant.
12290 if (!More && !Less)
12291 return Diff;
12292
12293 // Left with variable on only one side, bail out.
12294 if (!More || !Less)
12295 return std::nullopt;
12296 }
12297
12298 // Did not reduce to constant.
12299 return std::nullopt;
12300}
12301
12302bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12303 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12304 const SCEV *FoundRHS, const Instruction *CtxI) {
12305 // Try to recognize the following pattern:
12306 //
12307 // FoundRHS = ...
12308 // ...
12309 // loop:
12310 // FoundLHS = {Start,+,W}
12311 // context_bb: // Basic block from the same loop
12312 // known(Pred, FoundLHS, FoundRHS)
12313 //
12314 // If some predicate is known in the context of a loop, it is also known on
12315 // each iteration of this loop, including the first iteration. Therefore, in
12316 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12317 // prove the original pred using this fact.
12318 if (!CtxI)
12319 return false;
12320 const BasicBlock *ContextBB = CtxI->getParent();
12321 // Make sure AR varies in the context block.
12322 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12323 const Loop *L = AR->getLoop();
12324 // Make sure that context belongs to the loop and executes on 1st iteration
12325 // (if it ever executes at all).
12326 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12327 return false;
12328 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12329 return false;
12330 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12331 }
12332
12333 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12334 const Loop *L = AR->getLoop();
12335 // Make sure that context belongs to the loop and executes on 1st iteration
12336 // (if it ever executes at all).
12337 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12338 return false;
12339 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12340 return false;
12341 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12342 }
12343
12344 return false;
12345}
12346
12347bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12348 const SCEV *LHS,
12349 const SCEV *RHS,
12350 const SCEV *FoundLHS,
12351 const SCEV *FoundRHS) {
12352 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12353 return false;
12354
12355 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12356 if (!AddRecLHS)
12357 return false;
12358
12359 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12360 if (!AddRecFoundLHS)
12361 return false;
12362
12363 // We'd like to let SCEV reason about control dependencies, so we constrain
12364 // both the inequalities to be about add recurrences on the same loop. This
12365 // way we can use isLoopEntryGuardedByCond later.
12366
12367 const Loop *L = AddRecFoundLHS->getLoop();
12368 if (L != AddRecLHS->getLoop())
12369 return false;
12370
12371 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12372 //
12373 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12374 // ... (2)
12375 //
12376 // Informal proof for (2), assuming (1) [*]:
12377 //
12378 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12379 //
12380 // Then
12381 //
12382 // FoundLHS s< FoundRHS s< INT_MIN - C
12383 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12384 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12385 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12386 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12387 // <=> FoundLHS + C s< FoundRHS + C
12388 //
12389 // [*]: (1) can be proved by ruling out overflow.
12390 //
12391 // [**]: This can be proved by analyzing all the four possibilities:
12392 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12393 // (A s>= 0, B s>= 0).
12394 //
12395 // Note:
12396 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12397 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12398 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12399 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12400 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12401 // C)".
12402
12403 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12404 if (!LDiff)
12405 return false;
12406 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12407 if (!RDiff || *LDiff != *RDiff)
12408 return false;
12409
12410 if (LDiff->isMinValue())
12411 return true;
12412
12413 APInt FoundRHSLimit;
12414
12415 if (Pred == CmpInst::ICMP_ULT) {
12416 FoundRHSLimit = -(*RDiff);
12417 } else {
12418 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12419 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12420 }
12421
12422 // Try to prove (1) or (2), as needed.
12423 return isAvailableAtLoopEntry(FoundRHS, L) &&
12424 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12425 getConstant(FoundRHSLimit));
12426}
12427
12428bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12429 const SCEV *RHS, const SCEV *FoundLHS,
12430 const SCEV *FoundRHS, unsigned Depth) {
12431 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12432
12433 auto ClearOnExit = make_scope_exit([&]() {
12434 if (LPhi) {
12435 bool Erased = PendingMerges.erase(LPhi);
12436 assert(Erased && "Failed to erase LPhi!");
12437 (void)Erased;
12438 }
12439 if (RPhi) {
12440 bool Erased = PendingMerges.erase(RPhi);
12441 assert(Erased && "Failed to erase RPhi!");
12442 (void)Erased;
12443 }
12444 });
12445
12446 // Find respective Phis and check that they are not being pending.
12447 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12448 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12449 if (!PendingMerges.insert(Phi).second)
12450 return false;
12451 LPhi = Phi;
12452 }
12453 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12454 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12455 // If we detect a loop of Phi nodes being processed by this method, for
12456 // example:
12457 //
12458 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12459 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12460 //
12461 // we don't want to deal with a case that complex, so return conservative
12462 // answer false.
12463 if (!PendingMerges.insert(Phi).second)
12464 return false;
12465 RPhi = Phi;
12466 }
12467
12468 // If none of LHS, RHS is a Phi, nothing to do here.
12469 if (!LPhi && !RPhi)
12470 return false;
12471
12472 // If there is a SCEVUnknown Phi we are interested in, make it left.
12473 if (!LPhi) {
12474 std::swap(LHS, RHS);
12475 std::swap(FoundLHS, FoundRHS);
12476 std::swap(LPhi, RPhi);
12478 }
12479
12480 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12481 const BasicBlock *LBB = LPhi->getParent();
12482 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12483
12484 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12485 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12486 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12487 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12488 };
12489
12490 if (RPhi && RPhi->getParent() == LBB) {
12491 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12492 // If we compare two Phis from the same block, and for each entry block
12493 // the predicate is true for incoming values from this block, then the
12494 // predicate is also true for the Phis.
12495 for (const BasicBlock *IncBB : predecessors(LBB)) {
12496 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12497 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12498 if (!ProvedEasily(L, R))
12499 return false;
12500 }
12501 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12502 // Case two: RHS is also a Phi from the same basic block, and it is an
12503 // AddRec. It means that there is a loop which has both AddRec and Unknown
12504 // PHIs, for it we can compare incoming values of AddRec from above the loop
12505 // and latch with their respective incoming values of LPhi.
12506 // TODO: Generalize to handle loops with many inputs in a header.
12507 if (LPhi->getNumIncomingValues() != 2) return false;
12508
12509 auto *RLoop = RAR->getLoop();
12510 auto *Predecessor = RLoop->getLoopPredecessor();
12511 assert(Predecessor && "Loop with AddRec with no predecessor?");
12512 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12513 if (!ProvedEasily(L1, RAR->getStart()))
12514 return false;
12515 auto *Latch = RLoop->getLoopLatch();
12516 assert(Latch && "Loop with AddRec with no latch?");
12517 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12518 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12519 return false;
12520 } else {
12521 // In all other cases go over inputs of LHS and compare each of them to RHS,
12522 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12523 // At this point RHS is either a non-Phi, or it is a Phi from some block
12524 // different from LBB.
12525 for (const BasicBlock *IncBB : predecessors(LBB)) {
12526 // Check that RHS is available in this block.
12527 if (!dominates(RHS, IncBB))
12528 return false;
12529 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12530 // Make sure L does not refer to a value from a potentially previous
12531 // iteration of a loop.
12532 if (!properlyDominates(L, LBB))
12533 return false;
12534 // Addrecs are considered to properly dominate their loop, so are missed
12535 // by the previous check. Discard any values that have computable
12536 // evolution in this loop.
12537 if (auto *Loop = LI.getLoopFor(LBB))
12538 if (hasComputableLoopEvolution(L, Loop))
12539 return false;
12540 if (!ProvedEasily(L, RHS))
12541 return false;
12542 }
12543 }
12544 return true;
12545}
12546
12547bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12548 const SCEV *LHS,
12549 const SCEV *RHS,
12550 const SCEV *FoundLHS,
12551 const SCEV *FoundRHS) {
12552 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12553 // sure that we are dealing with same LHS.
12554 if (RHS == FoundRHS) {
12555 std::swap(LHS, RHS);
12556 std::swap(FoundLHS, FoundRHS);
12558 }
12559 if (LHS != FoundLHS)
12560 return false;
12561
12562 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12563 if (!SUFoundRHS)
12564 return false;
12565
12566 Value *Shiftee, *ShiftValue;
12567
12568 using namespace PatternMatch;
12569 if (match(SUFoundRHS->getValue(),
12570 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12571 auto *ShifteeS = getSCEV(Shiftee);
12572 // Prove one of the following:
12573 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12574 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12575 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12576 // ---> LHS <s RHS
12577 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12578 // ---> LHS <=s RHS
12579 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12580 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12581 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12582 if (isKnownNonNegative(ShifteeS))
12583 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12584 }
12585
12586 return false;
12587}
12588
12589bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12590 const SCEV *RHS,
12591 const SCEV *FoundLHS,
12592 const SCEV *FoundRHS,
12593 const Instruction *CtxI) {
12594 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12595 FoundRHS) ||
12596 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12597 FoundRHS) ||
12598 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12599 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12600 CtxI) ||
12601 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12602}
12603
12604/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12605template <typename MinMaxExprType>
12606static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12607 const SCEV *Candidate) {
12608 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12609 if (!MinMaxExpr)
12610 return false;
12611
12612 return is_contained(MinMaxExpr->operands(), Candidate);
12613}
12614
12616 CmpPredicate Pred, const SCEV *LHS,
12617 const SCEV *RHS) {
12618 // If both sides are affine addrecs for the same loop, with equal
12619 // steps, and we know the recurrences don't wrap, then we only
12620 // need to check the predicate on the starting values.
12621
12622 if (!ICmpInst::isRelational(Pred))
12623 return false;
12624
12625 const SCEV *LStart, *RStart, *Step;
12626 const Loop *L;
12627 if (!match(LHS,
12628 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12630 m_SpecificLoop(L))))
12631 return false;
12636 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12637 return false;
12638
12639 return SE.isKnownPredicate(Pred, LStart, RStart);
12640}
12641
12642/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12643/// expression?
12645 const SCEV *LHS, const SCEV *RHS) {
12646 switch (Pred) {
12647 default:
12648 return false;
12649
12650 case ICmpInst::ICMP_SGE:
12651 std::swap(LHS, RHS);
12652 [[fallthrough]];
12653 case ICmpInst::ICMP_SLE:
12654 return
12655 // min(A, ...) <= A
12657 // A <= max(A, ...)
12659
12660 case ICmpInst::ICMP_UGE:
12661 std::swap(LHS, RHS);
12662 [[fallthrough]];
12663 case ICmpInst::ICMP_ULE:
12664 return
12665 // min(A, ...) <= A
12666 // FIXME: what about umin_seq?
12668 // A <= max(A, ...)
12670 }
12671
12672 llvm_unreachable("covered switch fell through?!");
12673}
12674
12675bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12676 const SCEV *RHS,
12677 const SCEV *FoundLHS,
12678 const SCEV *FoundRHS,
12679 unsigned Depth) {
12682 "LHS and RHS have different sizes?");
12683 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12684 getTypeSizeInBits(FoundRHS->getType()) &&
12685 "FoundLHS and FoundRHS have different sizes?");
12686 // We want to avoid hurting the compile time with analysis of too big trees.
12688 return false;
12689
12690 // We only want to work with GT comparison so far.
12691 if (ICmpInst::isLT(Pred)) {
12693 std::swap(LHS, RHS);
12694 std::swap(FoundLHS, FoundRHS);
12695 }
12696
12698
12699 // For unsigned, try to reduce it to corresponding signed comparison.
12700 if (P == ICmpInst::ICMP_UGT)
12701 // We can replace unsigned predicate with its signed counterpart if all
12702 // involved values are non-negative.
12703 // TODO: We could have better support for unsigned.
12704 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12705 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12706 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12707 // use this fact to prove that LHS and RHS are non-negative.
12708 const SCEV *MinusOne = getMinusOne(LHS->getType());
12709 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12710 FoundRHS) &&
12711 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12712 FoundRHS))
12714 }
12715
12716 if (P != ICmpInst::ICMP_SGT)
12717 return false;
12718
12719 auto GetOpFromSExt = [&](const SCEV *S) {
12720 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12721 return Ext->getOperand();
12722 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12723 // the constant in some cases.
12724 return S;
12725 };
12726
12727 // Acquire values from extensions.
12728 auto *OrigLHS = LHS;
12729 auto *OrigFoundLHS = FoundLHS;
12730 LHS = GetOpFromSExt(LHS);
12731 FoundLHS = GetOpFromSExt(FoundLHS);
12732
12733 // Is the SGT predicate can be proved trivially or using the found context.
12734 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12735 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12736 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12737 FoundRHS, Depth + 1);
12738 };
12739
12740 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12741 // We want to avoid creation of any new non-constant SCEV. Since we are
12742 // going to compare the operands to RHS, we should be certain that we don't
12743 // need any size extensions for this. So let's decline all cases when the
12744 // sizes of types of LHS and RHS do not match.
12745 // TODO: Maybe try to get RHS from sext to catch more cases?
12747 return false;
12748
12749 // Should not overflow.
12750 if (!LHSAddExpr->hasNoSignedWrap())
12751 return false;
12752
12753 auto *LL = LHSAddExpr->getOperand(0);
12754 auto *LR = LHSAddExpr->getOperand(1);
12755 auto *MinusOne = getMinusOne(RHS->getType());
12756
12757 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12758 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12759 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12760 };
12761 // Try to prove the following rule:
12762 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12763 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12764 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12765 return true;
12766 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12767 Value *LL, *LR;
12768 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12769
12770 using namespace llvm::PatternMatch;
12771
12772 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12773 // Rules for division.
12774 // We are going to perform some comparisons with Denominator and its
12775 // derivative expressions. In general case, creating a SCEV for it may
12776 // lead to a complex analysis of the entire graph, and in particular it
12777 // can request trip count recalculation for the same loop. This would
12778 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12779 // this, we only want to create SCEVs that are constants in this section.
12780 // So we bail if Denominator is not a constant.
12781 if (!isa<ConstantInt>(LR))
12782 return false;
12783
12784 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12785
12786 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12787 // then a SCEV for the numerator already exists and matches with FoundLHS.
12788 auto *Numerator = getExistingSCEV(LL);
12789 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12790 return false;
12791
12792 // Make sure that the numerator matches with FoundLHS and the denominator
12793 // is positive.
12794 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12795 return false;
12796
12797 auto *DTy = Denominator->getType();
12798 auto *FRHSTy = FoundRHS->getType();
12799 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12800 // One of types is a pointer and another one is not. We cannot extend
12801 // them properly to a wider type, so let us just reject this case.
12802 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12803 // to avoid this check.
12804 return false;
12805
12806 // Given that:
12807 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12808 auto *WTy = getWiderType(DTy, FRHSTy);
12809 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12810 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12811
12812 // Try to prove the following rule:
12813 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12814 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12815 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12816 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12817 if (isKnownNonPositive(RHS) &&
12818 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12819 return true;
12820
12821 // Try to prove the following rule:
12822 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12823 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12824 // If we divide it by Denominator > 2, then:
12825 // 1. If FoundLHS is negative, then the result is 0.
12826 // 2. If FoundLHS is non-negative, then the result is non-negative.
12827 // Anyways, the result is non-negative.
12828 auto *MinusOne = getMinusOne(WTy);
12829 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12830 if (isKnownNegative(RHS) &&
12831 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12832 return true;
12833 }
12834 }
12835
12836 // If our expression contained SCEVUnknown Phis, and we split it down and now
12837 // need to prove something for them, try to prove the predicate for every
12838 // possible incoming values of those Phis.
12839 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12840 return true;
12841
12842 return false;
12843}
12844
12846 const SCEV *RHS) {
12847 // zext x u<= sext x, sext x s<= zext x
12848 const SCEV *Op;
12849 switch (Pred) {
12850 case ICmpInst::ICMP_SGE:
12851 std::swap(LHS, RHS);
12852 [[fallthrough]];
12853 case ICmpInst::ICMP_SLE: {
12854 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12855 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12857 }
12858 case ICmpInst::ICMP_UGE:
12859 std::swap(LHS, RHS);
12860 [[fallthrough]];
12861 case ICmpInst::ICMP_ULE: {
12862 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12863 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12865 }
12866 default:
12867 return false;
12868 };
12869 llvm_unreachable("unhandled case");
12870}
12871
12872bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12873 const SCEV *LHS,
12874 const SCEV *RHS) {
12875 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12876 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12877 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12878 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12879 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12880}
12881
12882bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12883 const SCEV *LHS,
12884 const SCEV *RHS,
12885 const SCEV *FoundLHS,
12886 const SCEV *FoundRHS) {
12887 switch (Pred) {
12888 default:
12889 llvm_unreachable("Unexpected CmpPredicate value!");
12890 case ICmpInst::ICMP_EQ:
12891 case ICmpInst::ICMP_NE:
12892 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12893 return true;
12894 break;
12895 case ICmpInst::ICMP_SLT:
12896 case ICmpInst::ICMP_SLE:
12897 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12898 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12899 return true;
12900 break;
12901 case ICmpInst::ICMP_SGT:
12902 case ICmpInst::ICMP_SGE:
12903 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12904 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12905 return true;
12906 break;
12907 case ICmpInst::ICMP_ULT:
12908 case ICmpInst::ICMP_ULE:
12909 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12910 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12911 return true;
12912 break;
12913 case ICmpInst::ICMP_UGT:
12914 case ICmpInst::ICMP_UGE:
12915 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12916 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12917 return true;
12918 break;
12919 }
12920
12921 // Maybe it can be proved via operations?
12922 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12923 return true;
12924
12925 return false;
12926}
12927
12928bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12929 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12930 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12931 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12932 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12933 // reduce the compile time impact of this optimization.
12934 return false;
12935
12936 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12937 if (!Addend)
12938 return false;
12939
12940 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12941
12942 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12943 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12944 ConstantRange FoundLHSRange =
12945 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12946
12947 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12948 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12949
12950 // We can also compute the range of values for `LHS` that satisfy the
12951 // consequent, "`LHS` `Pred` `RHS`":
12952 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12953 // The antecedent implies the consequent if every value of `LHS` that
12954 // satisfies the antecedent also satisfies the consequent.
12955 return LHSRange.icmp(Pred, ConstRHS);
12956}
12957
12958bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12959 bool IsSigned) {
12960 assert(isKnownPositive(Stride) && "Positive stride expected!");
12961
12962 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12963 const SCEV *One = getOne(Stride->getType());
12964
12965 if (IsSigned) {
12966 APInt MaxRHS = getSignedRangeMax(RHS);
12967 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12968 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12969
12970 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12971 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12972 }
12973
12974 APInt MaxRHS = getUnsignedRangeMax(RHS);
12975 APInt MaxValue = APInt::getMaxValue(BitWidth);
12976 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12977
12978 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12979 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12980}
12981
12982bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12983 bool IsSigned) {
12984
12985 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12986 const SCEV *One = getOne(Stride->getType());
12987
12988 if (IsSigned) {
12989 APInt MinRHS = getSignedRangeMin(RHS);
12990 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12991 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12992
12993 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12994 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12995 }
12996
12997 APInt MinRHS = getUnsignedRangeMin(RHS);
12998 APInt MinValue = APInt::getMinValue(BitWidth);
12999 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13000
13001 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13002 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13003}
13004
13006 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13007 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13008 // expression fixes the case of N=0.
13009 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13010 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13011 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13012}
13013
13014const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13015 const SCEV *Stride,
13016 const SCEV *End,
13017 unsigned BitWidth,
13018 bool IsSigned) {
13019 // The logic in this function assumes we can represent a positive stride.
13020 // If we can't, the backedge-taken count must be zero.
13021 if (IsSigned && BitWidth == 1)
13022 return getZero(Stride->getType());
13023
13024 // This code below only been closely audited for negative strides in the
13025 // unsigned comparison case, it may be correct for signed comparison, but
13026 // that needs to be established.
13027 if (IsSigned && isKnownNegative(Stride))
13028 return getCouldNotCompute();
13029
13030 // Calculate the maximum backedge count based on the range of values
13031 // permitted by Start, End, and Stride.
13032 APInt MinStart =
13033 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13034
13035 APInt MinStride =
13036 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13037
13038 // We assume either the stride is positive, or the backedge-taken count
13039 // is zero. So force StrideForMaxBECount to be at least one.
13040 APInt One(BitWidth, 1);
13041 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13042 : APIntOps::umax(One, MinStride);
13043
13044 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13045 : APInt::getMaxValue(BitWidth);
13046 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13047
13048 // Although End can be a MAX expression we estimate MaxEnd considering only
13049 // the case End = RHS of the loop termination condition. This is safe because
13050 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13051 // taken count.
13052 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13053 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13054
13055 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13056 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13057 : APIntOps::umax(MaxEnd, MinStart);
13058
13059 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13060 getConstant(StrideForMaxBECount) /* Step */);
13061}
13062
13064ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13065 const Loop *L, bool IsSigned,
13066 bool ControlsOnlyExit, bool AllowPredicates) {
13068
13070 bool PredicatedIV = false;
13071 if (!IV) {
13072 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13073 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13074 if (AR && AR->getLoop() == L && AR->isAffine()) {
13075 auto canProveNUW = [&]() {
13076 // We can use the comparison to infer no-wrap flags only if it fully
13077 // controls the loop exit.
13078 if (!ControlsOnlyExit)
13079 return false;
13080
13081 if (!isLoopInvariant(RHS, L))
13082 return false;
13083
13084 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13085 // We need the sequence defined by AR to strictly increase in the
13086 // unsigned integer domain for the logic below to hold.
13087 return false;
13088
13089 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13090 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13091 // If RHS <=u Limit, then there must exist a value V in the sequence
13092 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13093 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13094 // overflow occurs. This limit also implies that a signed comparison
13095 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13096 // the high bits on both sides must be zero.
13097 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13098 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13099 Limit = Limit.zext(OuterBitWidth);
13100 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13101 };
13102 auto Flags = AR->getNoWrapFlags();
13103 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13104 Flags = setFlags(Flags, SCEV::FlagNUW);
13105
13106 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13107 if (AR->hasNoUnsignedWrap()) {
13108 // Emulate what getZeroExtendExpr would have done during construction
13109 // if we'd been able to infer the fact just above at that time.
13110 const SCEV *Step = AR->getStepRecurrence(*this);
13111 Type *Ty = ZExt->getType();
13112 auto *S = getAddRecExpr(
13114 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13116 }
13117 }
13118 }
13119 }
13120
13121
13122 if (!IV && AllowPredicates) {
13123 // Try to make this an AddRec using runtime tests, in the first X
13124 // iterations of this loop, where X is the SCEV expression found by the
13125 // algorithm below.
13126 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13127 PredicatedIV = true;
13128 }
13129
13130 // Avoid weird loops
13131 if (!IV || IV->getLoop() != L || !IV->isAffine())
13132 return getCouldNotCompute();
13133
13134 // A precondition of this method is that the condition being analyzed
13135 // reaches an exiting branch which dominates the latch. Given that, we can
13136 // assume that an increment which violates the nowrap specification and
13137 // produces poison must cause undefined behavior when the resulting poison
13138 // value is branched upon and thus we can conclude that the backedge is
13139 // taken no more often than would be required to produce that poison value.
13140 // Note that a well defined loop can exit on the iteration which violates
13141 // the nowrap specification if there is another exit (either explicit or
13142 // implicit/exceptional) which causes the loop to execute before the
13143 // exiting instruction we're analyzing would trigger UB.
13144 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13145 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13147
13148 const SCEV *Stride = IV->getStepRecurrence(*this);
13149
13150 bool PositiveStride = isKnownPositive(Stride);
13151
13152 // Avoid negative or zero stride values.
13153 if (!PositiveStride) {
13154 // We can compute the correct backedge taken count for loops with unknown
13155 // strides if we can prove that the loop is not an infinite loop with side
13156 // effects. Here's the loop structure we are trying to handle -
13157 //
13158 // i = start
13159 // do {
13160 // A[i] = i;
13161 // i += s;
13162 // } while (i < end);
13163 //
13164 // The backedge taken count for such loops is evaluated as -
13165 // (max(end, start + stride) - start - 1) /u stride
13166 //
13167 // The additional preconditions that we need to check to prove correctness
13168 // of the above formula is as follows -
13169 //
13170 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13171 // NoWrap flag).
13172 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13173 // no side effects within the loop)
13174 // c) loop has a single static exit (with no abnormal exits)
13175 //
13176 // Precondition a) implies that if the stride is negative, this is a single
13177 // trip loop. The backedge taken count formula reduces to zero in this case.
13178 //
13179 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13180 // then a zero stride means the backedge can't be taken without executing
13181 // undefined behavior.
13182 //
13183 // The positive stride case is the same as isKnownPositive(Stride) returning
13184 // true (original behavior of the function).
13185 //
13186 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13188 return getCouldNotCompute();
13189
13190 if (!isKnownNonZero(Stride)) {
13191 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13192 // if it might eventually be greater than start and if so, on which
13193 // iteration. We can't even produce a useful upper bound.
13194 if (!isLoopInvariant(RHS, L))
13195 return getCouldNotCompute();
13196
13197 // We allow a potentially zero stride, but we need to divide by stride
13198 // below. Since the loop can't be infinite and this check must control
13199 // the sole exit, we can infer the exit must be taken on the first
13200 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13201 // we know the numerator in the divides below must be zero, so we can
13202 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13203 // and produce the right result.
13204 // FIXME: Handle the case where Stride is poison?
13205 auto wouldZeroStrideBeUB = [&]() {
13206 // Proof by contradiction. Suppose the stride were zero. If we can
13207 // prove that the backedge *is* taken on the first iteration, then since
13208 // we know this condition controls the sole exit, we must have an
13209 // infinite loop. We can't have a (well defined) infinite loop per
13210 // check just above.
13211 // Note: The (Start - Stride) term is used to get the start' term from
13212 // (start' + stride,+,stride). Remember that we only care about the
13213 // result of this expression when stride == 0 at runtime.
13214 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13215 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13216 };
13217 if (!wouldZeroStrideBeUB()) {
13218 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13219 }
13220 }
13221 } else if (!NoWrap) {
13222 // Avoid proven overflow cases: this will ensure that the backedge taken
13223 // count will not generate any unsigned overflow.
13224 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13225 return getCouldNotCompute();
13226 }
13227
13228 // On all paths just preceeding, we established the following invariant:
13229 // IV can be assumed not to overflow up to and including the exiting
13230 // iteration. We proved this in one of two ways:
13231 // 1) We can show overflow doesn't occur before the exiting iteration
13232 // 1a) canIVOverflowOnLT, and b) step of one
13233 // 2) We can show that if overflow occurs, the loop must execute UB
13234 // before any possible exit.
13235 // Note that we have not yet proved RHS invariant (in general).
13236
13237 const SCEV *Start = IV->getStart();
13238
13239 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13240 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13241 // Use integer-typed versions for actual computation; we can't subtract
13242 // pointers in general.
13243 const SCEV *OrigStart = Start;
13244 const SCEV *OrigRHS = RHS;
13245 if (Start->getType()->isPointerTy()) {
13247 if (isa<SCEVCouldNotCompute>(Start))
13248 return Start;
13249 }
13250 if (RHS->getType()->isPointerTy()) {
13253 return RHS;
13254 }
13255
13256 const SCEV *End = nullptr, *BECount = nullptr,
13257 *BECountIfBackedgeTaken = nullptr;
13258 if (!isLoopInvariant(RHS, L)) {
13259 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13260 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13261 RHSAddRec->getNoWrapFlags()) {
13262 // The structure of loop we are trying to calculate backedge count of:
13263 //
13264 // left = left_start
13265 // right = right_start
13266 //
13267 // while(left < right){
13268 // ... do something here ...
13269 // left += s1; // stride of left is s1 (s1 > 0)
13270 // right += s2; // stride of right is s2 (s2 < 0)
13271 // }
13272 //
13273
13274 const SCEV *RHSStart = RHSAddRec->getStart();
13275 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13276
13277 // If Stride - RHSStride is positive and does not overflow, we can write
13278 // backedge count as ->
13279 // ceil((End - Start) /u (Stride - RHSStride))
13280 // Where, End = max(RHSStart, Start)
13281
13282 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13283 if (isKnownNegative(RHSStride) &&
13284 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13285 RHSStride)) {
13286
13287 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13288 if (isKnownPositive(Denominator)) {
13289 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13290 : getUMaxExpr(RHSStart, Start);
13291
13292 // We can do this because End >= Start, as End = max(RHSStart, Start)
13293 const SCEV *Delta = getMinusSCEV(End, Start);
13294
13295 BECount = getUDivCeilSCEV(Delta, Denominator);
13296 BECountIfBackedgeTaken =
13297 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13298 }
13299 }
13300 }
13301 if (BECount == nullptr) {
13302 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13303 // given the start, stride and max value for the end bound of the
13304 // loop (RHS), and the fact that IV does not overflow (which is
13305 // checked above).
13306 const SCEV *MaxBECount = computeMaxBECountForLT(
13307 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13308 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13309 MaxBECount, false /*MaxOrZero*/, Predicates);
13310 }
13311 } else {
13312 // We use the expression (max(End,Start)-Start)/Stride to describe the
13313 // backedge count, as if the backedge is taken at least once
13314 // max(End,Start) is End and so the result is as above, and if not
13315 // max(End,Start) is Start so we get a backedge count of zero.
13316 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13317 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13318 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13319 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13320 // Can we prove (max(RHS,Start) > Start - Stride?
13321 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13322 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13323 // In this case, we can use a refined formula for computing backedge
13324 // taken count. The general formula remains:
13325 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13326 // We want to use the alternate formula:
13327 // "((End - 1) - (Start - Stride)) /u Stride"
13328 // Let's do a quick case analysis to show these are equivalent under
13329 // our precondition that max(RHS,Start) > Start - Stride.
13330 // * For RHS <= Start, the backedge-taken count must be zero.
13331 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13332 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13333 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13334 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13335 // reducing this to the stride of 1 case.
13336 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13337 // Stride".
13338 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13339 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13340 // "((RHS - (Start - Stride) - 1) /u Stride".
13341 // Our preconditions trivially imply no overflow in that form.
13342 const SCEV *MinusOne = getMinusOne(Stride->getType());
13343 const SCEV *Numerator =
13344 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13345 BECount = getUDivExpr(Numerator, Stride);
13346 }
13347
13348 if (!BECount) {
13349 auto canProveRHSGreaterThanEqualStart = [&]() {
13350 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13351 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13352 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13353
13354 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13355 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13356 return true;
13357
13358 // (RHS > Start - 1) implies RHS >= Start.
13359 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13360 // "Start - 1" doesn't overflow.
13361 // * For signed comparison, if Start - 1 does overflow, it's equal
13362 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13363 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13364 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13365 //
13366 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13367 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13368 auto *StartMinusOne =
13369 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13370 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13371 };
13372
13373 // If we know that RHS >= Start in the context of loop, then we know
13374 // that max(RHS, Start) = RHS at this point.
13375 if (canProveRHSGreaterThanEqualStart()) {
13376 End = RHS;
13377 } else {
13378 // If RHS < Start, the backedge will be taken zero times. So in
13379 // general, we can write the backedge-taken count as:
13380 //
13381 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13382 //
13383 // We convert it to the following to make it more convenient for SCEV:
13384 //
13385 // ceil(max(RHS, Start) - Start) / Stride
13386 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13387
13388 // See what would happen if we assume the backedge is taken. This is
13389 // used to compute MaxBECount.
13390 BECountIfBackedgeTaken =
13391 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13392 }
13393
13394 // At this point, we know:
13395 //
13396 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13397 // 2. The index variable doesn't overflow.
13398 //
13399 // Therefore, we know N exists such that
13400 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13401 // doesn't overflow.
13402 //
13403 // Using this information, try to prove whether the addition in
13404 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13405 const SCEV *One = getOne(Stride->getType());
13406 bool MayAddOverflow = [&] {
13407 if (isKnownToBeAPowerOfTwo(Stride)) {
13408 // Suppose Stride is a power of two, and Start/End are unsigned
13409 // integers. Let UMAX be the largest representable unsigned
13410 // integer.
13411 //
13412 // By the preconditions of this function, we know
13413 // "(Start + Stride * N) >= End", and this doesn't overflow.
13414 // As a formula:
13415 //
13416 // End <= (Start + Stride * N) <= UMAX
13417 //
13418 // Subtracting Start from all the terms:
13419 //
13420 // End - Start <= Stride * N <= UMAX - Start
13421 //
13422 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13423 //
13424 // End - Start <= Stride * N <= UMAX
13425 //
13426 // Stride * N is a multiple of Stride. Therefore,
13427 //
13428 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13429 //
13430 // Since Stride is a power of two, UMAX + 1 is divisible by
13431 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13432 // write:
13433 //
13434 // End - Start <= Stride * N <= UMAX - Stride - 1
13435 //
13436 // Dropping the middle term:
13437 //
13438 // End - Start <= UMAX - Stride - 1
13439 //
13440 // Adding Stride - 1 to both sides:
13441 //
13442 // (End - Start) + (Stride - 1) <= UMAX
13443 //
13444 // In other words, the addition doesn't have unsigned overflow.
13445 //
13446 // A similar proof works if we treat Start/End as signed values.
13447 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13448 // to use signed max instead of unsigned max. Note that we're
13449 // trying to prove a lack of unsigned overflow in either case.
13450 return false;
13451 }
13452 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13453 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13454 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13455 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13456 // 1 <s End.
13457 //
13458 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13459 // End.
13460 return false;
13461 }
13462 return true;
13463 }();
13464
13465 const SCEV *Delta = getMinusSCEV(End, Start);
13466 if (!MayAddOverflow) {
13467 // floor((D + (S - 1)) / S)
13468 // We prefer this formulation if it's legal because it's fewer
13469 // operations.
13470 BECount =
13471 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13472 } else {
13473 BECount = getUDivCeilSCEV(Delta, Stride);
13474 }
13475 }
13476 }
13477
13478 const SCEV *ConstantMaxBECount;
13479 bool MaxOrZero = false;
13480 if (isa<SCEVConstant>(BECount)) {
13481 ConstantMaxBECount = BECount;
13482 } else if (BECountIfBackedgeTaken &&
13483 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13484 // If we know exactly how many times the backedge will be taken if it's
13485 // taken at least once, then the backedge count will either be that or
13486 // zero.
13487 ConstantMaxBECount = BECountIfBackedgeTaken;
13488 MaxOrZero = true;
13489 } else {
13490 ConstantMaxBECount = computeMaxBECountForLT(
13491 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13492 }
13493
13494 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13495 !isa<SCEVCouldNotCompute>(BECount))
13496 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13497
13498 const SCEV *SymbolicMaxBECount =
13499 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13500 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13501 Predicates);
13502}
13503
13504ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13505 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13506 bool ControlsOnlyExit, bool AllowPredicates) {
13508 // We handle only IV > Invariant
13509 if (!isLoopInvariant(RHS, L))
13510 return getCouldNotCompute();
13511
13512 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13513 if (!IV && AllowPredicates)
13514 // Try to make this an AddRec using runtime tests, in the first X
13515 // iterations of this loop, where X is the SCEV expression found by the
13516 // algorithm below.
13517 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13518
13519 // Avoid weird loops
13520 if (!IV || IV->getLoop() != L || !IV->isAffine())
13521 return getCouldNotCompute();
13522
13523 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13524 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13526
13527 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13528
13529 // Avoid negative or zero stride values
13530 if (!isKnownPositive(Stride))
13531 return getCouldNotCompute();
13532
13533 // Avoid proven overflow cases: this will ensure that the backedge taken count
13534 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13535 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13536 // behaviors like the case of C language.
13537 if (!Stride->isOne() && !NoWrap)
13538 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13539 return getCouldNotCompute();
13540
13541 const SCEV *Start = IV->getStart();
13542 const SCEV *End = RHS;
13543 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13544 // If we know that Start >= RHS in the context of loop, then we know that
13545 // min(RHS, Start) = RHS at this point.
13547 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13548 End = RHS;
13549 else
13550 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13551 }
13552
13553 if (Start->getType()->isPointerTy()) {
13555 if (isa<SCEVCouldNotCompute>(Start))
13556 return Start;
13557 }
13558 if (End->getType()->isPointerTy()) {
13559 End = getLosslessPtrToIntExpr(End);
13560 if (isa<SCEVCouldNotCompute>(End))
13561 return End;
13562 }
13563
13564 // Compute ((Start - End) + (Stride - 1)) / Stride.
13565 // FIXME: This can overflow. Holding off on fixing this for now;
13566 // howManyGreaterThans will hopefully be gone soon.
13567 const SCEV *One = getOne(Stride->getType());
13568 const SCEV *BECount = getUDivExpr(
13569 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13570
13571 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13573
13574 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13575 : getUnsignedRangeMin(Stride);
13576
13577 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13578 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13579 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13580
13581 // Although End can be a MIN expression we estimate MinEnd considering only
13582 // the case End = RHS. This is safe because in the other case (Start - End)
13583 // is zero, leading to a zero maximum backedge taken count.
13584 APInt MinEnd =
13585 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13586 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13587
13588 const SCEV *ConstantMaxBECount =
13589 isa<SCEVConstant>(BECount)
13590 ? BECount
13591 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13592 getConstant(MinStride));
13593
13594 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13595 ConstantMaxBECount = BECount;
13596 const SCEV *SymbolicMaxBECount =
13597 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13598
13599 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13600 Predicates);
13601}
13602
13604 ScalarEvolution &SE) const {
13605 if (Range.isFullSet()) // Infinite loop.
13606 return SE.getCouldNotCompute();
13607
13608 // If the start is a non-zero constant, shift the range to simplify things.
13609 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13610 if (!SC->getValue()->isZero()) {
13612 Operands[0] = SE.getZero(SC->getType());
13613 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13615 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13616 return ShiftedAddRec->getNumIterationsInRange(
13617 Range.subtract(SC->getAPInt()), SE);
13618 // This is strange and shouldn't happen.
13619 return SE.getCouldNotCompute();
13620 }
13621
13622 // The only time we can solve this is when we have all constant indices.
13623 // Otherwise, we cannot determine the overflow conditions.
13624 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13625 return SE.getCouldNotCompute();
13626
13627 // Okay at this point we know that all elements of the chrec are constants and
13628 // that the start element is zero.
13629
13630 // First check to see if the range contains zero. If not, the first
13631 // iteration exits.
13632 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13633 if (!Range.contains(APInt(BitWidth, 0)))
13634 return SE.getZero(getType());
13635
13636 if (isAffine()) {
13637 // If this is an affine expression then we have this situation:
13638 // Solve {0,+,A} in Range === Ax in Range
13639
13640 // We know that zero is in the range. If A is positive then we know that
13641 // the upper value of the range must be the first possible exit value.
13642 // If A is negative then the lower of the range is the last possible loop
13643 // value. Also note that we already checked for a full range.
13644 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13645 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13646
13647 // The exit value should be (End+A)/A.
13648 APInt ExitVal = (End + A).udiv(A);
13649 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13650
13651 // Evaluate at the exit value. If we really did fall out of the valid
13652 // range, then we computed our trip count, otherwise wrap around or other
13653 // things must have happened.
13654 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13655 if (Range.contains(Val->getValue()))
13656 return SE.getCouldNotCompute(); // Something strange happened
13657
13658 // Ensure that the previous value is in the range.
13659 assert(Range.contains(
13661 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13662 "Linear scev computation is off in a bad way!");
13663 return SE.getConstant(ExitValue);
13664 }
13665
13666 if (isQuadratic()) {
13667 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13668 return SE.getConstant(*S);
13669 }
13670
13671 return SE.getCouldNotCompute();
13672}
13673
13674const SCEVAddRecExpr *
13676 assert(getNumOperands() > 1 && "AddRec with zero step?");
13677 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13678 // but in this case we cannot guarantee that the value returned will be an
13679 // AddRec because SCEV does not have a fixed point where it stops
13680 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13681 // may happen if we reach arithmetic depth limit while simplifying. So we
13682 // construct the returned value explicitly.
13684 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13685 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13686 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13687 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13688 // We know that the last operand is not a constant zero (otherwise it would
13689 // have been popped out earlier). This guarantees us that if the result has
13690 // the same last operand, then it will also not be popped out, meaning that
13691 // the returned value will be an AddRec.
13692 const SCEV *Last = getOperand(getNumOperands() - 1);
13693 assert(!Last->isZero() && "Recurrency with zero step?");
13694 Ops.push_back(Last);
13697}
13698
13699// Return true when S contains at least an undef value.
13701 return SCEVExprContains(S, [](const SCEV *S) {
13702 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13703 return isa<UndefValue>(SU->getValue());
13704 return false;
13705 });
13706}
13707
13708// Return true when S contains a value that is a nullptr.
13710 return SCEVExprContains(S, [](const SCEV *S) {
13711 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13712 return SU->getValue() == nullptr;
13713 return false;
13714 });
13715}
13716
13717/// Return the size of an element read or written by Inst.
13719 Type *Ty;
13720 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13721 Ty = Store->getValueOperand()->getType();
13722 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13723 Ty = Load->getType();
13724 else
13725 return nullptr;
13726
13728 return getSizeOfExpr(ETy, Ty);
13729}
13730
13731//===----------------------------------------------------------------------===//
13732// SCEVCallbackVH Class Implementation
13733//===----------------------------------------------------------------------===//
13734
13736 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13737 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13738 SE->ConstantEvolutionLoopExitValue.erase(PN);
13739 SE->eraseValueFromMap(getValPtr());
13740 // this now dangles!
13741}
13742
13743void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13744 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13745
13746 // Forget all the expressions associated with users of the old value,
13747 // so that future queries will recompute the expressions using the new
13748 // value.
13749 SE->forgetValue(getValPtr());
13750 // this now dangles!
13751}
13752
13753ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13754 : CallbackVH(V), SE(se) {}
13755
13756//===----------------------------------------------------------------------===//
13757// ScalarEvolution Class Implementation
13758//===----------------------------------------------------------------------===//
13759
13762 LoopInfo &LI)
13763 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13764 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13765 LoopDispositions(64), BlockDispositions(64) {
13766 // To use guards for proving predicates, we need to scan every instruction in
13767 // relevant basic blocks, and not just terminators. Doing this is a waste of
13768 // time if the IR does not actually contain any calls to
13769 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13770 //
13771 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13772 // to _add_ guards to the module when there weren't any before, and wants
13773 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13774 // efficient in lieu of being smart in that rather obscure case.
13775
13776 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13777 F.getParent(), Intrinsic::experimental_guard);
13778 HasGuards = GuardDecl && !GuardDecl->use_empty();
13779}
13780
13782 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13783 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13784 ValueExprMap(std::move(Arg.ValueExprMap)),
13785 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13786 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13787 PendingMerges(std::move(Arg.PendingMerges)),
13788 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13789 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13790 PredicatedBackedgeTakenCounts(
13791 std::move(Arg.PredicatedBackedgeTakenCounts)),
13792 BECountUsers(std::move(Arg.BECountUsers)),
13793 ConstantEvolutionLoopExitValue(
13794 std::move(Arg.ConstantEvolutionLoopExitValue)),
13795 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13796 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13797 LoopDispositions(std::move(Arg.LoopDispositions)),
13798 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13799 BlockDispositions(std::move(Arg.BlockDispositions)),
13800 SCEVUsers(std::move(Arg.SCEVUsers)),
13801 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13802 SignedRanges(std::move(Arg.SignedRanges)),
13803 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13804 UniquePreds(std::move(Arg.UniquePreds)),
13805 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13806 LoopUsers(std::move(Arg.LoopUsers)),
13807 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13808 FirstUnknown(Arg.FirstUnknown) {
13809 Arg.FirstUnknown = nullptr;
13810}
13811
13813 // Iterate through all the SCEVUnknown instances and call their
13814 // destructors, so that they release their references to their values.
13815 for (SCEVUnknown *U = FirstUnknown; U;) {
13816 SCEVUnknown *Tmp = U;
13817 U = U->Next;
13818 Tmp->~SCEVUnknown();
13819 }
13820 FirstUnknown = nullptr;
13821
13822 ExprValueMap.clear();
13823 ValueExprMap.clear();
13824 HasRecMap.clear();
13825 BackedgeTakenCounts.clear();
13826 PredicatedBackedgeTakenCounts.clear();
13827
13828 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13829 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13830 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13831 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13832 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13833}
13834
13838
13839/// When printing a top-level SCEV for trip counts, it's helpful to include
13840/// a type for constants which are otherwise hard to disambiguate.
13841static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13842 if (isa<SCEVConstant>(S))
13843 OS << *S->getType() << " ";
13844 OS << *S;
13845}
13846
13848 const Loop *L) {
13849 // Print all inner loops first
13850 for (Loop *I : *L)
13851 PrintLoopInfo(OS, SE, I);
13852
13853 OS << "Loop ";
13854 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13855 OS << ": ";
13856
13857 SmallVector<BasicBlock *, 8> ExitingBlocks;
13858 L->getExitingBlocks(ExitingBlocks);
13859 if (ExitingBlocks.size() != 1)
13860 OS << "<multiple exits> ";
13861
13862 auto *BTC = SE->getBackedgeTakenCount(L);
13863 if (!isa<SCEVCouldNotCompute>(BTC)) {
13864 OS << "backedge-taken count is ";
13865 PrintSCEVWithTypeHint(OS, BTC);
13866 } else
13867 OS << "Unpredictable backedge-taken count.";
13868 OS << "\n";
13869
13870 if (ExitingBlocks.size() > 1)
13871 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13872 OS << " exit count for " << ExitingBlock->getName() << ": ";
13873 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13874 PrintSCEVWithTypeHint(OS, EC);
13875 if (isa<SCEVCouldNotCompute>(EC)) {
13876 // Retry with predicates.
13878 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13879 if (!isa<SCEVCouldNotCompute>(EC)) {
13880 OS << "\n predicated exit count for " << ExitingBlock->getName()
13881 << ": ";
13882 PrintSCEVWithTypeHint(OS, EC);
13883 OS << "\n Predicates:\n";
13884 for (const auto *P : Predicates)
13885 P->print(OS, 4);
13886 }
13887 }
13888 OS << "\n";
13889 }
13890
13891 OS << "Loop ";
13892 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13893 OS << ": ";
13894
13895 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13896 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13897 OS << "constant max backedge-taken count is ";
13898 PrintSCEVWithTypeHint(OS, ConstantBTC);
13900 OS << ", actual taken count either this or zero.";
13901 } else {
13902 OS << "Unpredictable constant max backedge-taken count. ";
13903 }
13904
13905 OS << "\n"
13906 "Loop ";
13907 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13908 OS << ": ";
13909
13910 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13911 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13912 OS << "symbolic max backedge-taken count is ";
13913 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13915 OS << ", actual taken count either this or zero.";
13916 } else {
13917 OS << "Unpredictable symbolic max backedge-taken count. ";
13918 }
13919 OS << "\n";
13920
13921 if (ExitingBlocks.size() > 1)
13922 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13923 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13924 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13926 PrintSCEVWithTypeHint(OS, ExitBTC);
13927 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13928 // Retry with predicates.
13930 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13932 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13933 OS << "\n predicated symbolic max exit count for "
13934 << ExitingBlock->getName() << ": ";
13935 PrintSCEVWithTypeHint(OS, ExitBTC);
13936 OS << "\n Predicates:\n";
13937 for (const auto *P : Predicates)
13938 P->print(OS, 4);
13939 }
13940 }
13941 OS << "\n";
13942 }
13943
13945 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13946 if (PBT != BTC) {
13947 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13948 OS << "Loop ";
13949 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13950 OS << ": ";
13951 if (!isa<SCEVCouldNotCompute>(PBT)) {
13952 OS << "Predicated backedge-taken count is ";
13953 PrintSCEVWithTypeHint(OS, PBT);
13954 } else
13955 OS << "Unpredictable predicated backedge-taken count.";
13956 OS << "\n";
13957 OS << " Predicates:\n";
13958 for (const auto *P : Preds)
13959 P->print(OS, 4);
13960 }
13961 Preds.clear();
13962
13963 auto *PredConstantMax =
13965 if (PredConstantMax != ConstantBTC) {
13966 assert(!Preds.empty() &&
13967 "different predicated constant max BTC but no predicates");
13968 OS << "Loop ";
13969 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13970 OS << ": ";
13971 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13972 OS << "Predicated constant max backedge-taken count is ";
13973 PrintSCEVWithTypeHint(OS, PredConstantMax);
13974 } else
13975 OS << "Unpredictable predicated constant max backedge-taken count.";
13976 OS << "\n";
13977 OS << " Predicates:\n";
13978 for (const auto *P : Preds)
13979 P->print(OS, 4);
13980 }
13981 Preds.clear();
13982
13983 auto *PredSymbolicMax =
13985 if (SymbolicBTC != PredSymbolicMax) {
13986 assert(!Preds.empty() &&
13987 "Different predicated symbolic max BTC, but no predicates");
13988 OS << "Loop ";
13989 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13990 OS << ": ";
13991 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13992 OS << "Predicated symbolic max backedge-taken count is ";
13993 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13994 } else
13995 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13996 OS << "\n";
13997 OS << " Predicates:\n";
13998 for (const auto *P : Preds)
13999 P->print(OS, 4);
14000 }
14001
14003 OS << "Loop ";
14004 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14005 OS << ": ";
14006 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14007 }
14008}
14009
14010namespace llvm {
14012 switch (LD) {
14014 OS << "Variant";
14015 break;
14017 OS << "Invariant";
14018 break;
14020 OS << "Computable";
14021 break;
14022 }
14023 return OS;
14024}
14025
14027 switch (BD) {
14029 OS << "DoesNotDominate";
14030 break;
14032 OS << "Dominates";
14033 break;
14035 OS << "ProperlyDominates";
14036 break;
14037 }
14038 return OS;
14039}
14040} // namespace llvm
14041
14043 // ScalarEvolution's implementation of the print method is to print
14044 // out SCEV values of all instructions that are interesting. Doing
14045 // this potentially causes it to create new SCEV objects though,
14046 // which technically conflicts with the const qualifier. This isn't
14047 // observable from outside the class though, so casting away the
14048 // const isn't dangerous.
14049 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14050
14051 if (ClassifyExpressions) {
14052 OS << "Classifying expressions for: ";
14053 F.printAsOperand(OS, /*PrintType=*/false);
14054 OS << "\n";
14055 for (Instruction &I : instructions(F))
14056 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14057 OS << I << '\n';
14058 OS << " --> ";
14059 const SCEV *SV = SE.getSCEV(&I);
14060 SV->print(OS);
14061 if (!isa<SCEVCouldNotCompute>(SV)) {
14062 OS << " U: ";
14063 SE.getUnsignedRange(SV).print(OS);
14064 OS << " S: ";
14065 SE.getSignedRange(SV).print(OS);
14066 }
14067
14068 const Loop *L = LI.getLoopFor(I.getParent());
14069
14070 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14071 if (AtUse != SV) {
14072 OS << " --> ";
14073 AtUse->print(OS);
14074 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14075 OS << " U: ";
14076 SE.getUnsignedRange(AtUse).print(OS);
14077 OS << " S: ";
14078 SE.getSignedRange(AtUse).print(OS);
14079 }
14080 }
14081
14082 if (L) {
14083 OS << "\t\t" "Exits: ";
14084 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14085 if (!SE.isLoopInvariant(ExitValue, L)) {
14086 OS << "<<Unknown>>";
14087 } else {
14088 OS << *ExitValue;
14089 }
14090
14091 bool First = true;
14092 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14093 if (First) {
14094 OS << "\t\t" "LoopDispositions: { ";
14095 First = false;
14096 } else {
14097 OS << ", ";
14098 }
14099
14100 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14101 OS << ": " << SE.getLoopDisposition(SV, Iter);
14102 }
14103
14104 for (const auto *InnerL : depth_first(L)) {
14105 if (InnerL == L)
14106 continue;
14107 if (First) {
14108 OS << "\t\t" "LoopDispositions: { ";
14109 First = false;
14110 } else {
14111 OS << ", ";
14112 }
14113
14114 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14115 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14116 }
14117
14118 OS << " }";
14119 }
14120
14121 OS << "\n";
14122 }
14123 }
14124
14125 OS << "Determining loop execution counts for: ";
14126 F.printAsOperand(OS, /*PrintType=*/false);
14127 OS << "\n";
14128 for (Loop *I : LI)
14129 PrintLoopInfo(OS, &SE, I);
14130}
14131
14134 auto &Values = LoopDispositions[S];
14135 for (auto &V : Values) {
14136 if (V.getPointer() == L)
14137 return V.getInt();
14138 }
14139 Values.emplace_back(L, LoopVariant);
14140 LoopDisposition D = computeLoopDisposition(S, L);
14141 auto &Values2 = LoopDispositions[S];
14142 for (auto &V : llvm::reverse(Values2)) {
14143 if (V.getPointer() == L) {
14144 V.setInt(D);
14145 break;
14146 }
14147 }
14148 return D;
14149}
14150
14152ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14153 switch (S->getSCEVType()) {
14154 case scConstant:
14155 case scVScale:
14156 return LoopInvariant;
14157 case scAddRecExpr: {
14158 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14159
14160 // If L is the addrec's loop, it's computable.
14161 if (AR->getLoop() == L)
14162 return LoopComputable;
14163
14164 // Add recurrences are never invariant in the function-body (null loop).
14165 if (!L)
14166 return LoopVariant;
14167
14168 // Everything that is not defined at loop entry is variant.
14169 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14170 return LoopVariant;
14171 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14172 " dominate the contained loop's header?");
14173
14174 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14175 if (AR->getLoop()->contains(L))
14176 return LoopInvariant;
14177
14178 // This recurrence is variant w.r.t. L if any of its operands
14179 // are variant.
14180 for (const auto *Op : AR->operands())
14181 if (!isLoopInvariant(Op, L))
14182 return LoopVariant;
14183
14184 // Otherwise it's loop-invariant.
14185 return LoopInvariant;
14186 }
14187 case scTruncate:
14188 case scZeroExtend:
14189 case scSignExtend:
14190 case scPtrToInt:
14191 case scAddExpr:
14192 case scMulExpr:
14193 case scUDivExpr:
14194 case scUMaxExpr:
14195 case scSMaxExpr:
14196 case scUMinExpr:
14197 case scSMinExpr:
14198 case scSequentialUMinExpr: {
14199 bool HasVarying = false;
14200 for (const auto *Op : S->operands()) {
14202 if (D == LoopVariant)
14203 return LoopVariant;
14204 if (D == LoopComputable)
14205 HasVarying = true;
14206 }
14207 return HasVarying ? LoopComputable : LoopInvariant;
14208 }
14209 case scUnknown:
14210 // All non-instruction values are loop invariant. All instructions are loop
14211 // invariant if they are not contained in the specified loop.
14212 // Instructions are never considered invariant in the function body
14213 // (null loop) because they are defined within the "loop".
14214 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14215 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14216 return LoopInvariant;
14217 case scCouldNotCompute:
14218 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14219 }
14220 llvm_unreachable("Unknown SCEV kind!");
14221}
14222
14224 return getLoopDisposition(S, L) == LoopInvariant;
14225}
14226
14228 return getLoopDisposition(S, L) == LoopComputable;
14229}
14230
14233 auto &Values = BlockDispositions[S];
14234 for (auto &V : Values) {
14235 if (V.getPointer() == BB)
14236 return V.getInt();
14237 }
14238 Values.emplace_back(BB, DoesNotDominateBlock);
14239 BlockDisposition D = computeBlockDisposition(S, BB);
14240 auto &Values2 = BlockDispositions[S];
14241 for (auto &V : llvm::reverse(Values2)) {
14242 if (V.getPointer() == BB) {
14243 V.setInt(D);
14244 break;
14245 }
14246 }
14247 return D;
14248}
14249
14251ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14252 switch (S->getSCEVType()) {
14253 case scConstant:
14254 case scVScale:
14256 case scAddRecExpr: {
14257 // This uses a "dominates" query instead of "properly dominates" query
14258 // to test for proper dominance too, because the instruction which
14259 // produces the addrec's value is a PHI, and a PHI effectively properly
14260 // dominates its entire containing block.
14261 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14262 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14263 return DoesNotDominateBlock;
14264
14265 // Fall through into SCEVNAryExpr handling.
14266 [[fallthrough]];
14267 }
14268 case scTruncate:
14269 case scZeroExtend:
14270 case scSignExtend:
14271 case scPtrToInt:
14272 case scAddExpr:
14273 case scMulExpr:
14274 case scUDivExpr:
14275 case scUMaxExpr:
14276 case scSMaxExpr:
14277 case scUMinExpr:
14278 case scSMinExpr:
14279 case scSequentialUMinExpr: {
14280 bool Proper = true;
14281 for (const SCEV *NAryOp : S->operands()) {
14283 if (D == DoesNotDominateBlock)
14284 return DoesNotDominateBlock;
14285 if (D == DominatesBlock)
14286 Proper = false;
14287 }
14288 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14289 }
14290 case scUnknown:
14291 if (Instruction *I =
14292 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14293 if (I->getParent() == BB)
14294 return DominatesBlock;
14295 if (DT.properlyDominates(I->getParent(), BB))
14297 return DoesNotDominateBlock;
14298 }
14300 case scCouldNotCompute:
14301 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14302 }
14303 llvm_unreachable("Unknown SCEV kind!");
14304}
14305
14306bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14307 return getBlockDisposition(S, BB) >= DominatesBlock;
14308}
14309
14312}
14313
14314bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14315 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14316}
14317
14318void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14319 bool Predicated) {
14320 auto &BECounts =
14321 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14322 auto It = BECounts.find(L);
14323 if (It != BECounts.end()) {
14324 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14325 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14326 if (!isa<SCEVConstant>(S)) {
14327 auto UserIt = BECountUsers.find(S);
14328 assert(UserIt != BECountUsers.end());
14329 UserIt->second.erase({L, Predicated});
14330 }
14331 }
14332 }
14333 BECounts.erase(It);
14334 }
14335}
14336
14337void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14338 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14339 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14340
14341 while (!Worklist.empty()) {
14342 const SCEV *Curr = Worklist.pop_back_val();
14343 auto Users = SCEVUsers.find(Curr);
14344 if (Users != SCEVUsers.end())
14345 for (const auto *User : Users->second)
14346 if (ToForget.insert(User).second)
14347 Worklist.push_back(User);
14348 }
14349
14350 for (const auto *S : ToForget)
14351 forgetMemoizedResultsImpl(S);
14352
14353 for (auto I = PredicatedSCEVRewrites.begin();
14354 I != PredicatedSCEVRewrites.end();) {
14355 std::pair<const SCEV *, const Loop *> Entry = I->first;
14356 if (ToForget.count(Entry.first))
14357 PredicatedSCEVRewrites.erase(I++);
14358 else
14359 ++I;
14360 }
14361}
14362
14363void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14364 LoopDispositions.erase(S);
14365 BlockDispositions.erase(S);
14366 UnsignedRanges.erase(S);
14367 SignedRanges.erase(S);
14368 HasRecMap.erase(S);
14369 ConstantMultipleCache.erase(S);
14370
14371 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14372 UnsignedWrapViaInductionTried.erase(AR);
14373 SignedWrapViaInductionTried.erase(AR);
14374 }
14375
14376 auto ExprIt = ExprValueMap.find(S);
14377 if (ExprIt != ExprValueMap.end()) {
14378 for (Value *V : ExprIt->second) {
14379 auto ValueIt = ValueExprMap.find_as(V);
14380 if (ValueIt != ValueExprMap.end())
14381 ValueExprMap.erase(ValueIt);
14382 }
14383 ExprValueMap.erase(ExprIt);
14384 }
14385
14386 auto ScopeIt = ValuesAtScopes.find(S);
14387 if (ScopeIt != ValuesAtScopes.end()) {
14388 for (const auto &Pair : ScopeIt->second)
14389 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14390 llvm::erase(ValuesAtScopesUsers[Pair.second],
14391 std::make_pair(Pair.first, S));
14392 ValuesAtScopes.erase(ScopeIt);
14393 }
14394
14395 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14396 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14397 for (const auto &Pair : ScopeUserIt->second)
14398 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14399 ValuesAtScopesUsers.erase(ScopeUserIt);
14400 }
14401
14402 auto BEUsersIt = BECountUsers.find(S);
14403 if (BEUsersIt != BECountUsers.end()) {
14404 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14405 auto Copy = BEUsersIt->second;
14406 for (const auto &Pair : Copy)
14407 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14408 BECountUsers.erase(BEUsersIt);
14409 }
14410
14411 auto FoldUser = FoldCacheUser.find(S);
14412 if (FoldUser != FoldCacheUser.end())
14413 for (auto &KV : FoldUser->second)
14414 FoldCache.erase(KV);
14415 FoldCacheUser.erase(S);
14416}
14417
14418void
14419ScalarEvolution::getUsedLoops(const SCEV *S,
14420 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14421 struct FindUsedLoops {
14422 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14423 : LoopsUsed(LoopsUsed) {}
14424 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14425 bool follow(const SCEV *S) {
14426 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14427 LoopsUsed.insert(AR->getLoop());
14428 return true;
14429 }
14430
14431 bool isDone() const { return false; }
14432 };
14433
14434 FindUsedLoops F(LoopsUsed);
14435 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14436}
14437
14438void ScalarEvolution::getReachableBlocks(
14441 Worklist.push_back(&F.getEntryBlock());
14442 while (!Worklist.empty()) {
14443 BasicBlock *BB = Worklist.pop_back_val();
14444 if (!Reachable.insert(BB).second)
14445 continue;
14446
14447 Value *Cond;
14448 BasicBlock *TrueBB, *FalseBB;
14449 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14450 m_BasicBlock(FalseBB)))) {
14451 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14452 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14453 continue;
14454 }
14455
14456 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14457 const SCEV *L = getSCEV(Cmp->getOperand(0));
14458 const SCEV *R = getSCEV(Cmp->getOperand(1));
14459 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14460 Worklist.push_back(TrueBB);
14461 continue;
14462 }
14463 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14464 R)) {
14465 Worklist.push_back(FalseBB);
14466 continue;
14467 }
14468 }
14469 }
14470
14471 append_range(Worklist, successors(BB));
14472 }
14473}
14474
14476 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14477 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14478
14479 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14480
14481 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14482 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14483 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14484
14485 const SCEV *visitConstant(const SCEVConstant *Constant) {
14486 return SE.getConstant(Constant->getAPInt());
14487 }
14488
14489 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14490 return SE.getUnknown(Expr->getValue());
14491 }
14492
14493 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14494 return SE.getCouldNotCompute();
14495 }
14496 };
14497
14498 SCEVMapper SCM(SE2);
14499 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14500 SE2.getReachableBlocks(ReachableBlocks, F);
14501
14502 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14503 if (containsUndefs(Old) || containsUndefs(New)) {
14504 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14505 // not propagate undef aggressively). This means we can (and do) fail
14506 // verification in cases where a transform makes a value go from "undef"
14507 // to "undef+1" (say). The transform is fine, since in both cases the
14508 // result is "undef", but SCEV thinks the value increased by 1.
14509 return nullptr;
14510 }
14511
14512 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14513 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14514 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14515 return nullptr;
14516
14517 return Delta;
14518 };
14519
14520 while (!LoopStack.empty()) {
14521 auto *L = LoopStack.pop_back_val();
14522 llvm::append_range(LoopStack, *L);
14523
14524 // Only verify BECounts in reachable loops. For an unreachable loop,
14525 // any BECount is legal.
14526 if (!ReachableBlocks.contains(L->getHeader()))
14527 continue;
14528
14529 // Only verify cached BECounts. Computing new BECounts may change the
14530 // results of subsequent SCEV uses.
14531 auto It = BackedgeTakenCounts.find(L);
14532 if (It == BackedgeTakenCounts.end())
14533 continue;
14534
14535 auto *CurBECount =
14536 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14537 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14538
14539 if (CurBECount == SE2.getCouldNotCompute() ||
14540 NewBECount == SE2.getCouldNotCompute()) {
14541 // NB! This situation is legal, but is very suspicious -- whatever pass
14542 // change the loop to make a trip count go from could not compute to
14543 // computable or vice-versa *should have* invalidated SCEV. However, we
14544 // choose not to assert here (for now) since we don't want false
14545 // positives.
14546 continue;
14547 }
14548
14549 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14550 SE.getTypeSizeInBits(NewBECount->getType()))
14551 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14552 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14553 SE.getTypeSizeInBits(NewBECount->getType()))
14554 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14555
14556 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14557 if (Delta && !Delta->isZero()) {
14558 dbgs() << "Trip Count for " << *L << " Changed!\n";
14559 dbgs() << "Old: " << *CurBECount << "\n";
14560 dbgs() << "New: " << *NewBECount << "\n";
14561 dbgs() << "Delta: " << *Delta << "\n";
14562 std::abort();
14563 }
14564 }
14565
14566 // Collect all valid loops currently in LoopInfo.
14567 SmallPtrSet<Loop *, 32> ValidLoops;
14568 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14569 while (!Worklist.empty()) {
14570 Loop *L = Worklist.pop_back_val();
14571 if (ValidLoops.insert(L).second)
14572 Worklist.append(L->begin(), L->end());
14573 }
14574 for (const auto &KV : ValueExprMap) {
14575#ifndef NDEBUG
14576 // Check for SCEV expressions referencing invalid/deleted loops.
14577 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14578 assert(ValidLoops.contains(AR->getLoop()) &&
14579 "AddRec references invalid loop");
14580 }
14581#endif
14582
14583 // Check that the value is also part of the reverse map.
14584 auto It = ExprValueMap.find(KV.second);
14585 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14586 dbgs() << "Value " << *KV.first
14587 << " is in ValueExprMap but not in ExprValueMap\n";
14588 std::abort();
14589 }
14590
14591 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14592 if (!ReachableBlocks.contains(I->getParent()))
14593 continue;
14594 const SCEV *OldSCEV = SCM.visit(KV.second);
14595 const SCEV *NewSCEV = SE2.getSCEV(I);
14596 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14597 if (Delta && !Delta->isZero()) {
14598 dbgs() << "SCEV for value " << *I << " changed!\n"
14599 << "Old: " << *OldSCEV << "\n"
14600 << "New: " << *NewSCEV << "\n"
14601 << "Delta: " << *Delta << "\n";
14602 std::abort();
14603 }
14604 }
14605 }
14606
14607 for (const auto &KV : ExprValueMap) {
14608 for (Value *V : KV.second) {
14609 const SCEV *S = ValueExprMap.lookup(V);
14610 if (!S) {
14611 dbgs() << "Value " << *V
14612 << " is in ExprValueMap but not in ValueExprMap\n";
14613 std::abort();
14614 }
14615 if (S != KV.first) {
14616 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14617 << *KV.first << "\n";
14618 std::abort();
14619 }
14620 }
14621 }
14622
14623 // Verify integrity of SCEV users.
14624 for (const auto &S : UniqueSCEVs) {
14625 for (const auto *Op : S.operands()) {
14626 // We do not store dependencies of constants.
14627 if (isa<SCEVConstant>(Op))
14628 continue;
14629 auto It = SCEVUsers.find(Op);
14630 if (It != SCEVUsers.end() && It->second.count(&S))
14631 continue;
14632 dbgs() << "Use of operand " << *Op << " by user " << S
14633 << " is not being tracked!\n";
14634 std::abort();
14635 }
14636 }
14637
14638 // Verify integrity of ValuesAtScopes users.
14639 for (const auto &ValueAndVec : ValuesAtScopes) {
14640 const SCEV *Value = ValueAndVec.first;
14641 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14642 const Loop *L = LoopAndValueAtScope.first;
14643 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14644 if (!isa<SCEVConstant>(ValueAtScope)) {
14645 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14646 if (It != ValuesAtScopesUsers.end() &&
14647 is_contained(It->second, std::make_pair(L, Value)))
14648 continue;
14649 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14650 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14651 std::abort();
14652 }
14653 }
14654 }
14655
14656 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14657 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14658 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14659 const Loop *L = LoopAndValue.first;
14660 const SCEV *Value = LoopAndValue.second;
14662 auto It = ValuesAtScopes.find(Value);
14663 if (It != ValuesAtScopes.end() &&
14664 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14665 continue;
14666 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14667 << *ValueAtScope << " missing in ValuesAtScopes\n";
14668 std::abort();
14669 }
14670 }
14671
14672 // Verify integrity of BECountUsers.
14673 auto VerifyBECountUsers = [&](bool Predicated) {
14674 auto &BECounts =
14675 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14676 for (const auto &LoopAndBEInfo : BECounts) {
14677 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14678 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14679 if (!isa<SCEVConstant>(S)) {
14680 auto UserIt = BECountUsers.find(S);
14681 if (UserIt != BECountUsers.end() &&
14682 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14683 continue;
14684 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14685 << " missing from BECountUsers\n";
14686 std::abort();
14687 }
14688 }
14689 }
14690 }
14691 };
14692 VerifyBECountUsers(/* Predicated */ false);
14693 VerifyBECountUsers(/* Predicated */ true);
14694
14695 // Verify intergity of loop disposition cache.
14696 for (auto &[S, Values] : LoopDispositions) {
14697 for (auto [Loop, CachedDisposition] : Values) {
14698 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14699 if (CachedDisposition != RecomputedDisposition) {
14700 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14701 << " is incorrect: cached " << CachedDisposition << ", actual "
14702 << RecomputedDisposition << "\n";
14703 std::abort();
14704 }
14705 }
14706 }
14707
14708 // Verify integrity of the block disposition cache.
14709 for (auto &[S, Values] : BlockDispositions) {
14710 for (auto [BB, CachedDisposition] : Values) {
14711 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14712 if (CachedDisposition != RecomputedDisposition) {
14713 dbgs() << "Cached disposition of " << *S << " for block %"
14714 << BB->getName() << " is incorrect: cached " << CachedDisposition
14715 << ", actual " << RecomputedDisposition << "\n";
14716 std::abort();
14717 }
14718 }
14719 }
14720
14721 // Verify FoldCache/FoldCacheUser caches.
14722 for (auto [FoldID, Expr] : FoldCache) {
14723 auto I = FoldCacheUser.find(Expr);
14724 if (I == FoldCacheUser.end()) {
14725 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14726 << "!\n";
14727 std::abort();
14728 }
14729 if (!is_contained(I->second, FoldID)) {
14730 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14731 std::abort();
14732 }
14733 }
14734 for (auto [Expr, IDs] : FoldCacheUser) {
14735 for (auto &FoldID : IDs) {
14736 const SCEV *S = FoldCache.lookup(FoldID);
14737 if (!S) {
14738 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14739 << "!\n";
14740 std::abort();
14741 }
14742 if (S != Expr) {
14743 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14744 << " != " << *Expr << "!\n";
14745 std::abort();
14746 }
14747 }
14748 }
14749
14750 // Verify that ConstantMultipleCache computations are correct. We check that
14751 // cached multiples and recomputed multiples are multiples of each other to
14752 // verify correctness. It is possible that a recomputed multiple is different
14753 // from the cached multiple due to strengthened no wrap flags or changes in
14754 // KnownBits computations.
14755 for (auto [S, Multiple] : ConstantMultipleCache) {
14756 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14757 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14758 Multiple.urem(RecomputedMultiple) != 0 &&
14759 RecomputedMultiple.urem(Multiple) != 0)) {
14760 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14761 << *S << " : Computed " << RecomputedMultiple
14762 << " but cache contains " << Multiple << "!\n";
14763 std::abort();
14764 }
14765 }
14766}
14767
14769 Function &F, const PreservedAnalyses &PA,
14770 FunctionAnalysisManager::Invalidator &Inv) {
14771 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14772 // of its dependencies is invalidated.
14773 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14774 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14775 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14776 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14777 Inv.invalidate<LoopAnalysis>(F, PA);
14778}
14779
14780AnalysisKey ScalarEvolutionAnalysis::Key;
14781
14784 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14785 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14786 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14787 auto &LI = AM.getResult<LoopAnalysis>(F);
14788 return ScalarEvolution(F, TLI, AC, DT, LI);
14789}
14790
14796
14799 // For compatibility with opt's -analyze feature under legacy pass manager
14800 // which was not ported to NPM. This keeps tests using
14801 // update_analyze_test_checks.py working.
14802 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14803 << F.getName() << "':\n";
14805 return PreservedAnalyses::all();
14806}
14807
14809 "Scalar Evolution Analysis", false, true)
14815 "Scalar Evolution Analysis", false, true)
14816
14818
14820
14822 SE.reset(new ScalarEvolution(
14824 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14826 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14827 return false;
14828}
14829
14831
14833 SE->print(OS);
14834}
14835
14837 if (!VerifySCEV)
14838 return;
14839
14840 SE->verify();
14841}
14842
14850
14852 const SCEV *RHS) {
14853 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14854}
14855
14856const SCEVPredicate *
14858 const SCEV *LHS, const SCEV *RHS) {
14860 assert(LHS->getType() == RHS->getType() &&
14861 "Type mismatch between LHS and RHS");
14862 // Unique this node based on the arguments
14863 ID.AddInteger(SCEVPredicate::P_Compare);
14864 ID.AddInteger(Pred);
14865 ID.AddPointer(LHS);
14866 ID.AddPointer(RHS);
14867 void *IP = nullptr;
14868 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14869 return S;
14870 SCEVComparePredicate *Eq = new (SCEVAllocator)
14871 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14872 UniquePreds.InsertNode(Eq, IP);
14873 return Eq;
14874}
14875
14877 const SCEVAddRecExpr *AR,
14880 // Unique this node based on the arguments
14881 ID.AddInteger(SCEVPredicate::P_Wrap);
14882 ID.AddPointer(AR);
14883 ID.AddInteger(AddedFlags);
14884 void *IP = nullptr;
14885 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14886 return S;
14887 auto *OF = new (SCEVAllocator)
14888 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14889 UniquePreds.InsertNode(OF, IP);
14890 return OF;
14891}
14892
14893namespace {
14894
14895class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14896public:
14897
14898 /// Rewrites \p S in the context of a loop L and the SCEV predication
14899 /// infrastructure.
14900 ///
14901 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14902 /// equivalences present in \p Pred.
14903 ///
14904 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14905 /// \p NewPreds such that the result will be an AddRecExpr.
14906 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14908 const SCEVPredicate *Pred) {
14909 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14910 return Rewriter.visit(S);
14911 }
14912
14913 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14914 if (Pred) {
14915 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14916 for (const auto *Pred : U->getPredicates())
14917 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14918 if (IPred->getLHS() == Expr &&
14919 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14920 return IPred->getRHS();
14921 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14922 if (IPred->getLHS() == Expr &&
14923 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14924 return IPred->getRHS();
14925 }
14926 }
14927 return convertToAddRecWithPreds(Expr);
14928 }
14929
14930 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14931 const SCEV *Operand = visit(Expr->getOperand());
14932 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14933 if (AR && AR->getLoop() == L && AR->isAffine()) {
14934 // This couldn't be folded because the operand didn't have the nuw
14935 // flag. Add the nusw flag as an assumption that we could make.
14936 const SCEV *Step = AR->getStepRecurrence(SE);
14937 Type *Ty = Expr->getType();
14938 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14939 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14940 SE.getSignExtendExpr(Step, Ty), L,
14941 AR->getNoWrapFlags());
14942 }
14943 return SE.getZeroExtendExpr(Operand, Expr->getType());
14944 }
14945
14946 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14947 const SCEV *Operand = visit(Expr->getOperand());
14948 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14949 if (AR && AR->getLoop() == L && AR->isAffine()) {
14950 // This couldn't be folded because the operand didn't have the nsw
14951 // flag. Add the nssw flag as an assumption that we could make.
14952 const SCEV *Step = AR->getStepRecurrence(SE);
14953 Type *Ty = Expr->getType();
14954 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14955 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14956 SE.getSignExtendExpr(Step, Ty), L,
14957 AR->getNoWrapFlags());
14958 }
14959 return SE.getSignExtendExpr(Operand, Expr->getType());
14960 }
14961
14962private:
14963 explicit SCEVPredicateRewriter(
14964 const Loop *L, ScalarEvolution &SE,
14965 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14966 const SCEVPredicate *Pred)
14967 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14968
14969 bool addOverflowAssumption(const SCEVPredicate *P) {
14970 if (!NewPreds) {
14971 // Check if we've already made this assumption.
14972 return Pred && Pred->implies(P, SE);
14973 }
14974 NewPreds->push_back(P);
14975 return true;
14976 }
14977
14978 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14980 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14981 return addOverflowAssumption(A);
14982 }
14983
14984 // If \p Expr represents a PHINode, we try to see if it can be represented
14985 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14986 // to add this predicate as a runtime overflow check, we return the AddRec.
14987 // If \p Expr does not meet these conditions (is not a PHI node, or we
14988 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14989 // return \p Expr.
14990 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14991 if (!isa<PHINode>(Expr->getValue()))
14992 return Expr;
14993 std::optional<
14994 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14995 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14996 if (!PredicatedRewrite)
14997 return Expr;
14998 for (const auto *P : PredicatedRewrite->second){
14999 // Wrap predicates from outer loops are not supported.
15000 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15001 if (L != WP->getExpr()->getLoop())
15002 return Expr;
15003 }
15004 if (!addOverflowAssumption(P))
15005 return Expr;
15006 }
15007 return PredicatedRewrite->first;
15008 }
15009
15010 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15011 const SCEVPredicate *Pred;
15012 const Loop *L;
15013};
15014
15015} // end anonymous namespace
15016
15017const SCEV *
15019 const SCEVPredicate &Preds) {
15020 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15021}
15022
15024 const SCEV *S, const Loop *L,
15027 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15028 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15029
15030 if (!AddRec)
15031 return nullptr;
15032
15033 // Check if any of the transformed predicates is known to be false. In that
15034 // case, it doesn't make sense to convert to a predicated AddRec, as the
15035 // versioned loop will never execute.
15036 for (const SCEVPredicate *Pred : TransformPreds) {
15037 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15038 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15039 continue;
15040
15041 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15042 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15043 if (isa<SCEVCouldNotCompute>(ExitCount))
15044 continue;
15045
15046 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15047 if (!Step->isOne())
15048 continue;
15049
15050 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15051 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15052 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15053 return nullptr;
15054 }
15055
15056 // Since the transformation was successful, we can now transfer the SCEV
15057 // predicates.
15058 Preds.append(TransformPreds.begin(), TransformPreds.end());
15059
15060 return AddRec;
15061}
15062
15063/// SCEV predicates
15067
15069 const ICmpInst::Predicate Pred,
15070 const SCEV *LHS, const SCEV *RHS)
15071 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15072 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15073 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15074}
15075
15077 ScalarEvolution &SE) const {
15078 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15079
15080 if (!Op)
15081 return false;
15082
15083 if (Pred != ICmpInst::ICMP_EQ)
15084 return false;
15085
15086 return Op->LHS == LHS && Op->RHS == RHS;
15087}
15088
15089bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15090
15092 if (Pred == ICmpInst::ICMP_EQ)
15093 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15094 else
15095 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15096 << *RHS << "\n";
15097
15098}
15099
15101 const SCEVAddRecExpr *AR,
15102 IncrementWrapFlags Flags)
15103 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15104
15105const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15106
15108 ScalarEvolution &SE) const {
15109 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15110 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15111 return false;
15112
15113 if (Op->AR == AR)
15114 return true;
15115
15116 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15118 return false;
15119
15120 const SCEV *Start = AR->getStart();
15121 const SCEV *OpStart = Op->AR->getStart();
15122 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15123 return false;
15124
15125 // Reject pointers to different address spaces.
15126 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15127 return false;
15128
15129 const SCEV *Step = AR->getStepRecurrence(SE);
15130 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15131 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15132 return false;
15133
15134 // If both steps are positive, this implies N, if N's start and step are
15135 // ULE/SLE (for NSUW/NSSW) than this'.
15136 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15137 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15138 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15139
15140 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15141 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15142 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15143 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15144 : SE.getNoopOrSignExtend(Start, WiderTy);
15146 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15147 SE.isKnownPredicate(Pred, OpStart, Start);
15148}
15149
15151 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15152 IncrementWrapFlags IFlags = Flags;
15153
15154 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15155 IFlags = clearFlags(IFlags, IncrementNSSW);
15156
15157 return IFlags == IncrementAnyWrap;
15158}
15159
15160void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15161 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15163 OS << "<nusw>";
15165 OS << "<nssw>";
15166 OS << "\n";
15167}
15168
15171 ScalarEvolution &SE) {
15172 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15173 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15174
15175 // We can safely transfer the NSW flag as NSSW.
15176 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15177 ImpliedFlags = IncrementNSSW;
15178
15179 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15180 // If the increment is positive, the SCEV NUW flag will also imply the
15181 // WrapPredicate NUSW flag.
15182 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15183 if (Step->getValue()->getValue().isNonNegative())
15184 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15185 }
15186
15187 return ImpliedFlags;
15188}
15189
15190/// Union predicates don't get cached so create a dummy set ID for it.
15192 ScalarEvolution &SE)
15193 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15194 for (const auto *P : Preds)
15195 add(P, SE);
15196}
15197
15199 return all_of(Preds,
15200 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15201}
15202
15204 ScalarEvolution &SE) const {
15205 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15206 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15207 return this->implies(I, SE);
15208 });
15209
15210 return any_of(Preds,
15211 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15212}
15213
15215 for (const auto *Pred : Preds)
15216 Pred->print(OS, Depth);
15217}
15218
15219void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15220 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15221 for (const auto *Pred : Set->Preds)
15222 add(Pred, SE);
15223 return;
15224 }
15225
15226 // Implication checks are quadratic in the number of predicates. Stop doing
15227 // them if there are many predicates, as they should be too expensive to use
15228 // anyway at that point.
15229 bool CheckImplies = Preds.size() < 16;
15230
15231 // Only add predicate if it is not already implied by this union predicate.
15232 if (CheckImplies && implies(N, SE))
15233 return;
15234
15235 // Build a new vector containing the current predicates, except the ones that
15236 // are implied by the new predicate N.
15238 for (auto *P : Preds) {
15239 if (CheckImplies && N->implies(P, SE))
15240 continue;
15241 PrunedPreds.push_back(P);
15242 }
15243 Preds = std::move(PrunedPreds);
15244 Preds.push_back(N);
15245}
15246
15248 Loop &L)
15249 : SE(SE), L(L) {
15251 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15252}
15253
15256 for (const auto *Op : Ops)
15257 // We do not expect that forgetting cached data for SCEVConstants will ever
15258 // open any prospects for sharpening or introduce any correctness issues,
15259 // so we don't bother storing their dependencies.
15260 if (!isa<SCEVConstant>(Op))
15261 SCEVUsers[Op].insert(User);
15262}
15263
15265 const SCEV *Expr = SE.getSCEV(V);
15266 RewriteEntry &Entry = RewriteMap[Expr];
15267
15268 // If we already have an entry and the version matches, return it.
15269 if (Entry.second && Generation == Entry.first)
15270 return Entry.second;
15271
15272 // We found an entry but it's stale. Rewrite the stale entry
15273 // according to the current predicate.
15274 if (Entry.second)
15275 Expr = Entry.second;
15276
15277 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15278 Entry = {Generation, NewSCEV};
15279
15280 return NewSCEV;
15281}
15282
15284 if (!BackedgeCount) {
15286 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15287 for (const auto *P : Preds)
15288 addPredicate(*P);
15289 }
15290 return BackedgeCount;
15291}
15292
15294 if (!SymbolicMaxBackedgeCount) {
15296 SymbolicMaxBackedgeCount =
15297 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15298 for (const auto *P : Preds)
15299 addPredicate(*P);
15300 }
15301 return SymbolicMaxBackedgeCount;
15302}
15303
15305 if (!SmallConstantMaxTripCount) {
15307 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15308 for (const auto *P : Preds)
15309 addPredicate(*P);
15310 }
15311 return *SmallConstantMaxTripCount;
15312}
15313
15315 if (Preds->implies(&Pred, SE))
15316 return;
15317
15318 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15319 NewPreds.push_back(&Pred);
15320 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15321 updateGeneration();
15322}
15323
15325 return *Preds;
15326}
15327
15328void PredicatedScalarEvolution::updateGeneration() {
15329 // If the generation number wrapped recompute everything.
15330 if (++Generation == 0) {
15331 for (auto &II : RewriteMap) {
15332 const SCEV *Rewritten = II.second.second;
15333 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15334 }
15335 }
15336}
15337
15340 const SCEV *Expr = getSCEV(V);
15341 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15342
15343 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15344
15345 // Clear the statically implied flags.
15346 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15347 addPredicate(*SE.getWrapPredicate(AR, Flags));
15348
15349 auto II = FlagsMap.insert({V, Flags});
15350 if (!II.second)
15351 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15352}
15353
15356 const SCEV *Expr = getSCEV(V);
15357 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15358
15360 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15361
15362 auto II = FlagsMap.find(V);
15363
15364 if (II != FlagsMap.end())
15365 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15366
15368}
15369
15371 const SCEV *Expr = this->getSCEV(V);
15373 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15374
15375 if (!New)
15376 return nullptr;
15377
15378 for (const auto *P : NewPreds)
15379 addPredicate(*P);
15380
15381 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15382 return New;
15383}
15384
15387 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15388 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15389 SE)),
15390 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15391 for (auto I : Init.FlagsMap)
15392 FlagsMap.insert(I);
15393}
15394
15396 // For each block.
15397 for (auto *BB : L.getBlocks())
15398 for (auto &I : *BB) {
15399 if (!SE.isSCEVable(I.getType()))
15400 continue;
15401
15402 auto *Expr = SE.getSCEV(&I);
15403 auto II = RewriteMap.find(Expr);
15404
15405 if (II == RewriteMap.end())
15406 continue;
15407
15408 // Don't print things that are not interesting.
15409 if (II->second.second == Expr)
15410 continue;
15411
15412 OS.indent(Depth) << "[PSE]" << I << ":\n";
15413 OS.indent(Depth + 2) << *Expr << "\n";
15414 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15415 }
15416}
15417
15418// Match the mathematical pattern A - (A / B) * B, where A and B can be
15419// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15420// for URem with constant power-of-2 second operands.
15421// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15422// 4, A / B becomes X / 8).
15423bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15424 const SCEV *&RHS) {
15425 if (Expr->getType()->isPointerTy())
15426 return false;
15427
15428 // Try to match 'zext (trunc A to iB) to iY', which is used
15429 // for URem with constant power-of-2 second operands. Make sure the size of
15430 // the operand A matches the size of the whole expressions.
15431 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15432 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15433 LHS = Trunc->getOperand();
15434 // Bail out if the type of the LHS is larger than the type of the
15435 // expression for now.
15436 if (getTypeSizeInBits(LHS->getType()) >
15437 getTypeSizeInBits(Expr->getType()))
15438 return false;
15439 if (LHS->getType() != Expr->getType())
15440 LHS = getZeroExtendExpr(LHS, Expr->getType());
15441 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15442 << getTypeSizeInBits(Trunc->getType()));
15443 return true;
15444 }
15445 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15446 if (Add == nullptr || Add->getNumOperands() != 2)
15447 return false;
15448
15449 const SCEV *A = Add->getOperand(1);
15450 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15451
15452 if (Mul == nullptr)
15453 return false;
15454
15455 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15456 // (SomeExpr + (-(SomeExpr / B) * B)).
15457 if (Expr == getURemExpr(A, B)) {
15458 LHS = A;
15459 RHS = B;
15460 return true;
15461 }
15462 return false;
15463 };
15464
15465 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15466 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15467 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15468 MatchURemWithDivisor(Mul->getOperand(2));
15469
15470 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15471 if (Mul->getNumOperands() == 2)
15472 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15473 MatchURemWithDivisor(Mul->getOperand(0)) ||
15474 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15475 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15476 return false;
15477}
15478
15481 BasicBlock *Header = L->getHeader();
15482 BasicBlock *Pred = L->getLoopPredecessor();
15483 LoopGuards Guards(SE);
15484 if (!Pred)
15485 return Guards;
15487 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15488 return Guards;
15489}
15490
15491void ScalarEvolution::LoopGuards::collectFromPHI(
15495 unsigned Depth) {
15496 if (!SE.isSCEVable(Phi.getType()))
15497 return;
15498
15499 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15500 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15501 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15502 if (!VisitedBlocks.insert(InBlock).second)
15503 return {nullptr, scCouldNotCompute};
15504
15505 // Avoid analyzing unreachable blocks so that we don't get trapped
15506 // traversing cycles with ill-formed dominance or infinite cycles
15507 if (!SE.DT.isReachableFromEntry(InBlock))
15508 return {nullptr, scCouldNotCompute};
15509
15510 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15511 if (Inserted)
15512 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15513 Depth + 1);
15514 auto &RewriteMap = G->second.RewriteMap;
15515 if (RewriteMap.empty())
15516 return {nullptr, scCouldNotCompute};
15517 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15518 if (S == RewriteMap.end())
15519 return {nullptr, scCouldNotCompute};
15520 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15521 if (!SM)
15522 return {nullptr, scCouldNotCompute};
15523 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15524 return {C0, SM->getSCEVType()};
15525 return {nullptr, scCouldNotCompute};
15526 };
15527 auto MergeMinMaxConst = [](MinMaxPattern P1,
15528 MinMaxPattern P2) -> MinMaxPattern {
15529 auto [C1, T1] = P1;
15530 auto [C2, T2] = P2;
15531 if (!C1 || !C2 || T1 != T2)
15532 return {nullptr, scCouldNotCompute};
15533 switch (T1) {
15534 case scUMaxExpr:
15535 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15536 case scSMaxExpr:
15537 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15538 case scUMinExpr:
15539 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15540 case scSMinExpr:
15541 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15542 default:
15543 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15544 }
15545 };
15546 auto P = GetMinMaxConst(0);
15547 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15548 if (!P.first)
15549 break;
15550 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15551 }
15552 if (P.first) {
15553 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15555 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15556 Guards.RewriteMap.insert({LHS, RHS});
15557 }
15558}
15559
15560void ScalarEvolution::LoopGuards::collectFromBlock(
15561 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15562 const BasicBlock *Block, const BasicBlock *Pred,
15563 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15564
15566
15567 SmallVector<const SCEV *> ExprsToRewrite;
15568 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15569 const SCEV *RHS,
15570 DenseMap<const SCEV *, const SCEV *>
15571 &RewriteMap) {
15572 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15573 // replacement SCEV which isn't directly implied by the structure of that
15574 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15575 // legal. See the scoping rules for flags in the header to understand why.
15576
15577 // If LHS is a constant, apply information to the other expression.
15578 if (isa<SCEVConstant>(LHS)) {
15579 std::swap(LHS, RHS);
15581 }
15582
15583 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15584 // create this form when combining two checks of the form (X u< C2 + C1) and
15585 // (X >=u C1).
15586 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15587 &ExprsToRewrite]() {
15588 const SCEVConstant *C1;
15589 const SCEVUnknown *LHSUnknown;
15590 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15591 if (!match(LHS,
15592 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15593 !C2)
15594 return false;
15595
15596 auto ExactRegion =
15597 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15598 .sub(C1->getAPInt());
15599
15600 // Bail out, unless we have a non-wrapping, monotonic range.
15601 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15602 return false;
15603 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15604 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15605 I->second = SE.getUMaxExpr(
15606 SE.getConstant(ExactRegion.getUnsignedMin()),
15607 SE.getUMinExpr(RewrittenLHS,
15608 SE.getConstant(ExactRegion.getUnsignedMax())));
15609 ExprsToRewrite.push_back(LHSUnknown);
15610 return true;
15611 };
15612 if (MatchRangeCheckIdiom())
15613 return;
15614
15615 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15616 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15617 // the non-constant operand and in \p LHS the constant operand.
15618 auto IsMinMaxSCEVWithNonNegativeConstant =
15619 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15620 const SCEV *&RHS) {
15621 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15622 if (MinMax->getNumOperands() != 2)
15623 return false;
15624 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15625 if (C->getAPInt().isNegative())
15626 return false;
15627 SCTy = MinMax->getSCEVType();
15628 LHS = MinMax->getOperand(0);
15629 RHS = MinMax->getOperand(1);
15630 return true;
15631 }
15632 }
15633 return false;
15634 };
15635
15636 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15637 // constant, and returns their APInt in ExprVal and in DivisorVal.
15638 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15639 APInt &ExprVal, APInt &DivisorVal) {
15640 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15641 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15642 if (!ConstExpr || !ConstDivisor)
15643 return false;
15644 ExprVal = ConstExpr->getAPInt();
15645 DivisorVal = ConstDivisor->getAPInt();
15646 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15647 };
15648
15649 // Return a new SCEV that modifies \p Expr to the closest number divides by
15650 // \p Divisor and greater or equal than Expr.
15651 // For now, only handle constant Expr and Divisor.
15652 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15653 const SCEV *Divisor) {
15654 APInt ExprVal;
15655 APInt DivisorVal;
15656 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15657 return Expr;
15658 APInt Rem = ExprVal.urem(DivisorVal);
15659 if (!Rem.isZero())
15660 // return the SCEV: Expr + Divisor - Expr % Divisor
15661 return SE.getConstant(ExprVal + DivisorVal - Rem);
15662 return Expr;
15663 };
15664
15665 // Return a new SCEV that modifies \p Expr to the closest number divides by
15666 // \p Divisor and less or equal than Expr.
15667 // For now, only handle constant Expr and Divisor.
15668 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15669 const SCEV *Divisor) {
15670 APInt ExprVal;
15671 APInt DivisorVal;
15672 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15673 return Expr;
15674 APInt Rem = ExprVal.urem(DivisorVal);
15675 // return the SCEV: Expr - Expr % Divisor
15676 return SE.getConstant(ExprVal - Rem);
15677 };
15678
15679 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15680 // recursively. This is done by aligning up/down the constant value to the
15681 // Divisor.
15682 std::function<const SCEV *(const SCEV *, const SCEV *)>
15683 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15684 const SCEV *Divisor) {
15685 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15686 SCEVTypes SCTy;
15687 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15688 MinMaxRHS))
15689 return MinMaxExpr;
15690 auto IsMin =
15691 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15692 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15693 "Expected non-negative operand!");
15694 auto *DivisibleExpr =
15695 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15696 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15698 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15699 return SE.getMinMaxExpr(SCTy, Ops);
15700 };
15701
15702 // If we have LHS == 0, check if LHS is computing a property of some unknown
15703 // SCEV %v which we can rewrite %v to express explicitly.
15704 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15705 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15706 // explicitly express that.
15707 const SCEV *URemLHS = nullptr;
15708 const SCEV *URemRHS = nullptr;
15709 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15710 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15711 auto I = RewriteMap.find(LHSUnknown);
15712 const SCEV *RewrittenLHS =
15713 I != RewriteMap.end() ? I->second : LHSUnknown;
15714 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15715 const auto *Multiple =
15716 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15717 RewriteMap[LHSUnknown] = Multiple;
15718 ExprsToRewrite.push_back(LHSUnknown);
15719 return;
15720 }
15721 }
15722 }
15723
15724 // Do not apply information for constants or if RHS contains an AddRec.
15726 return;
15727
15728 // If RHS is SCEVUnknown, make sure the information is applied to it.
15730 std::swap(LHS, RHS);
15732 }
15733
15734 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15735 // and \p FromRewritten are the same (i.e. there has been no rewrite
15736 // registered for \p From), then puts this value in the list of rewritten
15737 // expressions.
15738 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15739 const SCEV *To) {
15740 if (From == FromRewritten)
15741 ExprsToRewrite.push_back(From);
15742 RewriteMap[From] = To;
15743 };
15744
15745 // Checks whether \p S has already been rewritten. In that case returns the
15746 // existing rewrite because we want to chain further rewrites onto the
15747 // already rewritten value. Otherwise returns \p S.
15748 auto GetMaybeRewritten = [&](const SCEV *S) {
15749 return RewriteMap.lookup_or(S, S);
15750 };
15751
15752 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15753 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15754 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15755 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15756 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15757 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15758 // DividesBy.
15759 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15760 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15761 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15762 if (Mul->getNumOperands() != 2)
15763 return false;
15764 auto *MulLHS = Mul->getOperand(0);
15765 auto *MulRHS = Mul->getOperand(1);
15766 if (isa<SCEVConstant>(MulLHS))
15767 std::swap(MulLHS, MulRHS);
15768 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15769 if (Div->getOperand(1) == MulRHS) {
15770 DividesBy = MulRHS;
15771 return true;
15772 }
15773 }
15774 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15775 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15776 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15777 return false;
15778 };
15779
15780 // Return true if Expr known to divide by \p DividesBy.
15781 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15782 [&](const SCEV *Expr, const SCEV *DividesBy) {
15783 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15784 return true;
15785 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15786 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15787 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15788 return false;
15789 };
15790
15791 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15792 const SCEV *DividesBy = nullptr;
15793 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15794 // Check that the whole expression is divided by DividesBy
15795 DividesBy =
15796 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15797
15798 // Collect rewrites for LHS and its transitive operands based on the
15799 // condition.
15800 // For min/max expressions, also apply the guard to its operands:
15801 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15802 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15803 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15804 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15805
15806 // We cannot express strict predicates in SCEV, so instead we replace them
15807 // with non-strict ones against plus or minus one of RHS depending on the
15808 // predicate.
15809 const SCEV *One = SE.getOne(RHS->getType());
15810 switch (Predicate) {
15811 case CmpInst::ICMP_ULT:
15812 if (RHS->getType()->isPointerTy())
15813 return;
15814 RHS = SE.getUMaxExpr(RHS, One);
15815 [[fallthrough]];
15816 case CmpInst::ICMP_SLT: {
15817 RHS = SE.getMinusSCEV(RHS, One);
15818 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15819 break;
15820 }
15821 case CmpInst::ICMP_UGT:
15822 case CmpInst::ICMP_SGT:
15823 RHS = SE.getAddExpr(RHS, One);
15824 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15825 break;
15826 case CmpInst::ICMP_ULE:
15827 case CmpInst::ICMP_SLE:
15828 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15829 break;
15830 case CmpInst::ICMP_UGE:
15831 case CmpInst::ICMP_SGE:
15832 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15833 break;
15834 default:
15835 break;
15836 }
15837
15839 SmallPtrSet<const SCEV *, 16> Visited;
15840
15841 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15842 append_range(Worklist, S->operands());
15843 };
15844
15845 while (!Worklist.empty()) {
15846 const SCEV *From = Worklist.pop_back_val();
15847 if (isa<SCEVConstant>(From))
15848 continue;
15849 if (!Visited.insert(From).second)
15850 continue;
15851 const SCEV *FromRewritten = GetMaybeRewritten(From);
15852 const SCEV *To = nullptr;
15853
15854 switch (Predicate) {
15855 case CmpInst::ICMP_ULT:
15856 case CmpInst::ICMP_ULE:
15857 To = SE.getUMinExpr(FromRewritten, RHS);
15858 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15859 EnqueueOperands(UMax);
15860 break;
15861 case CmpInst::ICMP_SLT:
15862 case CmpInst::ICMP_SLE:
15863 To = SE.getSMinExpr(FromRewritten, RHS);
15864 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15865 EnqueueOperands(SMax);
15866 break;
15867 case CmpInst::ICMP_UGT:
15868 case CmpInst::ICMP_UGE:
15869 To = SE.getUMaxExpr(FromRewritten, RHS);
15870 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15871 EnqueueOperands(UMin);
15872 break;
15873 case CmpInst::ICMP_SGT:
15874 case CmpInst::ICMP_SGE:
15875 To = SE.getSMaxExpr(FromRewritten, RHS);
15876 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15877 EnqueueOperands(SMin);
15878 break;
15879 case CmpInst::ICMP_EQ:
15881 To = RHS;
15882 break;
15883 case CmpInst::ICMP_NE:
15884 if (match(RHS, m_scev_Zero())) {
15885 const SCEV *OneAlignedUp =
15886 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15887 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15888 }
15889 break;
15890 default:
15891 break;
15892 }
15893
15894 if (To)
15895 AddRewrite(From, FromRewritten, To);
15896 }
15897 };
15898
15900 // First, collect information from assumptions dominating the loop.
15901 for (auto &AssumeVH : SE.AC.assumptions()) {
15902 if (!AssumeVH)
15903 continue;
15904 auto *AssumeI = cast<CallInst>(AssumeVH);
15905 if (!SE.DT.dominates(AssumeI, Block))
15906 continue;
15907 Terms.emplace_back(AssumeI->getOperand(0), true);
15908 }
15909
15910 // Second, collect information from llvm.experimental.guards dominating the loop.
15911 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15912 SE.F.getParent(), Intrinsic::experimental_guard);
15913 if (GuardDecl)
15914 for (const auto *GU : GuardDecl->users())
15915 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15916 if (Guard->getFunction() == Block->getParent() &&
15917 SE.DT.dominates(Guard, Block))
15918 Terms.emplace_back(Guard->getArgOperand(0), true);
15919
15920 // Third, collect conditions from dominating branches. Starting at the loop
15921 // predecessor, climb up the predecessor chain, as long as there are
15922 // predecessors that can be found that have unique successors leading to the
15923 // original header.
15924 // TODO: share this logic with isLoopEntryGuardedByCond.
15925 unsigned NumCollectedConditions = 0;
15927 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15928 for (; Pair.first;
15929 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15930 VisitedBlocks.insert(Pair.second);
15931 const BranchInst *LoopEntryPredicate =
15932 dyn_cast<BranchInst>(Pair.first->getTerminator());
15933 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15934 continue;
15935
15936 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15937 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15938 NumCollectedConditions++;
15939
15940 // If we are recursively collecting guards stop after 2
15941 // conditions to limit compile-time impact for now.
15942 if (Depth > 0 && NumCollectedConditions == 2)
15943 break;
15944 }
15945 // Finally, if we stopped climbing the predecessor chain because
15946 // there wasn't a unique one to continue, try to collect conditions
15947 // for PHINodes by recursively following all of their incoming
15948 // blocks and try to merge the found conditions to build a new one
15949 // for the Phi.
15950 if (Pair.second->hasNPredecessorsOrMore(2) &&
15952 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15953 for (auto &Phi : Pair.second->phis())
15954 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15955 }
15956
15957 // Now apply the information from the collected conditions to
15958 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15959 // earliest conditions is processed first. This ensures the SCEVs with the
15960 // shortest dependency chains are constructed first.
15961 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15962 SmallVector<Value *, 8> Worklist;
15963 SmallPtrSet<Value *, 8> Visited;
15964 Worklist.push_back(Term);
15965 while (!Worklist.empty()) {
15966 Value *Cond = Worklist.pop_back_val();
15967 if (!Visited.insert(Cond).second)
15968 continue;
15969
15970 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15971 auto Predicate =
15972 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15973 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15974 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15975 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15976 continue;
15977 }
15978
15979 Value *L, *R;
15980 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15981 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15982 Worklist.push_back(L);
15983 Worklist.push_back(R);
15984 }
15985 }
15986 }
15987
15988 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15989 // the replacement expressions are contained in the ranges of the replaced
15990 // expressions.
15991 Guards.PreserveNUW = true;
15992 Guards.PreserveNSW = true;
15993 for (const SCEV *Expr : ExprsToRewrite) {
15994 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15995 Guards.PreserveNUW &=
15996 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15997 Guards.PreserveNSW &=
15998 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15999 }
16000
16001 // Now that all rewrite information is collect, rewrite the collected
16002 // expressions with the information in the map. This applies information to
16003 // sub-expressions.
16004 if (ExprsToRewrite.size() > 1) {
16005 for (const SCEV *Expr : ExprsToRewrite) {
16006 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16007 Guards.RewriteMap.erase(Expr);
16008 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16009 }
16010 }
16011}
16012
16014 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16015 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16016 /// replacement is loop invariant in the loop of the AddRec.
16017 class SCEVLoopGuardRewriter
16018 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16020
16022
16023 public:
16024 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16025 const ScalarEvolution::LoopGuards &Guards)
16026 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
16027 if (Guards.PreserveNUW)
16028 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16029 if (Guards.PreserveNSW)
16030 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16031 }
16032
16033 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16034
16035 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16036 return Map.lookup_or(Expr, Expr);
16037 }
16038
16039 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16040 if (const SCEV *S = Map.lookup(Expr))
16041 return S;
16042
16043 // If we didn't find the extact ZExt expr in the map, check if there's
16044 // an entry for a smaller ZExt we can use instead.
16045 Type *Ty = Expr->getType();
16046 const SCEV *Op = Expr->getOperand(0);
16047 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16048 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16049 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16050 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16051 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16052 if (const SCEV *S = Map.lookup(NarrowExt))
16053 return SE.getZeroExtendExpr(S, Ty);
16054 Bitwidth = Bitwidth / 2;
16055 }
16056
16058 Expr);
16059 }
16060
16061 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16062 if (const SCEV *S = Map.lookup(Expr))
16063 return S;
16065 Expr);
16066 }
16067
16068 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16069 if (const SCEV *S = Map.lookup(Expr))
16070 return S;
16072 }
16073
16074 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16075 if (const SCEV *S = Map.lookup(Expr))
16076 return S;
16078 }
16079
16080 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16081 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16082 // (Const + A + B). There may be guard info for A + B, and if so, apply
16083 // it.
16084 // TODO: Could more generally apply guards to Add sub-expressions.
16085 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16086 Expr->getNumOperands() == 3) {
16087 if (const SCEV *S = Map.lookup(
16088 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
16089 return SE.getAddExpr(Expr->getOperand(0), S);
16090 }
16092 bool Changed = false;
16093 for (const auto *Op : Expr->operands()) {
16094 Operands.push_back(
16096 Changed |= Op != Operands.back();
16097 }
16098 // We are only replacing operands with equivalent values, so transfer the
16099 // flags from the original expression.
16100 return !Changed ? Expr
16101 : SE.getAddExpr(Operands,
16103 Expr->getNoWrapFlags(), FlagMask));
16104 }
16105
16106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16108 bool Changed = false;
16109 for (const auto *Op : Expr->operands()) {
16110 Operands.push_back(
16112 Changed |= Op != Operands.back();
16113 }
16114 // We are only replacing operands with equivalent values, so transfer the
16115 // flags from the original expression.
16116 return !Changed ? Expr
16117 : SE.getMulExpr(Operands,
16119 Expr->getNoWrapFlags(), FlagMask));
16120 }
16121 };
16122
16123 if (RewriteMap.empty())
16124 return Expr;
16125
16126 SCEVLoopGuardRewriter Rewriter(SE, *this);
16127 return Rewriter.visit(Expr);
16128}
16129
16130const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16131 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16132}
16133
16135 const LoopGuards &Guards) {
16136 return Guards.rewrite(Expr);
16137}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
mir Rename Register Operands
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:206
APInt abs() const
Get the absolute value.
Definition APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1201
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:371
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:219
unsigned countTrailingZeros() const
Definition APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:356
unsigned logBase2() const
Definition APInt.h:1761
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:440
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1221
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
iterator end() const
Definition ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator begin() const
Definition ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:760
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:237
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:180
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:163
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:158
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:222
void swap(DenseMap &RHS)
Definition DenseMap.h:747
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:712
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:743
TypeSize getSizeInBits() const
Definition DataLayout.h:723
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:295
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:294
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1099
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:169
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const SCEV > m_SCEV()
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
Definition MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:318
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2060
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1727
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:644
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:733
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2138
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:252
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2055
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:677
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:186
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:95
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:754
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2130
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1734
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:339
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:71
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1956
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2032
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1963
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:560
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:257
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1899
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:869
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:301
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:196
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:145
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:129
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.