Skip to content

Commit

Permalink
Constants.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Apr 11, 2020
1 parent a95536a commit c825a06
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 92 deletions.
5 changes: 2 additions & 3 deletions apps/garbled/examples/bug.mpcl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

package main

func main(a, b int4) int {
tmp := a / b
return tmp + b
func main(a, b uint8) uint {
return b - a
}
3 changes: 2 additions & 1 deletion compiler/ast/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ func (pkg *Package) Compile(packages map[string]*Package, logger *utils.Logger,

steps := ctx.Start().Serialize()

program, err := ssa.NewProgram(inputs, outputs, gen.Constants(), steps)
program, err := ssa.NewProgram(params, inputs, outputs, gen.Constants(),
steps)
if err != nil {
return nil, nil, err
}
Expand Down
45 changes: 1 addition & 44 deletions compiler/ssa/circuitgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ package ssa

import (
"fmt"
"sort"
"strings"

"github.com/markkurossi/mpc/circuit"
"github.com/markkurossi/mpc/compiler/circuits"
Expand All @@ -25,7 +23,7 @@ func (prog *Program) CompileCircuit(params *utils.Params) (
return nil, err
}

err = prog.DefineConstants(cc)
err = prog.DefineConstants(cc.ZeroWire(), cc.OneWire())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -71,47 +69,6 @@ func (prog *Program) CompileCircuit(params *utils.Params) (
return circ, nil
}

func (prog *Program) DefineConstants(cc *circuits.Compiler) error {

var consts []Variable
for _, c := range prog.Constants {
consts = append(consts, c.Const)
}
sort.Slice(consts, func(i, j int) bool {
return strings.Compare(consts[i].Name, consts[j].Name) == -1
})

if len(consts) > 0 && cc.Params.Verbose {
fmt.Printf("Defining constants:\n")
}
for _, c := range consts {
msg := fmt.Sprintf(" - %v(%d)", c, c.Type.MinBits)

var wires []*circuits.Wire
var bitString string
for bit := 0; bit < c.Type.MinBits; bit++ {
var w *circuits.Wire
if c.Bit(bit) {
bitString = "1" + bitString
w = cc.OneWire()
} else {
bitString = "0" + bitString
w = cc.ZeroWire()
}
wires = append(wires, w)
}
if cc.Params.Verbose {
fmt.Printf("%s\t%s\n", msg, bitString)
}

err := prog.SetWires(c.String(), wires)
if err != nil {
return err
}
}
return nil
}

func (prog *Program) Circuit(cc *circuits.Compiler) error {

for _, step := range prog.Steps {
Expand Down
60 changes: 58 additions & 2 deletions compiler/ssa/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ package ssa
import (
"fmt"
"io"
"sort"
"strings"

"github.com/markkurossi/mpc/circuit"
"github.com/markkurossi/mpc/compiler/circuits"
"github.com/markkurossi/mpc/compiler/utils"
)

type Program struct {
Params *utils.Params
Inputs circuit.IO
Outputs circuit.IO
InputWires []*circuits.Wire
Expand All @@ -24,12 +28,17 @@ type Program struct {
wires map[string][]*circuits.Wire
freeWires map[int][][]*circuits.Wire
nextWireID uint32
zeroWire *circuits.Wire
oneWire *circuits.Wire
numGates uint64
numNonXOR uint64
}

func NewProgram(in, out circuit.IO, consts map[string]ConstantInst,
steps []Step) (*Program, error) {
func NewProgram(params *utils.Params, in, out circuit.IO,
consts map[string]ConstantInst, steps []Step) (*Program, error) {

prog := &Program{
Params: params,
Inputs: in,
Outputs: out,
Constants: consts,
Expand Down Expand Up @@ -181,6 +190,53 @@ func (prog *Program) GC() {
prog.Steps = steps
}

func (prog *Program) DefineConstants(zero, one *circuits.Wire) error {

var consts []Variable
for _, c := range prog.Constants {
consts = append(consts, c.Const)
}
sort.Slice(consts, func(i, j int) bool {
return strings.Compare(consts[i].Name, consts[j].Name) == -1
})

if len(consts) > 0 && prog.Params.Verbose {
fmt.Printf("Defining constants:\n")
}
for _, c := range consts {
msg := fmt.Sprintf(" - %v(%d)", c, c.Type.MinBits)

_, ok := prog.wires[c.String()]
if ok {
fmt.Printf("%s\talready defined\n", msg)
continue
}

var wires []*circuits.Wire
var bitString string
for bit := 0; bit < c.Type.MinBits; bit++ {
var w *circuits.Wire
if c.Bit(bit) {
bitString = "1" + bitString
w = one
} else {
bitString = "0" + bitString
w = zero
}
wires = append(wires, w)
}
if prog.Params.Verbose {
fmt.Printf("%s\t%s\n", msg, bitString)
}

err := prog.SetWires(c.String(), wires)
if err != nil {
return err
}
}
return nil
}

func (prog *Program) PP(out io.Writer) {
for i, in := range prog.Inputs {
fmt.Fprintf(out, "# Input%d: %s\n", i, in)
Expand Down
Loading

0 comments on commit c825a06

Please sign in to comment.