Skip to content

Commit

Permalink
OT interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Feb 17, 2023
1 parent ab2abcd commit 7ef1efb
Show file tree
Hide file tree
Showing 9 changed files with 602 additions and 12 deletions.
184 changes: 183 additions & 1 deletion ot/co.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt"
"hash"
"math/big"
)

var (
bo = binary.BigEndian
bo = binary.BigEndian
_ OT = &CO{}
)

// COSender implements CO OT sender.
Expand Down Expand Up @@ -224,3 +226,183 @@ func xor(a, b []byte) []byte {
}
return a[:l]
}

// CO implements CO OT as the OT interface.
type CO struct {
curve elliptic.Curve
hash hash.Hash
io IO
}

// NewCO creates a new CO OT implementing the OT interface.
func NewCO() *CO {
return &CO{
curve: elliptic.P256(),
hash: sha256.New(),
}
}

// InitSender initializes the OT sender.
func (co *CO) InitSender(io IO) error {
co.io = io

return io.SendData([]byte(co.curve.Params().Name))
}

// InitReceiver initializes the OT receiver.
func (co *CO) InitReceiver(io IO) error {
co.io = io

name, err := io.ReceiveData()
if err != nil {
return err
}
if string(name) != co.curve.Params().Name {
return fmt.Errorf("invalid curve %s, expected %s",
string(name), co.curve.Params().Name)
}
return nil
}

// Send sends the wire labels with OT.
func (co *CO) Send(wires []Wire) error {
curveParams := co.curve.Params()

// a <- Zp
a, err := rand.Int(rand.Reader, curveParams.N)
if err != nil {
return err
}
aBytes := a.Bytes()

// A = G^a
Ax, Ay := co.curve.ScalarBaseMult(aBytes)

err = co.io.SendData(Ax.Bytes())
if err != nil {
return err
}
err = co.io.SendData(Ay.Bytes())
if err != nil {
return err
}

// Aa = A^a
Aax, Aay := co.curve.ScalarMult(Ax, Ay, aBytes)

// a: {x,y}
// a^-1: {x,-y}
// AaInv = {Aax, -Aay}
AaInvx := big.NewInt(0).Set(Aax)
AaInvy := big.NewInt(0).Sub(curveParams.P, Aay)

for i := 0; i < len(wires); i++ {
data, err := co.io.ReceiveData()
if err != nil {
return err
}
Bx := big.NewInt(0).SetBytes(data)

data, err = co.io.ReceiveData()
if err != nil {
return err
}
By := big.NewInt(0).SetBytes(data)

Bx, By = co.curve.ScalarMult(Bx, By, aBytes)
Bax, Bay := co.curve.Add(Bx, By, AaInvx, AaInvy)

var labelData LabelData

wires[i].L0.GetData(&labelData)
e0 := xor(kdf(co.hash, Bx, By, uint64(i)), labelData[:])
err = co.io.SendData(e0)
if err != nil {
return err
}
wires[i].L1.GetData(&labelData)
e1 := xor(kdf(co.hash, Bax, Bay, uint64(i)), labelData[:])
err = co.io.SendData(e1)
if err != nil {
return err
}
}
return nil
}

// Receive receives the wire labels with OT based on the flag values.
func (co *CO) Receive(flags []bool) ([]Label, error) {
curveParams := co.curve.Params()

result := make([]Label, len(flags))

data, err := co.io.ReceiveData()
if err != nil {
return nil, err
}
Ax := big.NewInt(0).SetBytes(data)

data, err = co.io.ReceiveData()
if err != nil {
return nil, err
}
Ay := big.NewInt(0).SetBytes(data)

for i := 0; i < len(flags); i++ {
// b <= Zp
b, err := rand.Int(rand.Reader, curveParams.N)
if err != nil {
return nil, err
}
bBytes := b.Bytes()

Bx, By := co.curve.ScalarBaseMult(bBytes)
if flags[i] {
Bx, By = co.curve.Add(Bx, By, Ax, Ay)
}
err = co.io.SendData(Bx.Bytes())
if err != nil {
return nil, err
}
err = co.io.SendData(By.Bytes())
if err != nil {
return nil, err
}

Asx, Asy := co.curve.ScalarMult(Ax, Ay, bBytes)

// Receive E

data := kdf(co.hash, Asx, Asy, uint64(i))

var e []byte
if flags[i] {
_, err = co.io.ReceiveData()
if err != nil {
return nil, err
}

e, err := co.io.ReceiveData()
if err != nil {
return nil, err
}
data = xor(data, e)
} else {
e, err = co.io.ReceiveData()
if err != nil {
return nil, err
}
data = xor(data, e)

_, err := co.io.ReceiveData()
if err != nil {
return nil, err
}
}
var labelData LabelData
copy(labelData[:], data)
result[i].SetData(&labelData)
}

return result, nil
}
19 changes: 13 additions & 6 deletions ot/co_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ package ot
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
)

Expand All @@ -30,17 +29,25 @@ func TestCO(t *testing.T) {
if err != nil {
t.Fatalf("COSender.NewTransfer: %v", err)
}
rXfer, err := receiver.NewTransfer(1)
var bit uint = 1

rXfer, err := receiver.NewTransfer(bit)
if err != nil {
t.Fatalf("COReceiver.NewTransfer: %v", err)
}
rXfer.ReceiveA(sXfer.A())
sXfer.ReceiveB(rXfer.B())
result := rXfer.ReceiveE(sXfer.E())

fmt.Printf("data0: %x\n", l0Data)
fmt.Printf("data1: %x\n", l1Data)
fmt.Printf("result: %x\n", result)
var ret int
if bit == 0 {
ret = bytes.Compare(result, l0Data[:])
} else {
ret = bytes.Compare(result, l1Data[:])
}
if ret != 0 {
t.Errorf("Verify failed")
}
}

func BenchmarkCO(b *testing.B) {
Expand All @@ -60,7 +67,7 @@ func BenchmarkCO(b *testing.B) {
if err != nil {
b.Fatalf("COSender.NewTransfer: %v", err)
}
var bit uint = 1
bit := uint(i % 2)

rXfer, err := receiver.NewTransfer(bit)
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions ot/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//
// io.go
//
// Copyright (c) 2023 Markku Rossi
//
// All rights reserved.

package ot

// IO defines an I/O interface to communicate between peers.
type IO interface {
// SendData sends binary data.
SendData(val []byte) error

// SendUint32 sends an uint32 value.
SendUint32(val int) error

// Flush flushed any pending data in the connection.
Flush() error

// ReceiveData receives binary data.
ReceiveData() ([]byte, error)

// ReceiveUint32 receives an uint32 value.
ReceiveUint32() (int, error)
}
23 changes: 23 additions & 0 deletions ot/ot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//
// ot.go
//
// Copyright (c) 2023 Markku Rossi
//
// All rights reserved.

package ot

// OT defines Oblivious Transfer protocol.
type OT interface {
// InitSender initializes the OT sender.
InitSender(io IO) error

// InitReceiver initializes the OT receiver.
InitReceiver(io IO) error

// Send sends the wire labels with OT.
Send(wires []Wire) error

// Receive receives the wire labels with OT based on the flag values.
Receive(flags []bool) ([]Label, error)
}
Loading

0 comments on commit 7ef1efb

Please sign in to comment.