Skip to content

Commit 417182e

Browse files
committed
Backport of rL326666 and rL326668 for PR36607 and PR36608.
[CallSiteSplitting] properly split musttail calls. The original author was Fedor Indutny <[email protected]>. `musttail` calls can't be naively splitted. The split blocks must include not only the call instruction itself, but also (optional) `bitcast` and `return` instructions that follow it. Clone `bitcast` and `ret`, place them into the split blocks, and remove the tail block when done. Reviewers: junbuml, mcrosier, davidxl, davide, fhahn Reviewed By: fhahn Subscribers: JDevlieghere, llvm-commits Differential Revision: https://reviews.llvm.org/D43729 git-svn-id: https://llvm.org/svn/llvm-project/llvm/branches/release_60@329793 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent d88ca29 commit 417182e

File tree

2 files changed

+184
-2
lines changed

2 files changed

+184
-2
lines changed

lib/Transforms/Scalar/CallSiteSplitting.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,46 @@ static bool canSplitCallSite(CallSite CS) {
201201
return CallSiteBB->canSplitPredecessors();
202202
}
203203

204+
static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before,
205+
Value *V) {
206+
Instruction *Copy = I->clone();
207+
Copy->setName(I->getName());
208+
Copy->insertBefore(Before);
209+
if (V)
210+
Copy->setOperand(0, V);
211+
return Copy;
212+
}
213+
214+
/// Copy mandatory `musttail` return sequence that follows original `CI`, and
215+
/// link it up to `NewCI` value instead:
216+
///
217+
/// * (optional) `bitcast NewCI to ...`
218+
/// * `ret bitcast or NewCI`
219+
///
220+
/// Insert this sequence right before `SplitBB`'s terminator, which will be
221+
/// cleaned up later in `splitCallSite` below.
222+
static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI,
223+
Instruction *NewCI) {
224+
bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy();
225+
auto II = std::next(CI->getIterator());
226+
227+
BitCastInst *BCI = dyn_cast<BitCastInst>(&*II);
228+
if (BCI)
229+
++II;
230+
231+
ReturnInst *RI = dyn_cast<ReturnInst>(&*II);
232+
assert(RI && "`musttail` call must be followed by `ret` instruction");
233+
234+
TerminatorInst *TI = SplitBB->getTerminator();
235+
Value *V = NewCI;
236+
if (BCI)
237+
V = cloneInstForMustTail(BCI, TI, V);
238+
cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V);
239+
240+
// FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug
241+
// that prevents doing this now.
242+
}
243+
204244
/// Return true if the CS is split into its new predecessors which are directly
205245
/// hooked to each of its original predecessors pointed by PredBB1 and PredBB2.
206246
/// CallInst1 and CallInst2 will be the new call-sites placed in the new
@@ -245,6 +285,7 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
245285
Instruction *CallInst1, Instruction *CallInst2) {
246286
Instruction *Instr = CS.getInstruction();
247287
BasicBlock *TailBB = Instr->getParent();
288+
bool IsMustTailCall = CS.isMustTailCall();
248289
assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site");
249290

250291
BasicBlock *SplitBlock1 =
@@ -276,9 +317,14 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
276317
++ArgNo;
277318
}
278319
}
320+
// Clone and place bitcast and return instructions before `TI`
321+
if (IsMustTailCall) {
322+
copyMustTailReturn(SplitBlock1, CS.getInstruction(), CallInst1);
323+
copyMustTailReturn(SplitBlock2, CS.getInstruction(), CallInst2);
324+
}
279325

280326
// Replace users of the original call with a PHI mering call-sites split.
281-
if (Instr->getNumUses()) {
327+
if (!IsMustTailCall && Instr->getNumUses()) {
282328
PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call",
283329
TailBB->getFirstNonPHI());
284330
PN->addIncoming(CallInst1, SplitBlock1);
@@ -290,8 +336,25 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
290336
<< "\n");
291337
DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName()
292338
<< "\n");
293-
Instr->eraseFromParent();
339+
294340
NumCallSiteSplit++;
341+
342+
// FIXME: remove TI in `copyMustTailReturn`
343+
if (IsMustTailCall) {
344+
// Remove superfluous `br` terminators from the end of the Split blocks
345+
// NOTE: Removing terminator removes the SplitBlock from the TailBB's
346+
// predecessors. Therefore we must get complete list of Splits before
347+
// attempting removal.
348+
SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB)));
349+
assert(Splits.size() == 2 && "Expected exactly 2 splits!");
350+
for (unsigned i = 0; i < Splits.size(); i++)
351+
Splits[i]->getTerminator()->eraseFromParent();
352+
353+
// Erase the tail block once done with musttail patching
354+
TailBB->eraseFromParent();
355+
return;
356+
}
357+
Instr->eraseFromParent();
295358
}
296359

297360
// Return true if the call-site has an argument which is a PHI with only
@@ -369,7 +432,17 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) {
369432
Function *Callee = CS.getCalledFunction();
370433
if (!Callee || Callee->isDeclaration())
371434
continue;
435+
436+
// Successful musttail call-site splits result in erased CI and erased BB.
437+
// Check if such path is possible before attempting the splitting.
438+
bool IsMustTail = CS.isMustTailCall();
439+
372440
Changed |= tryToSplitCallSite(CS);
441+
442+
// There're no interesting instructions after this. The call site
443+
// itself might have been erased on splitting.
444+
if (IsMustTail)
445+
break;
373446
}
374447
}
375448
return Changed;
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -callsite-splitting -S | FileCheck %s
3+
4+
define i8* @caller(i8* %a, i8* %b) {
5+
; CHECK-LABEL: @caller(
6+
; CHECK-NEXT: Top:
7+
; CHECK-NEXT: [[C:%.*]] = icmp eq i8* [[A:%.*]], null
8+
; CHECK-NEXT: br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
9+
; CHECK: TBB:
10+
; CHECK-NEXT: [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
11+
; CHECK-NEXT: br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
12+
; CHECK: Tail.predBB1.split:
13+
; CHECK-NEXT: [[TMP0:%.*]] = musttail call i8* @callee(i8* null, i8* [[B]])
14+
; CHECK-NEXT: [[CB1:%.*]] = bitcast i8* [[TMP0]] to i8*
15+
; CHECK-NEXT: ret i8* [[CB1]]
16+
; CHECK: Tail.predBB2.split:
17+
; CHECK-NEXT: [[TMP1:%.*]] = musttail call i8* @callee(i8* nonnull [[A]], i8* null)
18+
; CHECK-NEXT: [[CB2:%.*]] = bitcast i8* [[TMP1]] to i8*
19+
; CHECK-NEXT: ret i8* [[CB2]]
20+
; CHECK: End:
21+
; CHECK-NEXT: ret i8* null
22+
;
23+
Top:
24+
%c = icmp eq i8* %a, null
25+
br i1 %c, label %Tail, label %TBB
26+
TBB:
27+
%c2 = icmp eq i8* %b, null
28+
br i1 %c2, label %Tail, label %End
29+
Tail:
30+
%ca = musttail call i8* @callee(i8* %a, i8* %b)
31+
%cb = bitcast i8* %ca to i8*
32+
ret i8* %cb
33+
End:
34+
ret i8* null
35+
}
36+
37+
define i8* @callee(i8* %a, i8* %b) noinline {
38+
; CHECK-LABEL: define i8* @callee(
39+
; CHECK-NEXT: ret i8* [[A:%.*]]
40+
;
41+
ret i8* %a
42+
}
43+
44+
define i8* @no_cast_caller(i8* %a, i8* %b) {
45+
; CHECK-LABEL: @no_cast_caller(
46+
; CHECK-NEXT: Top:
47+
; CHECK-NEXT: [[C:%.*]] = icmp eq i8* [[A:%.*]], null
48+
; CHECK-NEXT: br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
49+
; CHECK: TBB:
50+
; CHECK-NEXT: [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
51+
; CHECK-NEXT: br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
52+
; CHECK: Tail.predBB1.split:
53+
; CHECK-NEXT: [[TMP0:%.*]] = musttail call i8* @callee(i8* null, i8* [[B]])
54+
; CHECK-NEXT: ret i8* [[TMP0]]
55+
; CHECK: Tail.predBB2.split:
56+
; CHECK-NEXT: [[TMP1:%.*]] = musttail call i8* @callee(i8* nonnull [[A]], i8* null)
57+
; CHECK-NEXT: ret i8* [[TMP1]]
58+
; CHECK: End:
59+
; CHECK-NEXT: ret i8* null
60+
;
61+
Top:
62+
%c = icmp eq i8* %a, null
63+
br i1 %c, label %Tail, label %TBB
64+
TBB:
65+
%c2 = icmp eq i8* %b, null
66+
br i1 %c2, label %Tail, label %End
67+
Tail:
68+
%ca = musttail call i8* @callee(i8* %a, i8* %b)
69+
ret i8* %ca
70+
End:
71+
ret i8* null
72+
}
73+
74+
define void @void_caller(i8* %a, i8* %b) {
75+
; CHECK-LABEL: @void_caller(
76+
; CHECK-NEXT: Top:
77+
; CHECK-NEXT: [[C:%.*]] = icmp eq i8* [[A:%.*]], null
78+
; CHECK-NEXT: br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
79+
; CHECK: TBB:
80+
; CHECK-NEXT: [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
81+
; CHECK-NEXT: br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
82+
; CHECK: Tail.predBB1.split:
83+
; CHECK-NEXT: musttail call void @void_callee(i8* null, i8* [[B]])
84+
; CHECK-NEXT: ret void
85+
; CHECK: Tail.predBB2.split:
86+
; CHECK-NEXT: musttail call void @void_callee(i8* nonnull [[A]], i8* null)
87+
; CHECK-NEXT: ret void
88+
; CHECK: End:
89+
; CHECK-NEXT: ret void
90+
;
91+
Top:
92+
%c = icmp eq i8* %a, null
93+
br i1 %c, label %Tail, label %TBB
94+
TBB:
95+
%c2 = icmp eq i8* %b, null
96+
br i1 %c2, label %Tail, label %End
97+
Tail:
98+
musttail call void @void_callee(i8* %a, i8* %b)
99+
ret void
100+
End:
101+
ret void
102+
}
103+
104+
define void @void_callee(i8* %a, i8* %b) noinline {
105+
; CHECK-LABEL: define void @void_callee(
106+
; CHECK-NEXT: ret void
107+
;
108+
ret void
109+
}

0 commit comments

Comments
 (0)