Skip to content

Commit

Permalink
Adding dial options for PerRPCCredentials (grpc#1225)
Browse files Browse the repository at this point in the history
* Adding dial options for PerRPCCredentials

* Added tests for PerRPCCredentials

* Post-review updates

* post-review updates
  • Loading branch information
MakMukhi authored May 11, 2017
1 parent 07bd943 commit 88a73d3
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 9 deletions.
3 changes: 3 additions & 0 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
callHdr.Creds = c.creds
}

gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
Expand Down
11 changes: 11 additions & 0 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (

"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
Expand Down Expand Up @@ -141,6 +142,7 @@ type callInfo struct {
trailerMD metadata.MD
peer *peer.Peer
traceInfo traceInfo // in trace.go
creds credentials.PerRPCCredentials
}

var defaultCallInfo = callInfo{failFast: true}
Expand Down Expand Up @@ -207,6 +209,15 @@ func FailFast(failFast bool) CallOption {
})
}

// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call.
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
return beforeCall(func(c *callInfo) error {
c.creds = creds
return nil
})
}

// The format of the payload: compressed or not?
type payloadFormat uint8

Expand Down
3 changes: 3 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
callHdr.Creds = c.creds
}
var trInfo traceInfo
if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
Expand Down
121 changes: 121 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ type test struct {
serverInitialConnWindowSize int32
clientInitialWindowSize int32
clientInitialConnWindowSize int32
perRPCCreds credentials.PerRPCCredentials

// srv and srvAddr are set once startServer is called.
srv *grpc.Server
Expand Down Expand Up @@ -621,6 +622,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.clientInitialConnWindowSize > 0 {
opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize))
}
if te.perRPCCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
Expand Down Expand Up @@ -3984,3 +3988,120 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) {
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
}
}

var (
// test authdata
authdata = map[string]string{
"test-key": "test-value",
"test-key2-bin": string([]byte{1, 2, 3}),
}
)

type testPerRPCCredentials struct{}

func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return authdata, nil
}

func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
return false
}

func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx, fmt.Errorf("didn't find metadata in context")
}
for k, vwant := range authdata {
vgot, ok := md[k]
if !ok {
return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
}
if vgot[0] != vwant {
return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
}
}
return ctx, nil
}

func TestPerRPCCredentialsViaDialOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaDialOptions(t, e)
}
}

func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
te := newTest(t, e)
te.tapHandle = authHandle
te.perRPCCreds = testPerRPCCredentials{}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}

func TestPerRPCCredentialsViaCallOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaCallOptions(t, e)
}
}

func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
te := newTest(t, e)
te.tapHandle = authHandle
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}

func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
}
}

func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
te := newTest(t, e)
te.perRPCCreds = testPerRPCCredentials{}
// When credentials are provided via both dial options and call options,
// we apply both sets.
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx, fmt.Errorf("couldn't find metadata in context")
}
for k, vwant := range authdata {
vgot, ok := md[k]
if !ok {
return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
}
if len(vgot) != 2 {
return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
}
if vgot[0] != vwant || vgot[1] != vwant {
return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
}
}
return ctx, nil
}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}
52 changes: 43 additions & 9 deletions transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type http2Client struct {
// The scheme used: https if TLS is on, http otherwise.
scheme string

isSecure bool

creds []credentials.PerRPCCredentials

// Boolean to keep track of reading activity on transport.
Expand Down Expand Up @@ -181,7 +183,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
conn.Close()
}
}(conn)
var authInfo credentials.AuthInfo
var (
isSecure bool
authInfo credentials.AuthInfo
)
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
Expand All @@ -191,6 +196,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
}
isSecure = true
}
kp := opts.KeepaliveParams
// Validate keepalive parameters.
Expand Down Expand Up @@ -230,6 +236,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
scheme: scheme,
state: reachable,
activeStreams: make(map[uint32]*Stream),
isSecure: isSecure,
creds: opts.PerRPCCredentials,
maxStreams: defaultMaxStreamsClient,
streamsQuota: newQuotaPool(defaultMaxStreamsClient),
Expand Down Expand Up @@ -335,8 +342,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
pr.AuthInfo = t.authInfo
}
ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string)
for _, c := range t.creds {
var (
authData = make(map[string]string)
audience string
)
// Create an audience string only if needed.
if len(t.creds) > 0 || callHdr.Creds != nil {
// Construct URI required to get auth request metadata.
var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
Expand All @@ -347,17 +358,39 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 {
return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
pos = len(callHdr.Method)
}
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
audience = "https://" + callHdr.Host + port + callHdr.Method[:pos]
}
for _, c := range t.creds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
authData[k] = v
}
}
callAuthData := make(map[string]string)
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton")
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2
k = strings.ToLower(k)
callAuthData[k] = v
}
}
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
Expand Down Expand Up @@ -435,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}

for k, v := range authData {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, v := range callAuthData {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
var (
hasMD bool
Expand Down
3 changes: 3 additions & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ type CallHdr struct {
// outbound message.
SendCompress string

// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials

// Flush indicates whether a new stream command should be sent
// to the peer without waiting for the first data. This is
// only a hint. The transport may modify the flush decision
Expand Down

0 comments on commit 88a73d3

Please sign in to comment.