Skip to content

Commit

Permalink
transform (coroutines): remove map iteration from coroutine lowering …
Browse files Browse the repository at this point in the history
…pass

The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output.
This change removes many of the maps and ensures that the transformations are deterministic.
  • Loading branch information
niaow authored and aykevl committed Apr 12, 2020
1 parent 3862d6e commit bb5f753
Showing 1 changed file with 101 additions and 53 deletions.
154 changes: 101 additions & 53 deletions transform/coroutines.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ type asyncFunc struct {
// callers is a set of all functions which call this async function.
callers map[llvm.Value]struct{}

// returns is a map of terminal basic blocks to their return kinds.
returns map[llvm.BasicBlock]returnKind
// returns is a list of returns in the function, along with metadata.
returns []asyncReturn

// calls is the set of all calls in the asyncFunc.
// normalCalls is the set of all intermideate suspending calls in the asyncFunc.
// tailCalls is the set of all tail calls in the asyncFunc.
calls, normalCalls, tailCalls map[llvm.Value]struct{}
// calls is a list of all calls in the asyncFunc.
// normalCalls is a list of all intermideate suspending calls in the asyncFunc.
// tailCalls is a list of all tail calls in the asyncFunc.
calls, normalCalls, tailCalls []llvm.Value
}

// asyncReturn is a metadata container for a return from an asynchronous function.
type asyncReturn struct {
// block is the basic block terminated by the return.
block llvm.BasicBlock

// kind is the kind of the return.
kind returnKind
}

// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
Expand All @@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
// The map keys are function pointers.
asyncFuncs map[llvm.Value]*asyncFunc

asyncFuncsOrdered []*asyncFunc

// calls is a slice of all of the async calls in the module.
calls []llvm.Value

Expand All @@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
func (c *coroutineLoweringPass) findAsyncFuncs() {
asyncs := map[llvm.Value]*asyncFunc{}
asyncsOrdered := []llvm.Value{}
calls := []llvm.Value{}

// Use a breadth-first search to find all async functions.
worklist := []llvm.Value{c.pause}
for len(worklist) > 0 {
// Pop a function off the worklist.
fn := worklist[len(worklist)-1]
worklist = worklist[:len(worklist)-1]
fn := worklist[0]
worklist = worklist[1:]

// Get task pointer argument.
task := fn.LastParam()
Expand Down Expand Up @@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
// Mark the caller as async.
// Use nil as a temporary value. It will be replaced later.
asyncs[caller] = nil
asyncsOrdered = append(asyncsOrdered, caller)

// Put the caller on the worklist.
worklist = append(worklist, caller)
Expand All @@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
}
}

// Flip the order of the async functions so that the top ones are lowered first.
for i := 0; i < len(asyncsOrdered)/2; i++ {
asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
}

// Map the elements of asyncsOrdered to *asyncFunc.
asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
for i, v := range asyncsOrdered {
asyncFuncsOrdered[i] = asyncs[v]
}

c.asyncFuncs = asyncs
c.asyncFuncsOrdered = asyncFuncsOrdered
c.calls = calls
}

Expand Down Expand Up @@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {

// analyzeFuncReturns analyzes and classifies the returns of a function.
func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
returns := map[llvm.BasicBlock]returnKind{}
returns := []asyncReturn{}
if fn.fn == c.pause {
// Skip pause.
fn.returns = returns
Expand All @@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
case !c.isAsyncCall(prev):
// This is not any form of asynchronous tail call.
if isVoid {
returns[bb] = returnVoid
returns = append(returns, asyncReturn{
block: bb,
kind: returnVoid,
})
} else {
returns[bb] = returnNormal
returns = append(returns, asyncReturn{
block: bb,
kind: returnNormal,
})
}
case isVoid:
if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
// This is a tail call to a void-returning function from a function with a void return.
returns[bb] = returnVoidTail
returns = append(returns, asyncReturn{
block: bb,
kind: returnVoidTail,
})
} else {
// This is a tail call to a value-returning function from a function with a void return.
// The returned value will be ditched.
returns[bb] = returnDitchedTail
returns = append(returns, asyncReturn{
block: bb,
kind: returnDitchedTail,
})
}
case last.Operand(0) == prev:
// This is a regular tail call. The return of the callee is returned to the parent.
returns[bb] = returnTail
returns = append(returns, asyncReturn{
block: bb,
kind: returnTail,
})
case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
// This is a tail call that returns a previous value after waiting on a void function.
returns[bb] = returnDelayedValue
returns = append(returns, asyncReturn{
block: bb,
kind: returnDelayedValue,
})
default:
// This is a tail call that returns a value that is available before the function call.
returns[bb] = returnAlternateTail
returns = append(returns, asyncReturn{
block: bb,
kind: returnAlternateTail,
})
}
case llvm.Unreachable:
prev := llvm.PrevInstruction(last)
Expand All @@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
}

// This is an asyncnhronous tail call to function that does not return.
returns[bb] = returnDeadTail
returns = append(returns, asyncReturn{
block: bb,
kind: returnDeadTail,
})
}
}

Expand All @@ -451,46 +500,45 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {

// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
func (c *coroutineLoweringPass) returnAnalysisPass() {
for _, async := range c.asyncFuncs {
for _, async := range c.asyncFuncsOrdered {
c.analyzeFuncReturns(async)
}
}

// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
func (c *coroutineLoweringPass) categorizeCalls() {
// Sort calls into their respective callers.
for _, async := range c.asyncFuncs {
async.calls = map[llvm.Value]struct{}{}
}
for _, call := range c.calls {
c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{}
caller := c.asyncFuncs[call.InstructionParent().Parent()]
caller.calls = append(caller.calls, call)
}

// Seperate regular and tail calls.
for _, async := range c.asyncFuncs {
// Find all tail calls (of any kind).
for _, async := range c.asyncFuncsOrdered {
// Search returns for tail calls.
tails := map[llvm.Value]struct{}{}
for ret, kind := range async.returns {
switch kind {
for _, ret := range async.returns {
switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
// This is a tail return. The previous instruction is a tail call.
tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{}
tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
}
}

// Find all regular calls.
regulars := map[llvm.Value]struct{}{}
for call := range async.calls {
// Seperate tail calls and regular calls.
normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
for _, call := range async.calls {
if _, ok := tails[call]; ok {
// This is a tail call.
continue
tailCalls = append(tailCalls, call)
} else {
// This is a regular call.
normalCalls = append(normalCalls, call)
}

regulars[call] = struct{}{}
}

async.tailCalls = tails
async.normalCalls = regulars
async.normalCalls = normalCalls
async.tailCalls = tailCalls
}
}

Expand All @@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
}

func (async *asyncFunc) hasValueStoreReturn() bool {
for _, kind := range async.returns {
switch kind {
for _, ret := range async.returns {
switch ret.kind {
case returnNormal, returnAlternateTail, returnDelayedValue:
return true
}
Expand Down Expand Up @@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
}

// Lower returns.
for ret, kind := range fn.returns {
for _, ret := range fn.returns {
// Get terminator.
terminator := ret.LastInstruction()
terminator := ret.block.LastInstruction()

// Get tail call if applicable.
var call llvm.Value
switch kind {
switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator)
}

switch kind {
switch ret.kind {
case returnNormal:
c.builder.SetInsertPointBefore(terminator)

Expand Down Expand Up @@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.CreateBr(suspend)

// Restore old state before tail calls.
for call := range fn.tailCalls {
if fn.returns[call.InstructionParent()] == returnDeadTail {
for _, call := range fn.tailCalls {
if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
// Callee never returns, so the state restore is ineffectual.
continue
}
Expand All @@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
}

// Lower returns.
for ret, kind := range fn.returns {
for _, ret := range fn.returns {
// Get terminator instruction.
terminator := ret.LastInstruction()
terminator := ret.block.LastInstruction()

// Get tail call if applicable.
var call llvm.Value
switch kind {
switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator)
}

switch kind {
switch ret.kind {
case returnNormal:
c.builder.SetInsertPointBefore(terminator)

Expand All @@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call)

// Store return value.
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
c.builder.CreateStore(terminator.Operand(0), retPtr)

// Heap-allocate a return buffer for the discarded return.
alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
Expand All @@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call)

// Store return value.
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
c.builder.CreateStore(terminator.Operand(0), retPtr)
}

// Delete call if it is a pause, because it has already been lowered.
Expand All @@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {

// Replace terminator with branch to cleanup.
terminator.EraseFromParentAsInstruction()
c.builder.SetInsertPointAtEnd(ret)
c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup)
}

// Lower regular calls.
for call := range fn.normalCalls {
for _, call := range fn.normalCalls {
// Lower return value of call.
c.lowerCallReturn(fn, call)

Expand Down Expand Up @@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
} else {
// Check for any undead returns.
var undead bool
for _, kind := range c.asyncFuncs[fn].returns {
if kind != returnDeadTail {
for _, ret := range c.asyncFuncs[fn].returns {
if ret.kind != returnDeadTail {
// This return results in a value being eventually stored.
undead = true
break
Expand Down

0 comments on commit bb5f753

Please sign in to comment.