diff --git a/call.go b/call.go index 797190f1471c..438758fc3ed1 100644 --- a/call.go +++ b/call.go @@ -99,17 +99,17 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, Client: true, } } - outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload) + hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload) if err != nil { return err } if c.maxSendMessageSize == nil { return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } - if len(outBuf) > *c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize) + if len(data) > *c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) } - err = t.Write(stream, outBuf, opts) + err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() dopts.copts.StatsHandler.HandleRPC(ctx, outPayload) diff --git a/call_test.go b/call_test.go index deb3cb6eed39..f3113092948d 100644 --- a/call_test.go +++ b/call_test.go @@ -104,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) + hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) if err != nil { t.Errorf("Failed to encode the response: %v", err) return } - h.t.Write(s, reply, &transport.Options{}) + h.t.Write(s, hdr, data, &transport.Options{}) h.t.WriteStatus(s, status.New(codes.OK, "")) } diff --git a/rpc_util.go b/rpc_util.go index be8444a1a6be..caded6522098 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -288,19 +288,20 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt return pf, msg, nil } -// encode serializes msg and prepends the message header. If msg is nil, it -// generates the message header of 0 message length. -func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) { - var ( - b []byte - length uint +// encode serializes msg and returns a buffer of message header and a buffer of msg. +// If msg is nil, it generates the message header and an empty msg buffer. +func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) { + var b []byte + const ( + payloadLen = 1 + sizeLen = 4 ) + if msg != nil { var err error - // TODO(zhaoq): optimize to reduce memory alloc and copying. b, err = c.Marshal(msg) if err != nil { - return nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if outPayload != nil { outPayload.Payload = msg @@ -310,39 +311,28 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl } if cp != nil { if err := cp.Do(cbuf, b); err != nil { - return nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } b = cbuf.Bytes() } - length = uint(len(b)) - } - if length > math.MaxUint32 { - return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length) } - const ( - payloadLen = 1 - sizeLen = 4 - ) - - var buf = make([]byte, payloadLen+sizeLen+len(b)) + if len(b) > math.MaxUint32 { + return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + } - // Write payload format + bufHeader := make([]byte, payloadLen+sizeLen) if cp == nil { - buf[0] = byte(compressionNone) + bufHeader[0] = byte(compressionNone) } else { - buf[0] = byte(compressionMade) + bufHeader[0] = byte(compressionMade) } // Write length of b into buf - binary.BigEndian.PutUint32(buf[1:], uint32(length)) - // Copy encoded msg to buf - copy(buf[5:], b) - + binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) if outPayload != nil { - outPayload.WireLength = len(buf) + outPayload.WireLength = payloadLen + sizeLen + len(b) } - - return buf, nil + return bufHeader, b, nil } func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error { diff --git a/rpc_util_test.go b/rpc_util_test.go index 7cbad491a591..23c471e2e407 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -104,14 +104,15 @@ func TestEncode(t *testing.T) { msg proto.Message cp Compressor // outputs - b []byte - err error + hdr []byte + data []byte + err error }{ - {nil, nil, []byte{0, 0, 0, 0, 0}, nil}, + {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, } { - b, err := encode(protoCodec{}, test.msg, nil, nil, nil) - if err != test.err || !bytes.Equal(b, test.b) { - t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err) + hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil) + if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { + t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) } } } @@ -164,8 +165,8 @@ func TestToRPCErr(t *testing.T) { // bytes. func bmEncode(b *testing.B, mSize int) { msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encoded, _ := encode(protoCodec{}, msg, nil, nil, nil) - encodedSz := int64(len(encoded)) + encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil) + encodedSz := int64(len(encodeHdr) + len(encodeData)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/server.go b/server.go index 5885c6ca69ec..86fe20a53a19 100644 --- a/server.go +++ b/server.go @@ -677,15 +677,15 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) + hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err } - if len(p) > s.opts.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize) + if len(data) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) } - err = t.Write(stream, p, opts) + err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) diff --git a/stream.go b/stream.go index c155d3d4edec..2fcf36873746 100644 --- a/stream.go +++ b/stream.go @@ -362,7 +362,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, } } - out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) + hdr, data, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) defer func() { if cs.cbuf != nil { cs.cbuf.Reset() @@ -374,10 +374,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if cs.c.maxSendMessageSize == nil { return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } - if len(out) > *cs.c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize) + if len(data) > *cs.c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } - err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) + err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) @@ -449,7 +449,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } func (cs *clientStream) CloseSend() (err error) { - err = cs.t.Write(cs.s, nil, &transport.Options{Last: true}) + err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) defer func() { if err != nil { cs.finish(err) @@ -608,7 +608,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) + hdr, data, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) defer func() { if ss.cbuf != nil { ss.cbuf.Reset() @@ -617,10 +617,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if err != nil { return err } - if len(out) > ss.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize) + if len(data) > ss.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } - if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { + if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } if outPayload != nil { diff --git a/transport/handler_server.go b/transport/handler_server.go index 85b8ee0ac2fb..0489fada52e1 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -255,9 +255,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { } } -func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { +func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { return ht.do(func() { ht.writeCommonHeaders(s) + ht.rw.Write(hdr) ht.rw.Write(data) if !opts.Delay { ht.rw.(http.Flusher).Flush() diff --git a/transport/http2_client.go b/transport/http2_client.go index 8546d094c61b..5f229138b948 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -683,8 +683,15 @@ func (t *http2Client) GracefulClose() error { // should proceed only if Write returns nil. // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later // if it improves the performance. -func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { - r := bytes.NewBuffer(data) +func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen + if len(data) < secondStart { + secondStart = len(data) + } + hdr = append(hdr, data[:secondStart]...) + data = data[secondStart:] + isLastSlice := (len(data) == 0) + r := bytes.NewBuffer(hdr) var ( p []byte oqv uint32 @@ -726,9 +733,6 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { endStream bool forceFlush bool ) - if opts.Last && r.Len() == 0 { - endStream = true - } // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. @@ -768,10 +772,22 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { t.writableChan <- 0 continue } - if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { - // Do a force flush iff this is last frame for the entire gRPC message - // and the caller is the only writer at this moment. - forceFlush = true + if r.Len() == 0 { + if isLastSlice { + if opts.Last { + endStream = true + } + if t.framer.adjustNumWriters(0) == 1 { + // Do a force flush iff this is last frame for the entire gRPC message + // and the caller is the only writer at this moment. + forceFlush = true + } + } else { + isLastSlice = true + if len(data) != 0 { + r = bytes.NewBuffer(data) + } + } } // If WriteData fails, all the pending streams will be handled // by http2Client.Close(). No explicit CloseStream() needs to be diff --git a/transport/http2_server.go b/transport/http2_server.go index 6ee6f40f7523..302651b5ab4e 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -827,8 +827,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). -func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { +func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) { // TODO(zhaoq): Support multi-writers for a single stream. + secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen + if len(data) < secondStart { + secondStart = len(data) + } + hdr = append(hdr, data[:secondStart]...) + data = data[secondStart:] + isLastSlice := (len(data) == 0) var writeHeaderFrame bool s.mu.Lock() if s.state == streamDone { @@ -842,7 +849,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { if writeHeaderFrame { t.WriteHeader(s, nil) } - r := bytes.NewBuffer(data) + r := bytes.NewBuffer(hdr) var ( p []byte oqv uint32 @@ -921,8 +928,15 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { continue } var forceFlush bool - if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { - forceFlush = true + if r.Len() == 0 { + if isLastSlice { + if t.framer.adjustNumWriters(0) == 1 && !opts.Last { + forceFlush = true + } + } else { + r = bytes.NewBuffer(data) + isLastSlice = true + } } // Reset ping strikes when sending data since this might cause // the peer to send ping. diff --git a/transport/transport.go b/transport/transport.go index ec0fe678dbf4..c5732beec052 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -564,7 +564,7 @@ type ClientTransport interface { // Write sends the data for the given stream. A nil stream indicates // the write is to be performed on the transport as a whole. - Write(s *Stream, data []byte, opts *Options) error + Write(s *Stream, hdr []byte, data []byte, opts *Options) error // NewStream creates a Stream for an RPC. NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) @@ -606,7 +606,7 @@ type ServerTransport interface { // Write sends the data for the given stream. // Write may not be called on all streams. - Write(s *Stream, data []byte, opts *Options) error + Write(s *Stream, hdr []byte, data []byte, opts *Options) error // WriteStatus sends the status of a stream to the client. WriteStatus is // the final call made on a stream and always occurs. diff --git a/transport/transport_test.go b/transport/transport_test.go index 8610478890bc..be2d8dad4136 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -92,7 +92,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. - h.t.Write(s, resp, &Options{}) + h.t.Write(s, resp, nil, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -112,7 +112,7 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { buf[0] = byte(0) binary.BigEndian.PutUint32(buf[1:], uint32(sz)) copy(buf[5:], msg) - h.t.Write(s, buf, &Options{}) + h.t.Write(s, buf, nil, &Options{}) } } @@ -190,7 +190,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. - h.t.Write(s, resp, &Options{}) + h.t.Write(s, resp, nil, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -215,7 +215,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { // Wait before sending. Give time to client to start reading // before server starts sending. time.Sleep(2 * time.Second) - h.t.Write(s, resp, &Options{}) + h.t.Write(s, resp, nil, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -808,7 +808,7 @@ func TestClientSendAndReceive(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s1, expectedRequest, nil, &opts); err != nil && err != io.EOF { t.Fatalf("failed to send data: %v", err) } p := make([]byte, len(expectedResponse)) @@ -845,7 +845,7 @@ func performOneRPC(ct ClientTransport) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF { + if err := ct.Write(s, expectedRequest, nil, &opts); err == nil || err == io.EOF { time.Sleep(5 * time.Millisecond) // The following s.Recv()'s could error out because the // underlying transport is gone. @@ -889,7 +889,7 @@ func TestLargeMessage(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -921,7 +921,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -959,7 +959,7 @@ func TestLargeMessageDelayWrite(t *testing.T) { // Give time to server to start reading before client starts sending. time.Sleep(2 * time.Second) - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -1005,7 +1005,7 @@ func TestGracefulClose(t *testing.T) { Delay: false, } // The stream which was created before graceful close can still proceed. - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF { t.Fatalf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponse)) @@ -1034,7 +1034,7 @@ func TestLargeMessageSuspension(t *testing.T) { } // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) - err = ct.Write(s, msg, &Options{Last: true, Delay: false}) + err = ct.Write(s, msg, nil, &Options{Last: true, Delay: false}) expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) if err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) @@ -1311,7 +1311,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Failed to create 1st stream. Err: %v", err) } // Exhaust server's connection window. - if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { + if err := client.Write(cstream1, make([]byte, defaultWindowSize), nil, &Options{Last: true}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } //Client should be able to create another stream and send data on it. @@ -1319,7 +1319,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { if err != nil { t.Fatalf("Failed to create 2nd stream. Err: %v", err) } - if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil { + if err := client.Write(cstream2, make([]byte, defaultWindowSize), nil, &Options{}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } // Get the streams on server. @@ -1474,7 +1474,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { t.Fatalf("Failed to open stream: %v", err) } d := make([]byte, 1) - if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, d, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Fatalf("Failed to write: %v", err) } // Read without window update. @@ -1516,7 +1516,7 @@ func TestEncodingRequiredStatus(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) @@ -1544,7 +1544,7 @@ func TestInvalidHeaderField(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) @@ -1787,7 +1787,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { opts := Options{} header := make([]byte, 5) for i := 1; i <= 10; i++ { - if err := ct.Write(cstream, buf, &opts); err != nil { + if err := ct.Write(cstream, buf, nil, &opts); err != nil { t.Fatalf("Error on client while writing message: %v", err) } if _, err := cstream.Read(header); err != nil {