Skip to content

Commit

Permalink
RSA OT implementation of the OT interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Feb 18, 2023
1 parent 00a9899 commit 8d860ec
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 2 deletions.
18 changes: 16 additions & 2 deletions ot/ot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

func testOT(sender, receiver OT, t *testing.T) {
const size int = 1024
const size int = 64

wires := make([]Wire, size)
flags := make([]bool, size)
Expand Down Expand Up @@ -68,7 +68,6 @@ func testOT(sender, receiver OT, t *testing.T) {
err := fmt.Errorf("label %d mismatch %v %v,%v", i,
labels[i], wires[i].L0, wires[i].L1)
pipe.Close()
pipe.Drain()
done <- err
return
}
Expand Down Expand Up @@ -96,6 +95,10 @@ func TestOTCO(t *testing.T) {
testOT(NewCO(), NewCO(), t)
}

func TestOTRSA(t *testing.T) {
testOT(NewRSA(2048), NewRSA(2048), t)
}

func benchmarkOT(sender, receiver OT, batchSize int, b *testing.B) {
wires := make([]Wire, batchSize)
flags := make([]bool, batchSize)
Expand Down Expand Up @@ -206,3 +209,14 @@ func BenchmarkOTCO64(b *testing.B) {
func XBenchmarkOTCO128(b *testing.B) {
benchmarkOT(NewCO(), NewCO(), 128, b)
}

func benchmarkOTRSA(keySize, batchSize int, b *testing.B) {
benchmarkOT(NewRSA(keySize), NewRSA(keySize), batchSize, b)
}

func BenchmarkOTRSA2048_1(b *testing.B) {
benchmarkOTRSA(2048, 1, b)
}
func BenchmarkOTRSA2048_8(b *testing.B) {
benchmarkOTRSA(2048, 8, b)
}
194 changes: 194 additions & 0 deletions ot/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package ot
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"math/big"

"github.com/markkurossi/mpc/ot/mpint"
Expand Down Expand Up @@ -212,3 +213,196 @@ func (r *ReceiverXfer) ReceiveMessages(m0p, m1p []byte, err error) error {
func (r *ReceiverXfer) Message() (m []byte, bit uint) {
return r.mb, r.bit
}

// RSA implements RSA OT as the OT interface.
type RSA struct {
keyBits int
name string
io IO
priv *rsa.PrivateKey
pub *rsa.PublicKey
}

// NewRSA creates a new RSA OT implementing the OT interface. The
// argument specifies the RSA key size in bits.
func NewRSA(keyBits int) *RSA {
return &RSA{
keyBits: keyBits,
name: fmt.Sprintf("RSA-%v", keyBits),
}
}

func (r *RSA) messageSize() int {
return (r.keyBits + 7) / 8
}

// InitSender initializes the OT sender.
func (r *RSA) InitSender(io IO) error {
r.io = io

priv, err := rsa.GenerateKey(rand.Reader, r.keyBits)
if err != nil {
return err
}
r.priv = priv
r.pub = &priv.PublicKey

if err := SendString(io, r.name); err != nil {
return err
}
if err := io.SendData(r.pub.N.Bytes()); err != nil {
return err
}
if err := io.SendUint32(r.pub.E); err != nil {
return err
}

return io.Flush()
}

// InitReceiver initializes the OT receiver.
func (r *RSA) InitReceiver(io IO) error {
r.io = io

name, err := ReceiveString(io)
if err != nil {
return err
}
if name != r.name {
return fmt.Errorf("invalid algorithm %s, expected %s", name, r.name)
}
pubN, err := ReceiveBigInt(io)
if err != nil {
return err
}
pubE, err := io.ReceiveUint32()
if err != nil {
return err
}
r.pub = &rsa.PublicKey{
N: pubN,
E: pubE,
}
return nil
}

// Send sends the wire labels with OT.
func (r *RSA) Send(wires []Wire) error {
for i := 0; i < len(wires); i++ {
// Send random messages.
x0, err := RandomData(r.messageSize())
if err != nil {
return err
}
x1, err := RandomData(r.messageSize())
if err != nil {
return err
}
if err := r.io.SendData(x0); err != nil {
return err
}
if err := r.io.SendData(x1); err != nil {
return err
}
if err := r.io.Flush(); err != nil {
return err
}

// Receive V.
v, err := ReceiveBigInt(r.io)
if err != nil {
return err
}
x0i := mpint.FromBytes(x0)
x1i := mpint.FromBytes(x1)
k0 := mpint.Exp(mpint.Sub(v, x0i), r.priv.D, r.pub.N)
k1 := mpint.Exp(mpint.Sub(v, x1i), r.priv.D, r.pub.N)

// Create transfer messages.
var ld LabelData
wires[i].L0.GetData(&ld)
m0, err := pkcs1.NewEncryptionBlock(pkcs1.BT1, r.messageSize(), ld[:])
if err != nil {
return err
}
m0p := mpint.Add(mpint.FromBytes(m0), k0)
if err := r.io.SendData(m0p.Bytes()); err != nil {
return err
}
wires[i].L1.GetData(&ld)
m1, err := pkcs1.NewEncryptionBlock(pkcs1.BT1, r.messageSize(), ld[:])
if err != nil {
return err
}
m1p := mpint.Add(mpint.FromBytes(m1), k1)
if err := r.io.SendData(m1p.Bytes()); err != nil {
return err
}
if err := r.io.Flush(); err != nil {
return err
}
}
return nil
}

// Receive receives the wire labels with OT based on the flag values.
func (r *RSA) Receive(flags []bool, result []Label) error {
for i := 0; i < len(flags); i++ {
k, err := rand.Int(rand.Reader, r.pub.N)
if err != nil {
return err
}
// Receive random messages.
x0, err := ReceiveBigInt(r.io)
if err != nil {
return err
}
x1, err := ReceiveBigInt(r.io)
if err != nil {
return err
}
var xb *big.Int
if flags[i] {
xb = x1
} else {
xb = x0
}

// Create and send V.
e := big.NewInt(int64(r.pub.E))
v := mpint.Mod(mpint.Add(xb, mpint.Exp(k, e, r.pub.N)), r.pub.N)
if err := r.io.SendData(v.Bytes()); err != nil {
return err
}
if err := r.io.Flush(); err != nil {
return err
}

// Receive transfer messages.
m0p, err := ReceiveBigInt(r.io)
if err != nil {
return err
}
m1p, err := ReceiveBigInt(r.io)
if err != nil {
return err
}
var mbp *big.Int
if flags[i] {
mbp = m1p
} else {
mbp = m0p
}
mbBytes := make([]byte, r.messageSize())
mbIntBytes := mpint.Sub(mbp, k).Bytes()
ofs := len(mbBytes) - len(mbIntBytes)
copy(mbBytes[ofs:], mbIntBytes)

mb, err := pkcs1.ParseEncryptionBlock(mbBytes)
if err != nil {
return err
}
result[i].SetBytes(mb)
}
return nil
}

0 comments on commit 8d860ec

Please sign in to comment.