Skip to content

Commit

Permalink
Cleaned up block and branch handling. Now we select
Browse files Browse the repository at this point in the history
return values correctly.
  • Loading branch information
markkurossi committed Feb 15, 2020
1 parent 1497923 commit 1812b25
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 37 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
# mpc
Secure Multi-Party Computation

# SSA (Static single assignment form)

```go
package main

func main(a, b int) int {
if a > b {
return a
}
return b
}
```

The compiler creates the following SSA form assembly:

```
l0:
igt a@1,0/i b@1,0/i $_@0,0/b
jump l2
l2:
if $_@0,0/b l3
jump l4
l4:
mov b@1,0/i $ret0@1,2/i
jump l1
l1:
phi $_@0,0/b $ret0@1,1/i $ret0@1,2/i $_@0,1/i
ret $_@0,1/i
l3:
mov a@1,0/i $ret0@1,1/i
jump l1
```

![If-else SSA form](ifelse.png)

# Syntax

## Mathematical operations
Expand Down
19 changes: 10 additions & 9 deletions compiler/ast/ssagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,13 @@ func (ast *Func) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
}

// Select return variables.
// XXX this is broken, we must use decision tree for the return values.
var vars []ssa.Variable
for _, ret := range ast.Return {
b, err := block.Bindings.Get(ret.Name)
v, err := ctx.BlockHead.ReturnBinding(ret.Name, ctx.BlockTail, gen)
if err != nil {
return nil, err
}
vars = append(vars, b.Value(block, gen))
vars = append(vars, v)
}
ctx.BlockTail.AddInstr(ssa.NewRetInstr(vars))

Expand Down Expand Up @@ -143,12 +142,13 @@ func (ast *If) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
}

branchBlock := gen.NextBlock(block)
branchBlock.BranchCond = e
block.AddInstr(ssa.NewJumpInstr(branchBlock))

block = branchBlock

// Branch.
tBlock := gen.NextBlock(block)
tBlock := gen.BranchBlock(block)
block.AddInstr(ssa.NewIfInstr(e, tBlock))

// True branch.
Expand All @@ -165,7 +165,7 @@ func (ast *If) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
tNext = gen.NextBlock(block)
} else {
tNext.Bindings = tNext.Bindings.Merge(e, block.Bindings)
block.AddTo(tNext)
block.SetNext(tNext)
}
block.AddInstr(ssa.NewJumpInstr(tNext))

Expand Down Expand Up @@ -196,10 +196,10 @@ func (ast *If) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
// Both branches continue.

next := gen.Block()
tNext.AddTo(next)
tNext.SetNext(next)
tNext.AddInstr(ssa.NewJumpInstr(next))

fNext.AddTo(next)
fNext.SetNext(next)
fNext.AddInstr(ssa.NewJumpInstr(next))

next.Bindings = tNext.Bindings.Merge(e, fNext.Bindings)
Expand Down Expand Up @@ -238,14 +238,15 @@ func (ast *Return) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
if err != nil {
return nil, err
}
_, err = ctx.Pop()
v, err = ctx.Pop()
if err != nil {
return nil, err
}
block.Bindings.Set(v)
}

block.AddInstr(ssa.NewJumpInstr(ctx.BlockTail))
block.AddTo(ctx.BlockTail)
block.SetNext(ctx.BlockTail)
block.Dead = true

return block, nil
Expand Down
100 changes: 78 additions & 22 deletions compiler/ssa/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ import (
)

type Block struct {
ID string
From []*Block
To []*Block
Instr []Instr
Bindings Bindings
Dead bool
ID string
From []*Block
Next *Block
BranchCond Variable
Branch *Block
Instr []Instr
Bindings Bindings
Dead bool
}

func (b *Block) String() string {
Expand All @@ -30,18 +32,22 @@ func (b *Block) Equals(o *Block) bool {
return b.ID == o.ID
}

func (b *Block) AddTo(o *Block) {
b.addTo(o)
func (b *Block) SetNext(o *Block) {
if b.Next != nil && b.Next.ID != o.ID {
panic(fmt.Sprintf("%s.Next already set to %s, now setting to %s",
b.ID, b.Next.ID, o.ID))
}
b.Next = o
o.addFrom(b)
}

func (b *Block) addTo(o *Block) {
for _, f := range b.To {
if f.Equals(o) {
return
}
func (b *Block) SetBranch(o *Block) {
if b.Branch != nil && b.Branch.ID != o.ID {
panic(fmt.Sprintf("%s.Branch already set to %s, now setting to %s",
b.ID, b.Next.ID, o.ID))
}
b.To = append(b.To, o)
b.Branch = o
o.addFrom(b)
}

func (b *Block) addFrom(o *Block) {
Expand All @@ -57,6 +63,42 @@ func (b *Block) AddInstr(instr Instr) {
b.Instr = append(b.Instr, instr)
}

func (b *Block) ReturnBinding(name string, retBlock *Block, gen *Generator) (
v Variable, err error) {

if b.Branch == nil {
// Sequential block, return latest value
if b.Next != nil {
v, err = b.Next.ReturnBinding(name, retBlock, gen)
if err == nil {
return v, nil
}
// Next didn't have value, take ours below.
}
bind, err := b.Bindings.Get(name)
if err != nil {
return v, err
}
return bind.Value(retBlock, gen), nil
}
vTrue, err := b.Branch.ReturnBinding(name, retBlock, gen)
if err != nil {
return v, err
}
vFalse, err := b.Next.ReturnBinding(name, retBlock, gen)
if err != nil {
return v, err
}
if vTrue.Equal(&vFalse) {
return vTrue, nil
}

v = gen.AnonVar(vTrue.Type)
retBlock.AddInstr(NewPhiInstr(b.BranchCond, vTrue, vFalse, v))

return v, nil
}

func (b *Block) PP(out io.Writer, seen map[string]bool) {
if seen[b.ID] {
return
Expand All @@ -67,8 +109,11 @@ func (b *Block) PP(out io.Writer, seen map[string]bool) {
for _, i := range b.Instr {
i.PP(out)
}
for _, to := range b.To {
to.PP(out, seen)
if b.Next != nil {
b.Next.PP(out, seen)
}
if b.Branch != nil {
b.Branch.PP(out, seen)
}
}

Expand All @@ -94,8 +139,11 @@ func (b *Block) DotNodes(out io.Writer, seen map[string]bool) {

fmt.Fprintf(out, " %s [label=\"%s\"]\n", b.ID, label)

for _, to := range b.To {
to.DotNodes(out, seen)
if b.Next != nil {
b.Next.DotNodes(out, seen)
}
if b.Branch != nil {
b.Branch.DotNodes(out, seen)
}
}

Expand All @@ -104,12 +152,20 @@ func (b *Block) DotLinks(out io.Writer, seen map[string]bool) {
return
}
seen[b.ID] = true
for _, to := range b.To {
fmt.Fprintf(out, " %s -> %s [label=\"%s\"];\n", b.ID, to.ID, to.ID)
if b.Next != nil {
fmt.Fprintf(out, " %s -> %s [label=\"%s\"];\n",
b.ID, b.Next.ID, b.Next.ID)
}
if b.Branch != nil {
fmt.Fprintf(out, " %s -> %s [label=\"%s\"];\n",
b.ID, b.Branch.ID, b.Branch.ID)
}

for _, to := range b.To {
to.DotLinks(out, seen)
if b.Next != nil {
b.Next.DotLinks(out, seen)
}
if b.Branch != nil {
b.Branch.DotLinks(out, seen)
}
}

Expand Down
9 changes: 8 additions & 1 deletion compiler/ssa/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ func (gen *Generator) Block() *Block {
func (gen *Generator) NextBlock(b *Block) *Block {
n := gen.Block()
n.Bindings = b.Bindings.Clone()
b.AddTo(n)
b.SetNext(n)
return n
}

func (gen *Generator) BranchBlock(b *Block) *Block {
n := gen.Block()
n.Bindings = b.Bindings.Clone()
b.SetBranch(n)
return n
}
13 changes: 8 additions & 5 deletions compiler/ssagen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ type SSAGenTest struct {

var ssagenTests = []SSAGenTest{
SSAGenTest{
Name: "add",
Enabled: true,
Name: "add",
Code: `
package main
func main(a, b int) int {
Expand All @@ -28,7 +29,8 @@ func main(a, b int) int {
`,
},
SSAGenTest{
Name: "ret2",
Enabled: true,
Name: "ret2",
Code: `
package main
func main(a, b int) (int, int) {
Expand All @@ -50,7 +52,8 @@ func main(a, b int) int {
`,
},
SSAGenTest{
Name: "ifelse",
Enabled: true,
Name: "ifelse",
Code: `
package main
func main(a, b int) int {
Expand All @@ -63,7 +66,7 @@ func main(a, b int) int {
`,
},
SSAGenTest{
Enabled: false,
Enabled: true,
Name: "if-else-assign",
Code: `
package main
Expand All @@ -83,7 +86,7 @@ func main(a, b int) (int, int) {
`,
},
SSAGenTest{
Enabled: false,
Enabled: true,
Name: "max3",
Code: `
package main
Expand Down

0 comments on commit 1812b25

Please sign in to comment.