87#define DEBUG_TYPE "tailcallelim"
89STATISTIC(NumEliminated,
"Number of tail calls removed");
90STATISTIC(NumRetDuped,
"Number of return duplicated");
91STATISTIC(NumAccumAdded,
"Number of accumulators introduced");
95 cl::desc(
"Force disabling recomputing of function entry count, on "
96 "successful tail recursion elimination."));
106 auto *AI = dyn_cast<AllocaInst>(&
I);
107 return !AI || AI->isStaticAlloca();
112struct AllocaDerivedValueTracker {
116 void walk(
Value *Root) {
120 auto AddUsesToWorklist = [&](
Value *
V) {
121 for (
auto &U :
V->uses()) {
122 if (!Visited.
insert(&U).second)
128 AddUsesToWorklist(Root);
130 while (!Worklist.
empty()) {
134 switch (
I->getOpcode()) {
135 case Instruction::Call:
136 case Instruction::Invoke: {
137 auto &CB = cast<CallBase>(*
I);
142 if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
145 CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
146 callUsesLocalStack(CB, IsNocapture);
154 case Instruction::Load: {
159 case Instruction::Store: {
160 if (
U->getOperandNo() == 0)
161 EscapePoints.insert(
I);
164 case Instruction::BitCast:
165 case Instruction::GetElementPtr:
166 case Instruction::PHI:
167 case Instruction::Select:
168 case Instruction::AddrSpaceCast:
171 EscapePoints.insert(
I);
175 AddUsesToWorklist(
I);
179 void callUsesLocalStack(
CallBase &CB,
bool IsNocapture) {
181 AllocaUsers.insert(&CB);
189 EscapePoints.insert(&CB);
198 if (
F.callsFunctionThatReturnsTwice())
202 AllocaDerivedValueTracker Tracker;
204 if (Arg.hasByValAttr())
240 VisitType Escaped = UNESCAPED;
242 for (
auto &
I : *BB) {
243 if (Tracker.EscapePoints.count(&
I))
250 if (!CI || CI->
isTailCall() || isa<PseudoProbeInst>(&
I))
255 if (
auto *
II = dyn_cast<IntrinsicInst>(CI))
256 if (
II->getIntrinsicID() == Intrinsic::stackrestore)
274 bool SafeToTail =
true;
275 for (
auto &Arg : CI->
args()) {
276 if (isa<Constant>(Arg.getUser()))
278 if (
Argument *
A = dyn_cast<Argument>(Arg.getUser()))
279 if (!
A->hasByValAttr())
288 <<
"marked as tail call candidate (readnone)";
296 if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI))
301 auto &State = Visited[SuccBB];
302 if (State < Escaped) {
304 if (State == ESCAPED)
311 if (!WorklistEscaped.
empty()) {
316 while (!WorklistUnescaped.
empty()) {
318 if (Visited[NextBB] == UNESCAPED) {
327 for (
CallInst *CI : DeferredTails) {
328 if (Visited[CI->getParent()] != ESCAPED) {
331 LLVM_DEBUG(
dbgs() <<
"Marked as tail call candidate: " << *CI <<
"\n");
346 if (
II->getIntrinsicID() == Intrinsic::lifetime_end)
351 if (
I->mayHaveSideEffects())
354 if (
LoadInst *L = dyn_cast<LoadInst>(
I)) {
364 L->getAlign(),
DL, L))
378 if (!
I->isAssociative() || !
I->isCommutative())
381 assert(
I->getNumOperands() >= 2 &&
382 "Associative/commutative operations should have at least 2 args!");
391 if ((
I->getOperand(0) == CI &&
I->getOperand(1) == CI) ||
392 (
I->getOperand(0) != CI &&
I->getOperand(1) != CI))
396 if (!
I->hasOneUse() || !isa<ReturnInst>(
I->user_back()))
403class TailRecursionEliminator {
444 BFI ?
BFI->getBlockFreq(&
F.getEntryBlock()).getFrequency() : 0
U),
445 OrigEntryCount(
F.getEntryCount() ?
F.getEntryCount()->getCount() : 0) {
448 assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) &&
449 "If a BFI was provided, the function should have both an entry "
450 "count that is non-zero and an entry basic block with a non-zero "
457 void createTailRecurseLoopHeader(
CallInst *CI);
463 void cleanupAndFinalize();
467 void copyByValueOperandIntoLocalTemp(
CallInst *CI,
int OpndIdx);
469 void copyLocalTempOfByValueOperandIntoArguments(
CallInst *CI,
int OpndIdx);
481 if (&BB->
front() == TI)
489 CI = dyn_cast<CallInst>(BBI);
493 if (BBI == BB->
begin())
499 "Incompatible call site attributes(Tail,NoTail)");
507 if (BB == &
F.getEntryBlock() && &BB->
front() == CI &&
514 for (;
I != E && FI != FE; ++
I, ++FI)
515 if (*
I != &*FI)
break;
516 if (
I == E && FI == FE)
523void TailRecursionEliminator::createTailRecurseLoopHeader(
CallInst *CI) {
524 HeaderBB = &
F.getEntryBlock();
527 HeaderBB->
setName(
"tailrecurse");
536 NEBI = NewEntry->
begin();
538 if (
AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
539 if (isa<ConstantInt>(AI->getArraySize()))
540 AI->moveBefore(NEBI);
550 I->replaceAllUsesWith(PN);
559 Type *RetType =
F.getReturnType();
577void TailRecursionEliminator::insertAccumulator(
Instruction *AccRecInstr) {
578 assert(!AccPN &&
"Trying to insert multiple accumulators");
580 AccumulatorRecursionInstr = AccRecInstr;
596 if (
P == &
F.getEntryBlock()) {
610void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(
CallInst *CI,
622 AggTy,
DL.getAllocaAddrSpace(),
nullptr, Alignment,
626 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
629 Builder.CreateMemCpy(NewAlloca, Alignment,
637void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments(
647 Value *
Size = Builder.getInt64(
DL.getTypeAllocSize(AggTy));
651 Builder.CreateMemCpy(
F.getArg(OpndIdx), Alignment,
656bool TailRecursionEliminator::eliminateCall(
CallInst *CI) {
665 for (++BBI; &*BBI !=
Ret; ++BBI) {
686 <<
"transforming tail recursion into loop";
692 createTailRecurseLoopHeader(CI);
695 for (
unsigned I = 0, E = CI->
arg_size();
I != E; ++
I) {
697 copyByValueOperandIntoLocalTemp(CI,
I);
703 for (
unsigned I = 0, E = CI->
arg_size();
I != E; ++
I) {
705 copyLocalTempOfByValueOperandIntoArguments(CI,
I);
711 F.removeParamAttr(
I, Attribute::ReadOnly);
712 ArgumentPHIs[
I]->addIncoming(
F.getArg(
I), BB);
718 insertAccumulator(AccRecInstr);
728 if (
Ret->getReturnValue() == CI || AccRecInstr) {
738 "current.ret.tr",
Ret->getIterator());
739 SI->setDebugLoc(
Ret->getDebugLoc());
747 AccPN->
addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
755 Ret->eraseFromParent();
757 DTU.
applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
759 if (OrigEntryBBFreq) {
760 assert(
F.getEntryCount().has_value());
764 assert(&
F.getEntryBlock() != BB);
765 auto RelativeBBFreq =
766 static_cast<double>(
BFI->getBlockFreq(BB).getFrequency()) /
767 static_cast<double>(OrigEntryBBFreq);
769 static_cast<uint64_t>(std::round(RelativeBBFreq * OrigEntryCount));
770 auto OldEntryCount =
F.getEntryCount()->getCount();
771 if (OldEntryCount <= ToSubtract) {
773 errs() <<
"[TRE] The entrycount attributable to the recursive call, "
775 <<
", should be strictly lower than the function entry count, "
776 << OldEntryCount <<
"\n");
778 F.setEntryCount(OldEntryCount - ToSubtract,
F.getEntryCount()->getType());
784void TailRecursionEliminator::cleanupAndFinalize() {
790 for (
PHINode *PN : ArgumentPHIs) {
799 if (RetSelects.
empty()) {
811 Instruction *AccRecInstr = AccumulatorRecursionInstr;
818 AccRecInstrNew->
setName(
"accumulator.ret.tr");
845 Instruction *AccRecInstr = AccumulatorRecursionInstr;
848 AccRecInstrNew->
setName(
"accumulator.ret.tr");
850 SI->getFalseValue());
853 SI->setFalseValue(AccRecInstrNew);
860bool TailRecursionEliminator::processBlock(
BasicBlock &BB) {
863 if (
BranchInst *BI = dyn_cast<BranchInst>(TI)) {
864 if (BI->isConditional())
873 CallInst *CI = findTRECandidate(&BB);
879 <<
"INTO UNCOND BRANCH PRED: " << BB);
893 }
else if (isa<ReturnInst>(TI)) {
894 CallInst *CI = findTRECandidate(&BB);
897 return eliminateCall(CI);
903bool TailRecursionEliminator::eliminate(
Function &
F,
909 if (
F.getFnAttribute(
"disable-tail-calls").getValueAsBool())
912 bool MadeChange =
false;
917 if (
F.getFunctionType()->isVarArg())
924 TailRecursionEliminator TRE(
F,
TTI, AA, ORE, DTU, BFI);
927 MadeChange |= TRE.processBlock(BB);
929 TRE.cleanupAndFinalize();
954 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
955 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
956 auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
957 auto *PDT = PDTWP ? &PDTWP->getPostDomTree() :
nullptr;
961 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
963 return TailRecursionEliminator::eliminate(
964 F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F),
965 &getAnalysis<AAResultsWrapperPass>().getAAResults(),
966 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
972char TailCallElim::ID = 0;
982 return new TailCallElim();
994 F.getEntryCount().has_value() &&
F.getEntryCount()->getCount())
1003 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
1005 TailRecursionEliminator::eliminate(
F, &
TTI, &AA, &ORE, DTU, BFI);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This is the interface for a simple mod/ref and alias analysis over globals.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
uint64_t IntrinsicInst * II
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file defines the SmallPtrSet class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static cl::opt< bool > ForceDisableBFI("tre-disable-entrycount-recompute", cl::init(false), cl::Hidden, cl::desc("Force disabling recomputing of function entry count, on " "successful tail recursion elimination."))
static bool canTRE(Function &F)
Scan the specified function for alloca instructions.
static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA)
Return true if it is safe to move the specified instruction from after the call to before the call,...
static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI)
static bool markTails(Function &F, OptimizationRemarkEmitter *ORE)
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
A private abstract base class describing the concept of an individual alias analysis implementation.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
This class represents an incoming formal argument to a Function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
const Instruction & front() const
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI InstListType::const_iterator getFirstNonPHIOrDbg(bool SkipPseudoOp=true) const
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic,...
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Analysis pass which computes BlockFrequencyInfo.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
bool isByValArgument(unsigned ArgNo) const
Determine whether this argument is passed by value.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
bool onlyReadsMemory(unsigned OpNo) const
Type * getParamByValType(unsigned ArgNo) const
Extract the byval type for a call or parameter.
bool hasOperandBundlesOtherThan(ArrayRef< uint32_t > IDs) const
Return true if this operand bundle user contains operand bundles with tags other than those specified...
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
bool isNoTailCall() const
void setTailCall(bool IsTc=true)
static LLVM_ABI Constant * getIdentity(Instruction *I, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary or intrinsic Instruction.
static LLVM_ABI Constant * getIntrinsicIdentity(Intrinsic::ID, Type *Ty)
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
static DebugLoc getCompilerGenerated()
LLVM_ABI void deleteBB(BasicBlock *DelBB)
Delete DelBB.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
void applyUpdates(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
void recalculate(FuncT &F)
Notify DTU that the entry block was replaced.
Legacy wrapper pass to provide the GlobalsAAResult object.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
LLVM_ABI void dropLocation()
Drop the instruction's debug location.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
A wrapper class for inspecting calls to intrinsic functions.
@ OB_clang_arc_attachedcall
An instruction for reading from memory.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Analysis pass which computes a PostDominatorTree.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Return a value (possibly void), from a function.
This class represents the LLVM 'select' instruction.
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", InsertPosition InsertBefore=nullptr, Instruction *MDFrom=nullptr)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
A Use represents the edge between a Value definition and its users.
void dropAllReferences()
Drop all references to operands.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
self_iterator getIterator()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI FunctionPass * createTailCallEliminationPass()
auto pred_end(const MachineBasicBlock *BB)
auto successors(const MachineBasicBlock *BB)
LLVM_ABI ReturnInst * FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, BasicBlock *Pred, DomTreeUpdater *DTU=nullptr)
This method duplicates the specified return instruction into a predecessor which ends in an unconditi...
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isModSet(const ModRefInfo MRI)
LLVM_ABI bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
auto pred_begin(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool pred_empty(const BasicBlock *BB)
LLVM_ABI void initializeTailCallElimPass(PassRegistry &)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.