Skip to content

Commit

Permalink
Cleaned up API and implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Feb 16, 2023
1 parent 3d543f5 commit 6625400
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 110 deletions.
117 changes: 49 additions & 68 deletions ot/co.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package ot

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
Expand All @@ -27,7 +26,7 @@ var (
)

type COSender struct {
Priv *ecdsa.PrivateKey
curve elliptic.Curve
}

// NewCOSender creates a new CO OT sender. The Sender implements the
Expand All @@ -41,55 +40,43 @@ type COSender struct {
// | |
// |-------send e{0,1}--->|
// | |
func NewCOSender() (*COSender, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
func NewCOSender() *COSender {
return &COSender{
Priv: priv,
}, err
curve: elliptic.P256(),
}
}

func (s *COSender) CurveParams() *elliptic.CurveParams {
return s.Priv.Params()
func (s *COSender) Curve() elliptic.Curve {
return s.curve
}

func (s *COSender) NewTransfer(m0, m1 []byte) (*COSenderXfer, error) {
curveParams := s.Priv.Params()
curveParams := s.curve.Params()

// a <= Zp
// a <- Zp
a, err := rand.Int(rand.Reader, curveParams.N)
if err != nil {
return nil, err
}

// A = G->mul_gen(a)
//
// Point res(this)
// EC_POINT_mul(ec_group, res.point, a.n, Null, NUll, ctx)
// r = res.point
// n = a.n
// int EC_POINT_mul(const EC_GROUP *group, EC_POINT *r, const BIGNUM *n,
// const EC_POINT *q, const BIGNUM *m, BN_CTX *ctx);
//
// => gen*n + q*m => r=gen*n

Ax, Ay := curveParams.ScalarBaseMult(a.Bytes())
Aax, Aay := curveParams.ScalarMult(Ax, Ay, a.Bytes())
// A = G^a
Ax, Ay := s.curve.ScalarBaseMult(a.Bytes())

// BN_usub(point->y, group->field, point->y)
// => result = group->field - point->y
// Aa = A^a
Aax, Aay := s.curve.ScalarMult(Ax, Ay, a.Bytes())

// a: {x,y}
// a^-1: {x,-y}
// AaInv = {Aax, -Aay}
AaInvx := big.NewInt(0).Set(Aax)
AaInvy := big.NewInt(0).Sub(curveParams.P, Aay)

return &COSenderXfer{
sender: s,
curve: s.curve,
hash: sha256.New(),
a: a,
m0: m0,
m1: m1,
a: a,
Ax: Ax,
Ay: Ay,
AaInvx: AaInvx,
Expand All @@ -98,11 +85,11 @@ func (s *COSender) NewTransfer(m0, m1 []byte) (*COSenderXfer, error) {
}

type COSenderXfer struct {
sender *COSender
curve elliptic.Curve
hash hash.Hash
a *big.Int
m0 []byte
m1 []byte
a *big.Int
Ax *big.Int
Ay *big.Int
AaInvx *big.Int
Expand All @@ -115,19 +102,15 @@ func (s *COSenderXfer) A() (x, y []byte) {
return s.Ax.Bytes(), s.Ay.Bytes()
}

func (s *COSenderXfer) ReceiveB(x, y []byte) error {
curveParams := s.sender.Priv.Params()

func (s *COSenderXfer) ReceiveB(x, y []byte) {
bx := big.NewInt(0).SetBytes(x)
by := big.NewInt(0).SetBytes(y)

bx, by = curveParams.ScalarMult(bx, by, s.a.Bytes())
bax, bay := curveParams.Add(bx, by, s.AaInvx, s.AaInvy)
bx, by = s.curve.ScalarMult(bx, by, s.a.Bytes())
bax, bay := s.curve.Add(bx, by, s.AaInvx, s.AaInvy)

s.e0 = xor(s.kdf(bx, by, 0), s.m0)
s.e1 = xor(s.kdf(bax, bay, 0), s.m1)

return nil
}

func (s *COSenderXfer) E() (e0, e1 []byte) {
Expand Down Expand Up @@ -158,66 +141,64 @@ func xor(a, b []byte) []byte {
}

type COReceiver struct {
curveParams *elliptic.CurveParams
curve elliptic.Curve
}

func NewCOReceiver(curveParams *elliptic.CurveParams) (*COReceiver, error) {
func NewCOReceiver(curve elliptic.Curve) *COReceiver {
return &COReceiver{
curveParams: curveParams,
}, nil
curve: curve,
}
}

func (r *COReceiver) NewTransfer(bit uint) (*COReceiverXfer, error) {
curveParams := r.curve.Params()

// b <= Zp
b, err := rand.Int(rand.Reader, r.curveParams.N)
b, err := rand.Int(rand.Reader, curveParams.N)
if err != nil {
return nil, err
}

return &COReceiverXfer{
receiver: r,
curveParams: r.curveParams,
hash: sha256.New(),
bit: bit,
b: b,
curve: r.curve,
hash: sha256.New(),
bit: bit,
b: b,
}, nil
}

type COReceiverXfer struct {
receiver *COReceiver
curveParams *elliptic.CurveParams
hash hash.Hash
bit uint
b *big.Int
Bx *big.Int
By *big.Int
Asx *big.Int
Asy *big.Int
}

func (r *COReceiverXfer) ReceiveA(x, y []byte) error {
curve elliptic.Curve
hash hash.Hash
bit uint
b *big.Int
Bx *big.Int
By *big.Int
Asx *big.Int
Asy *big.Int
}

func (r *COReceiverXfer) ReceiveA(x, y []byte) {
Ax := big.NewInt(0).SetBytes(x)
Ay := big.NewInt(0).SetBytes(y)

Bx, By := r.curveParams.ScalarBaseMult(r.b.Bytes())
Bx, By := r.curve.ScalarBaseMult(r.b.Bytes())
if r.bit != 0 {
Bx, By = r.curveParams.Add(Bx, By, Ax, Ay)
Bx, By = r.curve.Add(Bx, By, Ax, Ay)
}
r.Bx = Bx
r.By = By

Asx, Asy := r.curveParams.ScalarMult(Ax, Ay, r.b.Bytes())
Asx, Asy := r.curve.ScalarMult(Ax, Ay, r.b.Bytes())
r.Asx = Asx
r.Asy = Asy

return nil
}

func (r *COReceiverXfer) B() (x, y []byte) {
return r.Bx.Bytes(), r.By.Bytes()
}

func (r *COReceiverXfer) ReceiveE(e0, e1 []byte) ([]byte, error) {
func (r *COReceiverXfer) ReceiveE(e0, e1 []byte) []byte {
var result []byte

kdf := r.kdf(r.Asx, r.Asy, 0)
Expand All @@ -227,7 +208,7 @@ func (r *COReceiverXfer) ReceiveE(e0, e1 []byte) ([]byte, error) {
} else {
result = xor(kdf, e0)
}
return result, nil
return result
}

func (r *COReceiverXfer) kdf(x, y *big.Int, id uint64) []byte {
Expand Down
54 changes: 12 additions & 42 deletions ot/co_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,8 @@ func TestCO(t *testing.T) {
l0, _ := NewLabel(rand.Reader)
l1, _ := NewLabel(rand.Reader)

sender, err := NewCOSender()
if err != nil {
t.Fatalf("NewCOSender: %v", err)
}

receiver, err := NewCOReceiver(sender.CurveParams())
if err != nil {
t.Fatalf("NewCOReceiver: %v", err)
}
sender := NewCOSender()
receiver := NewCOReceiver(sender.Curve())

var l0Buf, l1Buf LabelData
l0Data := l0.Bytes(&l0Buf)
Expand All @@ -41,18 +34,10 @@ func TestCO(t *testing.T) {
if err != nil {
t.Fatalf("COReceiver.NewTransfer: %v", err)
}
err = rXfer.ReceiveA(sXfer.A())
if err != nil {
t.Fatalf("rXfer.ReceiveA: %v", err)
}
err = sXfer.ReceiveB(rXfer.B())
if err != nil {
t.Fatalf("sXfer.ReceiveB: %v", err)
}
result, err := rXfer.ReceiveE(sXfer.E())
if err != nil {
t.Fatalf("rXfer.ReceiveE: %v", err)
}
rXfer.ReceiveA(sXfer.A())
sXfer.ReceiveB(rXfer.B())
result := rXfer.ReceiveE(sXfer.E())

fmt.Printf("data0: %x\n", l0Data)
fmt.Printf("data1: %x\n", l1Data)
fmt.Printf("result: %x\n", result)
Expand All @@ -62,15 +47,8 @@ func BenchmarkCO(b *testing.B) {
l0, _ := NewLabel(rand.Reader)
l1, _ := NewLabel(rand.Reader)

sender, err := NewCOSender()
if err != nil {
b.Fatalf("NewCOSender: %v", err)
}

receiver, err := NewCOReceiver(sender.CurveParams())
if err != nil {
b.Fatalf("NewCOReceiver: %v", err)
}
sender := NewCOSender()
receiver := NewCOReceiver(sender.Curve())

b.ResetTimer()

Expand All @@ -88,18 +66,10 @@ func BenchmarkCO(b *testing.B) {
if err != nil {
b.Fatalf("COReceiver.NewTransfer: %v", err)
}
err = rXfer.ReceiveA(sXfer.A())
if err != nil {
b.Fatalf("rXfer.ReceiveA: %v", err)
}
err = sXfer.ReceiveB(rXfer.B())
if err != nil {
b.Fatalf("sXfer.ReceiveB: %v", err)
}
result, err := rXfer.ReceiveE(sXfer.E())
if err != nil {
b.Fatalf("rXfer.ReceiveE: %v", err)
}
rXfer.ReceiveA(sXfer.A())
sXfer.ReceiveB(rXfer.B())
result := rXfer.ReceiveE(sXfer.E())

var ret int
if bit == 0 {
ret = bytes.Compare(l0Data[:], result)
Expand Down

0 comments on commit 6625400

Please sign in to comment.