Skip to content

Commit

Permalink
Optimized multiplier.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Mar 23, 2020
1 parent 459597e commit 981cded
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
apps/circuit/circuit
apps/garbled/garbled
apps/ot/ot
apps/iter/iter
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ form assembly:
- [ ] Row reduction
- [ ] Half AND
- [ ] Oblivious transfer extensions
- [ ] Plot multiplication circuit size f(ArrayTreshold)
- Misc:
- [ ] TLS for garbler-evaluator protocol
- [X] Session-specific circuit encryption key
Expand Down
14 changes: 0 additions & 14 deletions apps/garbled/examples/bug.mpcl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,3 @@ package main
func main(a, b uint1024) uint {
return a * b
}

//func main(a, b int8) int {
// return a * b
//}

// 8 301 301
// 16 1365 770
// 32 5797 1972
// 64 23877 5049
// 128 96901 12927
// 256 390405 33095
// 512 1567237 84723
// 1024 6280197 216893
// 2048 25143301 555246
11 changes: 11 additions & 0 deletions apps/garbled/examples/sha256.mpcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// -*- go -*-

package main

import (
"crypto/sha256"
)

func main(block1, block2 uint512) uint256 {
return sha256.Block(block1, sha256.H0)
}
49 changes: 49 additions & 0 deletions apps/iter/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// main.go
//
// Copyright (c) 2019 Markku Rossi
//
// All rights reserved.
//

package main

import (
"fmt"
"log"

"github.com/markkurossi/mpc/compiler"
"github.com/markkurossi/mpc/compiler/utils"
)

var template = `
package main
func main(a, b int%d) int {
return a * b
}
`

func main() {
for bits := 8; bits < 1024; bits += 8 {
code := fmt.Sprintf(template, bits)
var bestLimit int
var bestCost int

for limit := 4; limit < 64; limit += 2 {
circ, _, err := compiler.NewCompiler(&utils.Params{
CircMultArrayTreshold: limit,
}).Compile(code)
if err != nil {
log.Fatalf("Compilation %d:%d failed: %s", bits, limit, err)
}
cost := circ.Cost()

if bestCost == 0 || cost < bestCost {
bestCost = cost
bestLimit = limit
}
}
fmt.Printf("%d\t%d\t%d\n", bits, bestLimit, bestCost)
}
}
4 changes: 4 additions & 0 deletions circuit/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func (c *Circuit) String() string {
c.NumGates, stats, c.NumWires, c.N1.Size(), c.N2.Size(), c.N3.Size())
}

func (c *Circuit) Cost() int {
return (c.Stats[AND]+c.Stats[OR])*4 + c.Stats[INV]*2
}

func (c *Circuit) Dump() {
fmt.Printf("circuit %s\n", c)
for id, gate := range c.Gates {
Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (pkg *Package) Compile(packages map[string]*Package, logger *utils.Logger,
"no main function defined")
}

gen := ssa.NewGenerator(params.Verbose)
gen := ssa.NewGenerator(params)
ctx := NewCodegen(logger, pkg, packages, params.Verbose)

// Init package.
Expand Down
40 changes: 22 additions & 18 deletions compiler/circuits/circ_multiplier.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ import (
"github.com/markkurossi/mpc/circuit"
)

func NewMultiplier(compiler *Compiler, x, y, z []*Wire) error {
func NewMultiplier(c *Compiler, arrayTreshold int, x, y, z []*Wire) error {
if false {
return NewArrayMultiplier(compiler, x, y, z)
return NewArrayMultiplier(c, x, y, z)
} else {
return NewKaratsubaMultiplier(compiler, x, y, z)
if arrayTreshold < 12 {
arrayTreshold = 18
}
return NewKaratsubaMultiplier(c, arrayTreshold, x, y, z)
}
}

Expand Down Expand Up @@ -142,16 +145,17 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error {
// 1024 6280197 4183044 2097153 3423717 2611364 812353
// 2048 25143301 16754692 8388609 10350993 7899604 2451389
//
func NewKaratsubaMultiplier(compiler *Compiler, a, b, r []*Wire) error {
a, b = compiler.ZeroPad(a, b)
func NewKaratsubaMultiplier(cc *Compiler, limit int, a, b, r []*Wire) error {

a, b = cc.ZeroPad(a, b)
if len(a) > len(r) {
return fmt.Errorf("Invalid multiplier arguments: a=%d, b=%d, r=%d",
len(a), len(b), len(r))
}

// Compute smaller multiplications with array multiplier.
if len(a) <= 32 {
return NewArrayMultiplier(compiler, a, b, r)
if len(a) <= limit {
return NewArrayMultiplier(cc, a, b, r)
}

mid := len(a) / 2
Expand All @@ -163,46 +167,46 @@ func NewKaratsubaMultiplier(compiler *Compiler, a, b, r []*Wire) error {
bHigh := b[mid:]

z0 := MakeWires(min(max(len(aLow), len(bLow))*2, len(r)))
if err := NewKaratsubaMultiplier(compiler, aLow, bLow, z0); err != nil {
if err := NewKaratsubaMultiplier(cc, limit, aLow, bLow, z0); err != nil {
return err
}
aSumLen := max(len(aLow), len(aHigh)) + 1
aSum := MakeWires(aSumLen)
if err := NewAdder(compiler, aLow, aHigh, aSum); err != nil {
if err := NewAdder(cc, aLow, aHigh, aSum); err != nil {
return err
}
bSumLen := max(len(bLow), len(bHigh)) + 1
bSum := MakeWires(bSumLen)
if err := NewAdder(compiler, bLow, bHigh, bSum); err != nil {
if err := NewAdder(cc, bLow, bHigh, bSum); err != nil {
return err
}
z1 := MakeWires(min(max(aSumLen, bSumLen)*2, len(r)))
if err := NewKaratsubaMultiplier(compiler, aSum, bSum, z1); err != nil {
if err := NewKaratsubaMultiplier(cc, limit, aSum, bSum, z1); err != nil {
return err
}
z2 := MakeWires(min(max(len(aHigh), len(bHigh))*2, len(r)))
if err := NewKaratsubaMultiplier(compiler, aHigh, bHigh, z2); err != nil {
if err := NewKaratsubaMultiplier(cc, limit, aHigh, bHigh, z2); err != nil {
return err
}

sub1 := MakeWires(len(r))
if err := NewSubtractor(compiler, z1, z2, sub1); err != nil {
if err := NewSubtractor(cc, z1, z2, sub1); err != nil {
return err
}
sub2 := MakeWires(len(r))
if err := NewSubtractor(compiler, sub1, z0, sub2); err != nil {
if err := NewSubtractor(cc, sub1, z0, sub2); err != nil {
return err
}

shift1 := compiler.ShiftLeft(z2, len(r), mid*2)
shift2 := compiler.ShiftLeft(sub2, len(r), mid)
shift1 := cc.ShiftLeft(z2, len(r), mid*2)
shift2 := cc.ShiftLeft(sub2, len(r), mid)

add1 := MakeWires(len(r))
if err := NewAdder(compiler, shift1, shift2, add1); err != nil {
if err := NewAdder(cc, shift1, shift2, add1); err != nil {
return err
}

return NewAdder(compiler, add1, z0, r)
return NewAdder(cc, add1, z0, r)
}

func max(a, b int) int {
Expand Down
4 changes: 2 additions & 2 deletions compiler/circuits/circuits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestMultiply1(t *testing.T) {

outputs := makeWires(2)

err = NewMultiplier(c, c.Inputs[0:1], c.Inputs[1:2], outputs)
err = NewMultiplier(c, 0, c.Inputs[0:1], c.Inputs[1:2], outputs)
if err != nil {
t.Error(err)
}
Expand All @@ -102,7 +102,7 @@ func TestMultiply(t *testing.T) {

outputs := makeWires(bits * 2)

err = NewMultiplier(c, c.Inputs[0:bits], c.Inputs[bits:2*bits], outputs)
err = NewMultiplier(c, 0, c.Inputs[0:bits], c.Inputs[bits:2*bits], outputs)
if err != nil {
t.Error(err)
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/ssa/circuitgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (gen *Generator) DefineConstants(cc *circuits.Compiler) error {
return strings.Compare(consts[i].Name, consts[j].Name) == -1
})

if len(consts) > 0 && gen.verbose {
if len(consts) > 0 && gen.Params.Verbose {
fmt.Printf("Defining constants:\n")
}
for _, c := range consts {
Expand All @@ -43,7 +43,7 @@ func (gen *Generator) DefineConstants(cc *circuits.Compiler) error {
}
wires = append(wires, w)
}
if gen.verbose {
if gen.Params.Verbose {
fmt.Printf("%s\t%s\n", msg, bitString)
}

Expand Down Expand Up @@ -102,7 +102,8 @@ func (b *Block) Circuit(gen *Generator, cc *circuits.Compiler) error {
if err != nil {
return err
}
err = circuits.NewMultiplier(cc, wires[0], wires[1], o)
err = circuits.NewMultiplier(cc, gen.Params.CircMultArrayTreshold,
wires[0], wires[1], o)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/ssa/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import (
"fmt"

"github.com/markkurossi/mpc/compiler/types"
"github.com/markkurossi/mpc/compiler/utils"
)

const (
anon = "%_"
)

type Generator struct {
verbose bool
Params *utils.Params
versions map[string]Variable
blockID int
constants map[string]ConstantInst
Expand All @@ -28,9 +29,9 @@ type ConstantInst struct {
Const Variable
}

func NewGenerator(verbose bool) *Generator {
func NewGenerator(params *utils.Params) *Generator {
return &Generator{
verbose: verbose,
Params: params,
versions: make(map[string]Variable),
constants: make(map[string]ConstantInst),
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/utils/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type Params struct {
SSADotOut io.WriteCloser
CircOut io.WriteCloser
CircDotOut io.WriteCloser

CircMultArrayTreshold int
}

func (p *Params) Close() {
Expand Down

0 comments on commit 981cded

Please sign in to comment.