Skip to content

Commit

Permalink
Pipe fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Feb 17, 2023
1 parent b7ae910 commit 03c5614
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 24 deletions.
9 changes: 4 additions & 5 deletions ot/co.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,20 @@ func NewCO() *CO {
// InitSender initializes the OT sender.
func (co *CO) InitSender(io IO) error {
co.io = io

return io.SendData([]byte(co.curve.Params().Name))
return SendString(io, co.curve.Params().Name)
}

// InitReceiver initializes the OT receiver.
func (co *CO) InitReceiver(io IO) error {
co.io = io

name, err := io.ReceiveData()
name, err := ReceiveString(io)
if err != nil {
return err
}
if string(name) != co.curve.Params().Name {
if name != co.curve.Params().Name {
return fmt.Errorf("invalid curve %s, expected %s",
string(name), co.curve.Params().Name)
name, co.curve.Params().Name)
}
return nil
}
Expand Down
14 changes: 14 additions & 0 deletions ot/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ type IO interface {
ReceiveUint32() (int, error)
}

// SendString sends a string value.
func SendString(io IO, str string) error {
return io.SendData([]byte(str))
}

// ReceiveString receives a string value.
func ReceiveString(io IO) (string, error) {
data, err := io.ReceiveData()
if err != nil {
return "", err
}
return string(data), nil
}

// ReceiveBigInt receives a bit.Int from the connection.
func ReceiveBigInt(io IO) (*big.Int, error) {
data, err := io.ReceiveData()
Expand Down
21 changes: 12 additions & 9 deletions ot/ot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,21 @@ func testOT(sender, receiver OT, t *testing.T) {
flags[i] = i%2 == 0
}

pipe := NewPipe()
pipe, rPipe := NewPipe()

go func() {
go func(pipe *Pipe) {
err := receiver.InitReceiver(pipe)
if err != nil {
done <- err
pipe.Close()
pipe.Drain()
done <- err
return
}
labels, err := receiver.Receive(flags)
if err != nil {
done <- err
pipe.Close()
pipe.Drain()
done <- err
return
}
for i := 0; i < len(flags); i++ {
Expand All @@ -64,14 +66,15 @@ func testOT(sender, receiver OT, t *testing.T) {
if !labels[i].Equal(expected) {
err := fmt.Errorf("label %d mismatch %v %v,%v", i,
labels[i], wires[i].L0, wires[i].L1)
done <- err
pipe.Close()
pipe.Drain()
done <- err
return
}
}

done <- nil
}()
}(rPipe)

err := sender.InitSender(pipe)
if err != nil {
Expand Down Expand Up @@ -115,11 +118,11 @@ func benchmarkOT(sender, receiver OT, batchSize int, b *testing.B) {
flags[i] = i%2 == 0
}

pipe := NewPipe()
pipe, rPipe := NewPipe()

b.ResetTimer()

go func() {
go func(pipe *Pipe) {
for i := 0; i < b.N; i++ {
err := receiver.InitReceiver(pipe)
if err != nil {
Expand Down Expand Up @@ -151,7 +154,7 @@ func benchmarkOT(sender, receiver OT, batchSize int, b *testing.B) {
}

done <- nil
}()
}(rPipe)

for i := 0; i < b.N; i++ {
err := sender.InitSender(pipe)
Expand Down
27 changes: 20 additions & 7 deletions ot/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ type Pipe struct {
}

// NewPipe creates a new in-memory pipe.
func NewPipe() *Pipe {
r, w := io.Pipe()
func NewPipe() (*Pipe, *Pipe) {
ar, aw := io.Pipe()
br, bw := io.Pipe()

return &Pipe{
rBuf: make([]byte, 64*1024),
wBuf: make([]byte, 64*1024),
r: r,
w: w,
}
rBuf: make([]byte, 64*1024),
wBuf: make([]byte, 64*1024),
r: ar,
w: bw,
}, &Pipe{
rBuf: make([]byte, 64*1024),
wBuf: make([]byte, 64*1024),
r: br,
w: aw,
}
}

// SendData sends binary data.
Expand All @@ -59,6 +66,12 @@ func (p *Pipe) Flush() error {
return nil
}

// Drain consumes all input from the pipe.
func (p *Pipe) Drain() error {
_, err := io.Copy(io.Discard, p.r)
return err
}

// Close closes the pipe.
func (p *Pipe) Close() error {
return p.w.Close()
Expand Down
6 changes: 3 additions & 3 deletions ot/pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ func TestPipe(t *testing.T) {
testData := []byte("Hello, world!")
testInt := 42

pipe := NewPipe()
pipe, rPipe := NewPipe()
done := make(chan error)

go func() {
go func(pipe *Pipe) {
data, err := pipe.ReceiveData()
if err != nil {
done <- err
Expand Down Expand Up @@ -51,7 +51,7 @@ func TestPipe(t *testing.T) {
done <- fmt.Errorf("expected EOF")
}
done <- nil
}()
}(rPipe)

err := pipe.SendData(testData)
if err != nil {
Expand Down

0 comments on commit 03c5614

Please sign in to comment.