Skip to content

Commit

Permalink
Ensure client stream can Receive() after calling CloseSend() and …
Browse files Browse the repository at this point in the history
…handle context cancellation (connectrpc#150)
  • Loading branch information
doriable authored Mar 18, 2022
1 parent d40fe85 commit 2603cca
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 82 deletions.
12 changes: 10 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package connect

import (
"context"
"errors"
"io"
"net/http"
)

Expand Down Expand Up @@ -62,7 +64,10 @@ func NewClient[Req, Res any](
unarySpec := config.newSpecification(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
sender, receiver := protocolClient.NewStream(ctx, unarySpec, request.Header())
if err := sender.Send(request.Any()); err != nil {
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
if err := sender.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) {
_ = sender.Close(err)
_ = receiver.Close()
return nil, err
Expand Down Expand Up @@ -124,7 +129,10 @@ func (c *Client[Req, Res]) CallServerStream(
) (*ServerStreamForClient[Res], error) {
sender, receiver := c.newStream(ctx, StreamTypeServer)
mergeHeaders(sender.Header(), req.header)
if err := sender.Send(req.Msg); err != nil {
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
if err := sender.Send(req.Msg); err != nil && !errors.Is(err, io.EOF) {
_ = sender.Close(err)
_ = receiver.Close()
return nil, err
Expand Down
90 changes: 73 additions & 17 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,22 @@ func TestServer(t *testing.T) {
})
t.Run("sum_error", func(t *testing.T) {
stream := client.Sum(context.Background())
assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1}))
_, err := stream.CloseAndReceive()
err := stream.Send(&pingv1.SumRequest{Number: 1})
if err != nil {
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
}
_, err = stream.CloseAndReceive()
assert.Equal(t, connect.CodeOf(err), connect.CodeInvalidArgument)
})
t.Run("sum_close_and_receive_without_send", func(t *testing.T) {
stream := client.Sum(context.Background())
stream.RequestHeader().Set(clientHeader, headerValue)
got, err := stream.CloseAndReceive()
assert.Nil(t, err)
assert.Zero(t, *got.Msg) // receive header only stream
assert.Equal(t, got.Header().Get(handlerHeader), headerValue)
})
}
testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) {
t.Run("count_up", func(t *testing.T) {
Expand Down Expand Up @@ -184,17 +196,68 @@ func TestServer(t *testing.T) {
})
t.Run("cumsum_error", func(t *testing.T) {
stream := client.CumSum(context.Background())
if !expectSuccess {
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 42}))
if err := stream.Send(&pingv1.CumSumRequest{Number: 42}); err != nil {
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
}
// We didn't send the headers the server expects, so we should now get an
// error.
_, err := stream.Receive()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeInvalidArgument)
})
t.Run("cumsum_empty_stream", func(t *testing.T) {
stream := client.CumSum(context.Background())
stream.RequestHeader().Set(clientHeader, headerValue)
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
// Deliberately closing with calling Send to test the behavior of Receive.
// This test case is based on the grpc interop tests.
assert.Nil(t, stream.CloseSend())
res, err := stream.Receive()
assert.Nil(t, res)
assert.True(t, errors.Is(err, io.EOF))
assert.Nil(t, stream.CloseReceive()) // clean-up the stream
})
t.Run("cumsum_cancel_after_first_response", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
stream := client.CumSum(ctx)
stream.RequestHeader().Set(clientHeader, headerValue)
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
var got []int64
expect := []int64{42}
if err := stream.Send(&pingv1.CumSumRequest{Number: 42}); err != nil {
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
}
msg, err := stream.Receive()
assert.Nil(t, err)
got = append(got, msg.Sum)
cancel()
_, err = stream.Receive()
assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled)
assert.Equal(t, got, expect)
})
t.Run("cumsum_cancel_before_send", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
stream := client.CumSum(ctx)
stream.RequestHeader().Set(clientHeader, headerValue)
// Send once first, since `makeRequest` does check for context errors
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8}))
cancel()
// On a subsequent send, ensure that we are still catching context cancellations without
// calling makeRequest.
err := stream.Send(&pingv1.CumSumRequest{Number: 19})
assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled, assert.Sprintf("%v", err))
})
}
testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) {
t.Run("errors", func(t *testing.T) {
Expand Down Expand Up @@ -311,10 +374,12 @@ func TestHeaderBasic(t *testing.T) {
}

func failNoHTTP2(t testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
err := stream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err) // haven't gotten response back yet
if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
}
assert.Nil(t, stream.CloseSend())
_, err = stream.Receive()
_, err := stream.Receive()
assert.NotNil(t, err) // should be 505
assert.True(
t,
Expand Down Expand Up @@ -387,9 +452,6 @@ func (p pingServer) Sum(
}
var sum int64
for stream.Receive() {
if err := ctx.Err(); err != nil {
return err
}
sum += stream.Msg().Number
}
if stream.Err() != nil {
Expand Down Expand Up @@ -418,9 +480,6 @@ func (p pingServer) CountUp(
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for i := int64(1); i <= req.Msg.Number; i++ {
if err := ctx.Err(); err != nil {
return err
}
if err := stream.Send(&pingv1.CountUpResponse{Number: i}); err != nil {
return err
}
Expand All @@ -441,9 +500,6 @@ func (p pingServer) CumSum(
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for {
if err := ctx.Err(); err != nil {
return err
}
msg, err := stream.Receive()
if errors.Is(err, io.EOF) {
return nil
Expand Down
110 changes: 47 additions & 63 deletions protocol_grpc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,11 @@ type duplexClientStream struct {

// send: guarded by prepareOnce because we can't initialize this state until
// the first call to Send.
prepareOnce sync.Once
writer *io.PipeWriter
marshaler marshaler
sentAtLeastOnce bool
header http.Header
trailer http.Header
prepareOnce sync.Once
writer *io.PipeWriter
marshaler marshaler
header http.Header
trailer http.Header

// receive goroutine
web bool
Expand Down Expand Up @@ -107,29 +106,31 @@ func (cs *duplexClientStream) Trailer() http.Header {
}

func (cs *duplexClientStream) Send(message any) error {
defer func() { cs.sentAtLeastOnce = true }()
// stream.makeRequest hands the read side of the pipe off to net/http and
// waits to establish the response stream. There's a small class of errors we
// can catch before writing to the request body, so we don't want to start
// writing to the stream until we're sure that we're actually waiting on the
// server. This makes user-visible behavior more predictable: for example, if
// they've configured the server's base URL as "hwws://acme.com", they'll
// always get an invalid URL error on their first attempt to send.
cs.prepareOnce.Do(func() {
requestPrepared := make(chan struct{})
go cs.makeRequest(requestPrepared)
<-requestPrepared
})
cs.prepareRequests()
// Before we receive the message, check if the context has been canceled.
if err := cs.ctx.Err(); err != nil {
cs.setResponseError(err)
return err
}
// Calling Marshal writes data to the send stream. It's safe to do this while
// makeRequest is running, because we're writing to our side of the pipe
// (which is safe to do while net/http reads from the other side).
if err := cs.marshaler.Marshal(message); err != nil {
return cs.improveMarshalerError(err)
if !errors.Is(err, io.ErrClosedPipe) {
return err
}
// In all other cases, we return a io.EOF, similar to gRPC and the user
// will get the state of the stream through Receive.
return NewError(CodeUnknown, io.EOF)
}
return nil
}

func (cs *duplexClientStream) CloseSend(_ error) error {
// Even if Send was never called, we need to make an HTTP request. This ensures
// that we've sent any headers to the server and that we got an HTTP response body
// from the server for Receive to unmarshal from if called.
cs.prepareRequests()
// The user calls CloseSend to indicate that they're done sending data. All
// we do here is write to the pipe and close it, so it's safe to do this
// while makeRequest is running. (This method takes an error to accommodate
Expand All @@ -145,8 +146,10 @@ func (cs *duplexClientStream) CloseSend(_ error) error {
// CloseSend automatically rather than requiring the user to do it.
if cs.web {
if err := cs.marshaler.MarshalWebTrailers(cs.trailer); err != nil {
_ = cs.writer.Close()
return cs.improveMarshalerError(err)
if !errors.Is(err, io.ErrClosedPipe) {
_ = cs.writer.Close()
return err
}
}
}
if err := cs.writer.Close(); err != nil {
Expand All @@ -166,6 +169,11 @@ func (cs *duplexClientStream) Receive(message any) error {
// The stream is already closed or corrupted.
return err
}
// Before we receive the message, check if the context has been canceled.
if err := cs.ctx.Err(); err != nil {
cs.setResponseError(err)
return err
}
// Consume one message from the response stream.
err := cs.unmarshaler.Unmarshal(message)
if err != nil {
Expand Down Expand Up @@ -220,6 +228,21 @@ func (cs *duplexClientStream) ResponseTrailer() http.Header {
return cs.response.Trailer
}

// stream.makeRequest hands the read side of the pipe off to net/http and
// waits to establish the response stream. There's a small class of errors we
// can catch before writing to the request body, so we don't want to start
// writing to the stream until we're sure that we're actually waiting on the
// server. This makes user-visible behavior more predictable: for example, if
// they've configured the server's base URL as "hwws://acme.com", they'll
// always get an invalid URL error on their first attempt to send.
func (cs *duplexClientStream) prepareRequests() {
cs.prepareOnce.Do(func() {
requestPrepared := make(chan struct{})
go cs.makeRequest(requestPrepared)
<-requestPrepared
})
}

func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
// This runs concurrently with Send and CloseSend. Receive and CloseReceive
// wait on cs.responseReady, so we can't race with them.
Expand All @@ -244,16 +267,10 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
req.Trailer = cs.trailer
}

// Before we send off a request, check if we're already out of time.
if err := cs.ctx.Err(); err != nil {
cs.setRequestError(err)
close(prepared)
return
}

// At this point, we've caught all the errors we can - it's time to send data
// to the server. Unblock Send.
close(prepared)

// Once we send a message to the server, they send a message back and
// establish the receive side of the stream.
res, err := cs.httpClient.Do(req)
Expand Down Expand Up @@ -301,36 +318,6 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
}
}

func (cs *duplexClientStream) improveMarshalerError(err error) error {
if !errors.Is(err, io.ErrClosedPipe) {
return err
}
// err is an io.ErrClosedPipe, which means that net/http closed the request
// body. It only does this when we can't send more data. In these cases, we
// expect a response from the server or some network error.
if !cs.sentAtLeastOnce {
// This is the first time we're marshaling data to the network. Because
// of the vagaries of goroutine scheduling, it's possible that we've
// already gotten a response from the server. However, user-visible
// behavior is more deterministic if we pretend that we're still waiting
// for the response and only return errors that were caught before we
// called HTTPClient.Do.
if requestErr := cs.getRequestError(); requestErr != nil {
return requestErr
}
return nil
}
// We've already sent at least one message to the server. Wait for a
// response so we can give the user a more informative error than "pipe
// closed".
<-cs.responseReady
if responseErr := cs.getRequestOrResponseError(); responseErr != nil {
return responseErr
}
// As a last resort, return the original error as-is. We shouldn't get here.
return err
}

func (cs *duplexClientStream) setRequestError(err error) {
cs.setError(err, true /* isRequest */)
}
Expand Down Expand Up @@ -359,10 +346,7 @@ func (cs *duplexClientStream) setError(err error, isRequest bool) {
// We've already hit an error, so we should stop writing to the request body.
// It's safe to call Close more than once and/or concurrently (calls after
// the first are no-ops), so it's okay for us to call this even though
// net/http sometimes closes the reader too. We do _not_ want to close the
// pipe with CloseWithError(err), because that will prevent errors returned
// from cs.marshaler.Marshal from going through the logic in
// improveMarshalerError and reduce determinism.
// net/http sometimes closes the reader too.
//
// It's safe to ignore the returned error here. Under the hood, Close calls
// CloseWithError, which is documented to always return nil.
Expand Down

0 comments on commit 2603cca

Please sign in to comment.