Skip to content

Commit

Permalink
starlark: capture free variables by reference (google#172)
Browse files Browse the repository at this point in the history
This change causes closures for nested functions to capture their
enclosing functions' variables by reference. Even though an inner
function cannot update outer variables, it must observe updates to
them made by the outer function.

A special case of this is a nested recursive function f:

   def outer():
       def f(): ...f()...

The def f statement constructs a closure which captures f, and then
binds the closure value to f. If the closure captures by value (as
before this change), the def statement will fail because f is
undefined (see issue google#170). Now, the closure captures a reference
to f, so it is safe to execute before f has been assigned.

This is implemented as follows. During resolving, captured local
variables such as f are marked as as "cells". The compiler assumes and
guarantees that such locals are values of a special internal type
called 'cell', and it emits explicit instructions to load from and
store into the cell. At runtime, cells are created on entry to the function;
parameters may be "spilled" into cells as needed.
Each cell variable gets its own allocation to avoid spurious liveness.
A function's tuple of free variables contains only cells.

Fixes google#170
  • Loading branch information
adonovan authored Mar 8, 2019
1 parent f763f8b commit 3d5a061
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 33 deletions.
72 changes: 51 additions & 21 deletions internal/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
// - an stack of active iterators.
// - an array of local variables.
// The number of local variables and their indices are computed by the resolver.
// Locals (possibly including parameters) that are shared with nested functions
// are 'cells': their locals array slot will contain a value of type 'cell',
// an indirect value in a box that is explicitly read/updated by instructions.
// - an array of free variables, for nested functions.
// As with locals, these are computed by the resolver.
// Free variables are a subset of the ancestors' cell variables.
// As with locals and cells, these are computed by the resolver.
// - an array of global variables, shared among all functions in the same module.
// All elements are initially nil.
// - two maps of predeclared and universal identifiers.
Expand All @@ -36,7 +40,7 @@ import (
const debug = false // TODO(adonovan): use a bitmap of options; and regexp to match files

// Increment this to force recompilation of saved bytecode files.
const Version = 7
const Version = 8

type Opcode uint8

Expand Down Expand Up @@ -89,18 +93,20 @@ const (
FALSE // - FALSE False
MANDATORY // - MANDATORY Mandatory [sentinel value for required kwonly args]

ITERPUSH // iterable ITERPUSH - [pushes the iterator stack]
ITERPOP // - ITERPOP - [pops the iterator stack]
NOT // value NOT bool
RETURN // value RETURN -
SETINDEX // a i new SETINDEX -
INDEX // a i INDEX elem
SETDICT // dict key value SETDICT -
ITERPUSH // iterable ITERPUSH - [pushes the iterator stack]
ITERPOP // - ITERPOP - [pops the iterator stack]
NOT // value NOT bool
RETURN // value RETURN -
SETINDEX // a i new SETINDEX -
INDEX // a i INDEX elem
SETDICT // dict key value SETDICT -
SETDICTUNIQ // dict key value SETDICTUNIQ -
APPEND // list elem APPEND -
SLICE // x lo hi step SLICE slice
APPEND // list elem APPEND -
SLICE // x lo hi step SLICE slice
INPLACE_ADD // x y INPLACE_ADD z where z is x+y or x.extend(y)
MAKEDICT // - MAKEDICT dict
MAKEDICT // - MAKEDICT dict
SETCELL // value cell SETCELL -
CELL // cell CELL value

// --- opcodes with an argument must go below this line ---

Expand All @@ -118,7 +124,7 @@ const (
SETLOCAL // value SETLOCAL<local> -
SETGLOBAL // value SETGLOBAL<global> -
LOCAL // - LOCAL<local> value
FREE // - FREE<freevar> value
FREE // - FREE<freevar> cell
GLOBAL // - GLOBAL<global> value
PREDECLARED // - PREDECLARED<name> value
UNIVERSAL // - UNIVERSAL<name> value
Expand Down Expand Up @@ -146,6 +152,7 @@ var opcodeNames = [...]string{
CALL_KW: "call_kw ",
CALL_VAR: "call_var",
CALL_VAR_KW: "call_var_kw",
CELL: "cell",
CIRCUMFLEX: "circumflex",
CJMP: "cjmp",
CONSTANT: "constant",
Expand Down Expand Up @@ -187,6 +194,7 @@ var opcodeNames = [...]string{
POP: "pop",
PREDECLARED: "predeclared",
RETURN: "return",
SETCELL: "setcell",
SETDICT: "setdict",
SETDICTUNIQ: "setdictuniq",
SETFIELD: "setfield",
Expand Down Expand Up @@ -217,6 +225,7 @@ var stackEffect = [...]int8{
CALL_KW: variableStackEffect,
CALL_VAR: variableStackEffect,
CALL_VAR_KW: variableStackEffect,
CELL: 0,
CIRCUMFLEX: -1,
CJMP: -1,
CONSTANT: +1,
Expand Down Expand Up @@ -257,6 +266,7 @@ var stackEffect = [...]int8{
POP: -1,
PREDECLARED: +1,
RETURN: -1,
SETCELL: -2,
SETDICT: -3,
SETDICTUNIQ: -3,
SETFIELD: -2,
Expand Down Expand Up @@ -308,6 +318,7 @@ type Funcode struct {
Code []byte // the byte code
pclinetab []uint16 // mapping from pc to linenum
Locals []Binding // locals, parameters first
Cells []int // indices of Locals that require cells
Freevars []Binding // for tracing
MaxStack int
NumParams int
Expand Down Expand Up @@ -446,6 +457,13 @@ func (pcomp *pcomp) function(name string, pos syntax.Position, stmts []syntax.St
},
}

// Record indices of locals that require cells.
for i, local := range locals {
if local.Scope == syntax.CellScope {
fcomp.fn.Cells = append(fcomp.fn.Cells, i)
}
}

if debug {
fmt.Fprintf(os.Stderr, "start function(%s @ %s)\n", name, pos)
}
Expand Down Expand Up @@ -895,33 +913,43 @@ func (fcomp *fcomp) setPos(pos syntax.Position) {
}

// set emits code to store the top-of-stack value
// to the specified local or global variable.
// to the specified local, cell, or global variable.
func (fcomp *fcomp) set(id *syntax.Ident) {
bind := id.Binding
switch bind.Scope {
case syntax.LocalScope:
fcomp.emit1(SETLOCAL, uint32(bind.Index))
case syntax.CellScope:
// TODO(adonovan): opt: make a single op for LOCAL<n>, SETCELL.
fcomp.emit1(LOCAL, uint32(bind.Index))
fcomp.emit(SETCELL)
case syntax.GlobalScope:
fcomp.emit1(SETGLOBAL, uint32(bind.Index))
default:
log.Fatalf("%s: set(%s): neither global nor local (%d)", id.NamePos, id.Name, bind.Scope)
log.Fatalf("%s: set(%s): not global/local/cell (%d)", id.NamePos, id.Name, bind.Scope)
}
}

// lookup emits code to push the value of the specified variable.
func (fcomp *fcomp) lookup(id *syntax.Ident) {
bind := id.Binding
if bind.Scope != syntax.UniversalScope { // (universal lookup can't fail)
fcomp.setPos(id.NamePos)
}
switch bind.Scope {
case syntax.LocalScope:
fcomp.setPos(id.NamePos)
fcomp.emit1(LOCAL, uint32(bind.Index))
case syntax.FreeScope:
// TODO(adonovan): opt: make a single op for FREE<n>, CELL.
fcomp.emit1(FREE, uint32(bind.Index))
fcomp.emit(CELL)
case syntax.CellScope:
// TODO(adonovan): opt: make a single op for LOCAL<n>, CELL.
fcomp.emit1(LOCAL, uint32(bind.Index))
fcomp.emit(CELL)
case syntax.GlobalScope:
fcomp.setPos(id.NamePos)
fcomp.emit1(GLOBAL, uint32(bind.Index))
case syntax.PredeclaredScope:
fcomp.setPos(id.NamePos)
fcomp.emit1(PREDECLARED, fcomp.pcomp.nameIndex(id.Name))
case syntax.UniversalScope:
fcomp.emit1(UNIVERSAL, fcomp.pcomp.nameIndex(id.Name))
Expand Down Expand Up @@ -1706,14 +1734,16 @@ func (fcomp *fcomp) function(pos syntax.Position, name string, f *syntax.Functio
}
fcomp.emit1(MAKETUPLE, uint32(n))

// Capture the values of the function's
// Capture the cells of the function's
// free variables from the lexical environment.
for _, freevar := range f.FreeVars {
// Don't call fcomp.lookup because we want
// the cell itself, not its content.
switch freevar.Scope {
case syntax.LocalScope:
fcomp.emit1(LOCAL, uint32(freevar.Index))
case syntax.FreeScope:
fcomp.emit1(FREE, uint32(freevar.Index))
case syntax.CellScope:
fcomp.emit1(LOCAL, uint32(freevar.Index))
}
}
fcomp.emit1(MAKETUPLE, uint32(len(f.FreeVars)))
Expand Down
16 changes: 16 additions & 0 deletions internal/compile/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ package compile
// pclinetab []varint
// numlocals varint
// locals []Ident
// numcells varint
// cells []int
// numfreevars varint
// freevar []Ident
// maxstack varint
Expand Down Expand Up @@ -183,6 +185,10 @@ func (e *encoder) function(fn *Funcode) {
e.int64(int64(x))
}
e.bindings(fn.Locals)
e.int(len(fn.Cells))
for _, index := range fn.Cells {
e.int(index)
}
e.bindings(fn.Freevars)
e.int(fn.MaxStack)
e.int(fn.NumParams)
Expand Down Expand Up @@ -338,6 +344,14 @@ func (d *decoder) bindings() []Binding {
return bindings
}

func (d *decoder) ints() []int {
ints := make([]int, d.int())
for i := range ints {
ints[i] = d.int()
}
return ints
}

func (d *decoder) bool() bool { return d.int() != 0 }

func (d *decoder) function() *Funcode {
Expand All @@ -349,6 +363,7 @@ func (d *decoder) function() *Funcode {
pclinetab[i] = uint16(d.int())
}
locals := d.bindings()
cells := d.ints()
freevars := d.bindings()
maxStack := d.int()
numParams := d.int()
Expand All @@ -363,6 +378,7 @@ func (d *decoder) function() *Funcode {
Code: code,
pclinetab: pclinetab,
Locals: locals,
Cells: cells,
Freevars: freevars,
MaxStack: maxStack,
NumParams: numParams,
Expand Down
12 changes: 8 additions & 4 deletions resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ func (r *resolver) spellcheck(use use) string {
}

// resolveLocalUses is called when leaving a container (function/module)
// block. It resolves all uses of locals within that block.
// block. It resolves all uses of locals/cells within that block.
func (b *block) resolveLocalUses() {
unresolved := b.uses[:0]
for _, use := range b.uses {
if bind := lookupLocal(use); bind != nil && bind.Scope == syntax.LocalScope {
if bind := lookupLocal(use); bind != nil && (bind.Scope == syntax.LocalScope || bind.Scope == syntax.CellScope) {
use.id.Binding = bind
} else {
unresolved = append(unresolved, use)
Expand Down Expand Up @@ -877,10 +877,14 @@ func (r *resolver) lookupLexical(use use, env *block) (bind *syntax.Binding) {
if !ok {
// Defined in parent block?
bind = r.lookupLexical(use, env.parent)
if env.function != nil && (bind.Scope == syntax.LocalScope || bind.Scope == syntax.FreeScope) {
if env.function != nil && (bind.Scope == syntax.LocalScope || bind.Scope == syntax.FreeScope || bind.Scope == syntax.CellScope) {
// Found in parent block, which belongs to enclosing function.
// Add the parent's binding to the function's freevars,
// and add a new 'free' binding to the inner function's block.
// and add a new 'free' binding to the inner function's block,
// and turn the parent's local into cell.
if bind.Scope == syntax.LocalScope {
bind.Scope = syntax.CellScope
}
index := len(env.function.FreeVars)
env.function.FreeVars = append(env.function.FreeVars, bind)
bind = &syntax.Binding{
Expand Down
3 changes: 1 addition & 2 deletions starlark/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ func TestInt(t *testing.T) {

func TestBacktrace(t *testing.T) {
// This test ensures continuity of the stack of active Starlark
// functions, including propagation through built-ins such as 'min'
// (though min does not itself appear in the stack).
// functions, including propagation through built-ins such as 'min'.
const src = `
def f(x): return 1//x
def g(x): f(x)
Expand Down
34 changes: 34 additions & 0 deletions starlark/interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Va
defer fmt.Println("Leaving ", f.Name)
}

// Spill indicated locals to cells.
// Each cell is a separate alloc to avoid spurious liveness.
for _, index := range f.Cells {
locals[index] = &cell{locals[index]}
}

// TODO(adonovan): add static check that beneath this point
// - there is exactly one return statement
// - there is no redefinition of 'err'.
Expand Down Expand Up @@ -502,6 +508,12 @@ loop:
locals[arg] = stack[sp-1]
sp--

case compile.SETCELL:
x := stack[sp-2]
y := stack[sp-1]
sp -= 2
y.(*cell).v = x

case compile.SETGLOBAL:
fn.globals[arg] = stack[sp-1]
sp--
Expand All @@ -519,6 +531,10 @@ loop:
stack[sp] = fn.freevars[arg]
sp++

case compile.CELL:
x := stack[sp-1]
stack[sp-1] = x.(*cell).v

case compile.GLOBAL:
x := fn.globals[arg]
if x == nil {
Expand Down Expand Up @@ -573,3 +589,21 @@ func (mandatory) Type() string { return "mandatory" }
func (mandatory) Freeze() {} // immutable
func (mandatory) Truth() Bool { return False }
func (mandatory) Hash() (uint32, error) { return 0, nil }

// A cell is a box containing a Value.
// Local variables marked as cells hold their value indirectly
// so that they may be shared by outer and inner nested functions.
// Cells are always accessed using indirect CELL/SETCELL instructions.
// The FreeVars tuple contains only cells.
// The FREE instruction always yields a cell.
type cell struct{ v Value }

func (c *cell) String() string { return "cell" }
func (c *cell) Type() string { return "cell" }
func (c *cell) Freeze() {
if c.v != nil {
c.v.Freeze()
}
}
func (c *cell) Truth() Bool { panic("unreachable") }
func (c *cell) Hash() (uint32, error) { panic("unreachable") }
8 changes: 5 additions & 3 deletions starlark/testdata/assign.star
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ g = 1

---
# option:nesteddef
# free variable captured before assignment
# Free variables are captured by reference, so this is ok.
load("assert.star", "assert")

def f():
def g(): ### "local variable outer referenced before assignment"
def g():
return outer
outer = 1
return g()

f()
assert.eq(f(), 1)

---
load("assert.star", "assert")
Expand Down
Loading

0 comments on commit 3d5a061

Please sign in to comment.