Skip to content

Commit

Permalink
Make p2p.IOStats goroutine-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
rozag committed Jul 7, 2023
1 parent 5976b81 commit 8de6182
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion circuit/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
ioStats = conn.Stats
timing.Sample("Result", []string{FileSize(xfer.Sum()).String()})
if verbose {
timing.Print(conn.Stats.Sent, conn.Stats.Recvd)
timing.Print(conn.Stats.Sent.Load(), conn.Stats.Recvd.Load())
}

return circ.Outputs.Split(raw), nil
Expand Down
2 changes: 1 addition & 1 deletion circuit/garbler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
ioStats = conn.Stats
timing.Sample("Result", []string{FileSize(xfer.Sum()).String()})
if verbose {
timing.Print(conn.Stats.Sent, conn.Stats.Recvd)
timing.Print(conn.Stats.Sent.Load(), conn.Stats.Recvd.Load())
}

return circ.Outputs.Split(result), nil
Expand Down
2 changes: 1 addition & 1 deletion circuit/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func Player(nw *p2p.Network, circ *Circuit, inputs *big.Int, verbose bool) (
ioStats = nw.Stats().Sub(ioStats)
timing.Sample("Result", []string{FileSize(ioStats.Sum()).String()})
if verbose {
timing.Print(nw.Stats().Sent, nw.Stats().Recvd)
timing.Print(nw.Stats().Sent.Load(), nw.Stats().Recvd.Load())
}

fmt.Printf("player not implemented yet\n")
Expand Down
2 changes: 1 addition & 1 deletion circuit/stream_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ loop:
timing.Sample("Result", []string{FileSize(xfer.Sum()).String()})

if verbose {
timing.Print(conn.Stats.Sent, conn.Stats.Recvd)
timing.Print(conn.Stats.Sent.Load(), conn.Stats.Recvd.Load())
}

return outputs, outputs.Split(rawResult), nil
Expand Down
2 changes: 1 addition & 1 deletion compiler/ssa/streamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT,
timing.Sample("Result", []string{circuit.FileSize(xfer.Sum()).String()})

if params.Verbose {
timing.Print(conn.Stats.Sent, conn.Stats.Recvd)
timing.Print(conn.Stats.Sent.Load(), conn.Stats.Recvd.Load())
}

fmt.Printf("Max permanent wires: %d, cached circuits: %d\n",
Expand Down
39 changes: 30 additions & 9 deletions p2p/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package p2p

import (
"io"
"sync/atomic"

"github.com/markkurossi/mpc/ot"
)
Expand Down Expand Up @@ -39,30 +40,49 @@ type Conn struct {

// IOStats implements I/O statistics.
type IOStats struct {
Sent uint64
Recvd uint64
Sent *atomic.Uint64
Recvd *atomic.Uint64
}

func NewIOStats() IOStats {
return IOStats{
Sent: new(atomic.Uint64),
Recvd: new(atomic.Uint64),
}
}

// Add adds the argument stats to this IOStats and returns the sum.
func (stats IOStats) Add(o IOStats) IOStats {
sent := new(atomic.Uint64)
sent.Store(stats.Sent.Load() + o.Sent.Load())

recvd := new(atomic.Uint64)
recvd.Store(stats.Recvd.Load() + o.Recvd.Load())

return IOStats{
Sent: stats.Sent + o.Sent,
Recvd: stats.Recvd + o.Recvd,
Sent: sent,
Recvd: recvd,
}
}

// Sub subtracts the argument stats from this IOStats and returns the
// difference.
func (stats IOStats) Sub(o IOStats) IOStats {
sent := new(atomic.Uint64)
sent.Store(stats.Sent.Load() - o.Sent.Load())

recvd := new(atomic.Uint64)
recvd.Store(stats.Recvd.Load() - o.Recvd.Load())

return IOStats{
Sent: stats.Sent - o.Sent,
Recvd: stats.Recvd - o.Recvd,
Sent: sent,
Recvd: recvd,
}
}

// Sum returns sum of sent and received bytes.
func (stats IOStats) Sum() uint64 {
return stats.Sent + stats.Recvd
return stats.Sent.Load() + stats.Recvd.Load()
}

// NewConn creates a new connection around the argument connection.
Expand All @@ -72,6 +92,7 @@ func NewConn(conn io.ReadWriter) *Conn {
ReadBuf: make([]byte, readBufSize),
fromWriter: make(chan []byte, numBuffers),
toWriter: make(chan []byte, numBuffers),
Stats: NewIOStats(),
}

go c.writer()
Expand All @@ -88,7 +109,7 @@ func (c *Conn) writer() {

for buf := range c.toWriter {
n, err := c.conn.Write(buf)
c.Stats.Sent += uint64(n)
c.Stats.Sent.Add(uint64(n))
if err != nil {
c.writerErr = err
}
Expand Down Expand Up @@ -138,7 +159,7 @@ func (c *Conn) Fill(n int) error {
if err != nil {
return err
}
c.Stats.Recvd += uint64(got)
c.Stats.Recvd.Add(uint64(got))
c.ReadEnd += got
}
return nil
Expand Down

0 comments on commit 8de6182

Please sign in to comment.