Skip to content

Commit

Permalink
Merge pull request grpc#863 from menghanl/setTrailer
Browse files Browse the repository at this point in the history
Allow multiple setTrailer
  • Loading branch information
iamqizhao authored Sep 29, 2016
2 parents dffd7cd + 1247834 commit c2983be
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 26 deletions.
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// It may be called at most once from a unary RPC handler. The ctx is the RPC
// handler's Context or one derived from it.
// When called more than once, all the provided metadata will be merged.
// The ctx is the RPC handler's Context or one derived from it.
func SetTrailer(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 {
return nil
Expand Down
4 changes: 2 additions & 2 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ type ServerStream interface {
// after SendProto. It fails if called multiple times or if
// called after SendProto.
SendHeader(metadata.MD) error
// SetTrailer sets the trailer metadata which will be sent with the
// RPC status.
// SetTrailer sets the trailer metadata which will be sent with the RPC status.
// When called more than once, all the provided metadata will be merged.
SetTrailer(metadata.MD)
Stream
}
Expand Down
109 changes: 97 additions & 12 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ var (
"tkey1": []string{"trailerValue1"},
"tkey2": []string{"trailerValue2"},
}
testTrailerMetadata2 = metadata.MD{
"tkey1": []string{"trailerValue12"},
"tkey2": []string{"trailerValue22"},
}
// capital "Key" is illegal in HTTP/2.
malformedHTTP2Metadata = metadata.MD{
"Key": []string{"foo"},
Expand All @@ -89,8 +93,9 @@ var (
var raceMode bool // set by race_test.go in race mode

type testServer struct {
security string // indicate the authentication protocol used by this server.
earlyFail bool // whether to error out the execution of a service handler prematurely.
security string // indicate the authentication protocol used by this server.
earlyFail bool // whether to error out the execution of a service handler prematurely.
multipleSetTrailer bool // whether to call setTrailer multiple times.
}

func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
Expand Down Expand Up @@ -136,14 +141,21 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil)
}
grpc.SetTrailer(ctx, testTrailerMetadata)
if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
}
if s.multipleSetTrailer {
if err := grpc.SetTrailer(ctx, testTrailerMetadata2); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata2, err)
}
}
}
pr, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("failed to get peer from ctx")
return nil, grpc.Errorf(codes.DataLoss, "failed to get peer from ctx")
}
if pr.Addr == net.Addr(nil) {
return nil, fmt.Errorf("failed to get peer address")
return nil, grpc.Errorf(codes.DataLoss, "failed to get peer address")
}
if s.security != "" {
// Check Auth info
Expand All @@ -153,13 +165,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
authType = info.AuthType()
serverName = info.State.ServerName
default:
return nil, fmt.Errorf("Unknown AuthInfo type")
return nil, grpc.Errorf(codes.Unauthenticated, "Unknown AuthInfo type")
}
if authType != s.security {
return nil, fmt.Errorf("Wrong auth type: got %q, want %q", authType, s.security)
return nil, grpc.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security)
}
if serverName != "x.test.youtube.com" {
return nil, fmt.Errorf("Unknown server name %q", serverName)
return nil, grpc.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName)
}
}
// Simulate some service delay.
Expand Down Expand Up @@ -229,9 +241,12 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
md, ok := metadata.FromContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return fmt.Errorf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
if s.multipleSetTrailer {
stream.SetTrailer(testTrailerMetadata2)
}
stream.SetTrailer(md)
}
for {
in, err := stream.Recv()
Expand Down Expand Up @@ -1193,6 +1208,76 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
}
}

func TestMultipleSetTrailerUnaryRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testMultipleSetTrailerUnaryRPC(t, e)
}
}

func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())

const (
argSize = 1
respSize = 1
)
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}

req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
var trailer metadata.MD
ctx := metadata.NewContext(context.Background(), testMetadata)
if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
if !reflect.DeepEqual(trailer, expectedTrailer) {
t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
}
}

func TestMultipleSetTrailerStreamingRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testMultipleSetTrailerStreamingRPC(t, e)
}
}

func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())

ctx := metadata.NewContext(context.Background(), testMetadata)
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
}
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err)
}

trailer := stream.Trailer()
expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
if !reflect.DeepEqual(trailer, expectedTrailer) {
t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
}
}

// TestMalformedHTTP2Metedata verfies the returned error when the client
// sends an illegal metadata.
func TestMalformedHTTP2Metadata(t *testing.T) {
Expand Down Expand Up @@ -1601,8 +1686,8 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
}
}
trailerMD := stream.Trailer()
if !reflect.DeepEqual(testMetadata, trailerMD) {
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
if !reflect.DeepEqual(testTrailerMetadata, trailerMD) {
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testTrailerMetadata)
}
}

Expand Down
12 changes: 2 additions & 10 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ package transport // import "google.golang.org/grpc/transport"

import (
"bytes"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -287,19 +286,12 @@ func (s *Stream) StatusDesc() string {
return s.statusDesc
}

// ErrIllegalTrailerSet indicates that the trailer has already been set or it
// is too late to do so.
var ErrIllegalTrailerSet = errors.New("transport: trailer has been set")

// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can only be called at most once. Server side only.
// by the server. This can be called multiple times. Server side only.
func (s *Stream) SetTrailer(md metadata.MD) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.trailer != nil {
return ErrIllegalTrailerSet
}
s.trailer = md.Copy()
s.trailer = metadata.Join(s.trailer, md)
return nil
}

Expand Down

0 comments on commit c2983be

Please sign in to comment.