diff --git a/eth/protocol.go b/eth/protocol.go index b67e5aaeabb9..723ab5502e4c 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -140,7 +140,7 @@ func (self *ethProtocol) handle() error { return self.protoError(ErrDecode, "->msg %v: %v", msg, err) } hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount) - return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) + return p2p.EncodeMsg(self.rw, BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) case BlockHashesMsg: // TODO: redo using lazy decode , this way very inefficient on known chains @@ -185,7 +185,7 @@ func (self *ethProtocol) handle() error { break } } - return self.rw.EncodeMsg(BlocksMsg, blocks...) + return p2p.EncodeMsg(self.rw, BlocksMsg, blocks...) case BlocksMsg: msgStream := rlp.NewStream(msg.Payload) @@ -298,12 +298,12 @@ func (self *ethProtocol) handleStatus() error { func (self *ethProtocol) requestBlockHashes(from []byte) error { self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4]) - return self.rw.EncodeMsg(GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize)) + return p2p.EncodeMsg(self.rw, GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize)) } func (self *ethProtocol) requestBlocks(hashes [][]byte) error { self.peer.Debugf("fetching %v blocks", len(hashes)) - return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...) + return p2p.EncodeMsg(self.rw, GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...) } func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) { diff --git a/eth/protocol_test.go b/eth/protocol_test.go index ab2aa289f087..224b59abd323 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -41,10 +41,6 @@ func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error { return nil } -func (self *testMsgReadWriter) EncodeMsg(code uint64, data ...interface{}) error { - return self.WriteMsg(p2p.NewMsg(code, data...)) -} - func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) { msg, ok := <-self.in if !ok { diff --git a/p2p/message.go b/p2p/message.go index a6f62ec4c861..daf2bf05c1c5 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -71,14 +71,11 @@ type MsgReader interface { } type MsgWriter interface { - // WriteMsg sends an existing message. - // The Payload reader of the message is consumed. + // WriteMsg sends a message. It will block until the message's + // Payload has been consumed by the other end. + // // Note that messages can be sent only once. WriteMsg(Msg) error - - // EncodeMsg writes an RLP-encoded message with the given - // code and data elements. - EncodeMsg(code uint64, data ...interface{}) error } // MsgReadWriter provides reading and writing of encoded messages. @@ -87,6 +84,12 @@ type MsgReadWriter interface { MsgWriter } +// EncodeMsg writes an RLP-encoded message with the given code and +// data elements. +func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { + return w.WriteMsg(NewMsg(code, data...)) +} + var magicToken = []byte{34, 64, 8, 145} func writeMsg(w io.Writer, msg Msg) error { @@ -209,11 +212,6 @@ func (p *MsgPipeRW) WriteMsg(msg Msg) error { return ErrPipeClosed } -// EncodeMsg is a convenient shorthand for sending an RLP-encoded message. -func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error { - return p.WriteMsg(NewMsg(code, data...)) -} - // ReadMsg returns a message sent on the other end of the pipe. func (p *MsgPipeRW) ReadMsg() (Msg, error) { if atomic.LoadInt32(p.closed) == 0 { diff --git a/p2p/message_test.go b/p2p/message_test.go index 066d2516d4e0..5cde9abf5b21 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -75,8 +75,8 @@ func TestDecodeRealMsg(t *testing.T) { func ExampleMsgPipe() { rw1, rw2 := MsgPipe() go func() { - rw1.EncodeMsg(8, []byte{0, 0}) - rw1.EncodeMsg(5, []byte{1, 1}) + EncodeMsg(rw1, 8, []byte{0, 0}) + EncodeMsg(rw1, 5, []byte{1, 1}) rw1.Close() }() @@ -100,7 +100,7 @@ loop: rw1, rw2 := MsgPipe() done := make(chan struct{}) go func() { - if err := rw1.EncodeMsg(1); err == nil { + if err := EncodeMsg(rw1, 1); err == nil { t.Error("EncodeMsg returned nil error") } else if err != ErrPipeClosed { t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 5b9e9e784758..4ee88f112b0f 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -126,10 +126,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) { Name: "a", Length: 2, Run: func(peer *Peer, rw MsgReadWriter) error { - if err := rw.EncodeMsg(2); err == nil { + if err := EncodeMsg(rw, 2); err == nil { t.Error("expected error for out-of-range msg code, got nil") } - if err := rw.EncodeMsg(1, "foo", "bar"); err != nil { + if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil { t.Errorf("write error: %v", err) } return nil diff --git a/p2p/protocol.go b/p2p/protocol.go index dd8cbc4ecdbc..969937076bd6 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -119,14 +119,14 @@ func (bp *baseProtocol) loop(quit <-chan error) error { getPeersTick := time.NewTicker(10 * time.Second) defer getPeersTick.Stop() - err := bp.rw.EncodeMsg(getPeersMsg) + err := EncodeMsg(bp.rw, getPeersMsg) for err == nil { select { case err = <-quit: return err case <-getPeersTick.C: - err = bp.rw.EncodeMsg(getPeersMsg) + err = EncodeMsg(bp.rw, getPeersMsg) case event := <-activity.Chan(): ping.Reset(pingTimeout) lastActive = event.(time.Time) @@ -134,7 +134,7 @@ func (bp *baseProtocol) loop(quit <-chan error) error { if lastActive.Add(pingTimeout * 2).Before(t) { err = newPeerError(errPingTimeout, "") } else if lastActive.Add(pingTimeout).Before(t) { - err = bp.rw.EncodeMsg(pingMsg) + err = EncodeMsg(bp.rw, pingMsg) } } } @@ -164,7 +164,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { return discRequestedError(reason[0]) case pingMsg: - return bp.rw.EncodeMsg(pongMsg) + return EncodeMsg(bp.rw, pongMsg) case pongMsg: @@ -177,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { // // TODO: add event mechanism to notify baseProtocol for new peers if len(peers) > 0 { - return bp.rw.EncodeMsg(peersMsg, peers...) + return EncodeMsg(bp.rw, peersMsg, peers...) } case peersMsg: diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go index ce25b3e1b541..ba5e95c0219d 100644 --- a/p2p/protocol_test.go +++ b/p2p/protocol_test.go @@ -93,7 +93,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { if err := expectMsg(rw2, handshakeMsg); err != nil { t.Error(err) } - err := rw2.EncodeMsg(handshakeMsg, + err := EncodeMsg(rw2, handshakeMsg, baseProtocolVersion, "", []interface{}{}, @@ -106,7 +106,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { if err := expectMsg(rw2, getPeersMsg); err != nil { t.Error(err) } - if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { + if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil { t.Error(err) }