Skip to content

Commit

Permalink
Merge branch 'master' into service_config_pr
Browse files Browse the repository at this point in the history
  • Loading branch information
lyuxuan authored May 7, 2017
2 parents ea230c7 + ffa4ec7 commit 3ea2870
Show file tree
Hide file tree
Showing 35 changed files with 1,654 additions and 614 deletions.
12 changes: 6 additions & 6 deletions Documentation/gomock-example.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Mocking Service for gRPC

[Example code](https://github.com/grpc/grpc-go/tree/master/examples/helloworld/mock)
[Example code](https://github.com/grpc/grpc-go/tree/master/examples/helloworld/mock_helloworld)

## Why?

To test client-side logic without the overhead of connecting to a real server. Mocking enables users to write light-weight unit tests to check functionalities on client-side without invoking RPC calls to a server.
To test client-side logic without the overhead of connecting to a real server. Mocking enables users to write light-weight unit tests to check functionalities on client-side without invoking RPC calls to a server.

## Idea: Mock the client stub that connects to the server.

Expand Down Expand Up @@ -61,18 +61,18 @@ Along with this the generated code has a method to create an instance of this st
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient
```

The user code uses this function to create an instance of the struct greeterClient which then can be used to make rpc calls to the server.
The user code uses this function to create an instance of the struct greeterClient which then can be used to make rpc calls to the server.
We will mock this interface GreeterClient and use an instance of that mock to make rpc calls. These calls instead of going to server will return pre-determined values.

To create a mock we’ll use [mockgen](https://github.com/golang/mock#running-mockgen).
From the directory ``` examples/helloworld/mock/ ``` run ``` mockgen google.golang.org/grpc/examples/helloworld/helloworld GreeterClient > mock_helloworld/hw_mock.go ```
From the directory ``` examples/helloworld/ ``` run ``` mockgen google.golang.org/grpc/examples/helloworld/helloworld GreeterClient > mock_helloworld/hw_mock.go ```

Notice that in the above command we specify GreeterClient as the interface to be mocked.

The user test code can import the package generated by mockgen along with library package gomock to write unit tests around client-side logic.
```Go
import "github.com/golang/mock/gomock"
import hwmock "google.golang.org/grpc/examples/helloworld/mock/mock_helloworld"
import hwmock "google.golang.org/grpc/examples/helloworld/mock_helloworld"
```

An instance of the mocked interface can be created as:
Expand All @@ -85,7 +85,7 @@ This mocked object can be programmed to expect calls to its methods and return p
mockGreeterClient.EXPECT().SayHello(
gomock.Any(), // expect any value for first parameter
gomock.Any(), // expect any value for second parameter
).Return(&helloworld.HelloReply{Message: “Mocked RPC”}, nil)
).Return(&helloworld.HelloReply{Message: “Mocked RPC”}, nil)
```

gomock.Any() indicates that the parameter can have any value or type. We can indicate specific values for built-in types with gomock.Eq().
Expand Down
3 changes: 3 additions & 0 deletions balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ func (rr *roundRobin) Notify() <-chan []Address {
func (rr *roundRobin) Close() error {
rr.mu.Lock()
defer rr.mu.Unlock()
if rr.done {
return errBalancerClosed
}
rr.done = true
if rr.w != nil {
rr.w.Close()
Expand Down
64 changes: 39 additions & 25 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
}

// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
}
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
Expand All @@ -120,7 +116,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
}
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err)
return Errorf(codes.Internal, "grpc: %v", err)
}
if len(outBuf) > *c.maxSendMessageSize {
return nil, Errorf(codes.ResourceExhausted, "Sent message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
Expand All @@ -134,10 +130,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err
return err
}
// Sent successfully.
return stream, nil
return nil
}

// Invoke sends the RPC request on the wire and returns after response is received.
Expand Down Expand Up @@ -194,6 +190,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
}()
}
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
Expand All @@ -203,17 +200,15 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
FailFast: c.failFast,
}
sh.HandleRPC(ctx, begin)
}
defer func() {
if sh != nil {
defer func() {
end := &stats.End{
Client: true,
EndTime: time.Now(),
Error: e,
}
sh.HandleRPC(ctx, end)
}
}()
}()
}
topts := &transport.Options{
Last: true,
Delay: false,
Expand Down Expand Up @@ -257,33 +252,49 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, t, args, topts)
stream, err = t.NewStream(ctx, callHdr)
if err != nil {
if put != nil {
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
put()
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
}
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
// Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
}
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
Expand All @@ -293,8 +304,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
t.CloseStream(stream, nil)
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
return stream.Status().Err()
}
Expand Down
30 changes: 24 additions & 6 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ var (
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
// DEPRECATED: Please use context.DeadlineExceeded instead. This error will be
// removed in Q1 2017.
// DEPRECATED: Please use context.DeadlineExceeded instead.
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")

// errNoTransportSecurity indicates that there is no transport security
Expand All @@ -79,6 +78,8 @@ var (
errConnClosing = errors.New("grpc: the connection is closing")
// errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable")
// errBalancerClosed indicates that the balancer is closed.
errBalancerClosed = errors.New("grpc: balancer is closed")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
)
Expand Down Expand Up @@ -109,6 +110,22 @@ const (
// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)

// WithInitialWindowSize returns a DialOption which sets the value for initial window size on a stream.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialWindowSize(s int32) DialOption {
return func(o *dialOptions) {
o.copts.InitialWindowSize = s
}
}

// WithInitialConnWindowSize returns a DialOption which sets the value for initial window size on a connection.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialConnWindowSize(s int32) DialOption {
return func(o *dialOptions) {
o.copts.InitialConnWindowSize = s
}
}

// WithMaxMsgSize Deprecated: use WithMaxReceiveMessageSize instead.
func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(WithMaxReceiveMessageSize(s))
Expand Down Expand Up @@ -212,7 +229,7 @@ func WithTransportCredentials(creds credentials.TransportCredentials) DialOption
}

// WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC.
// credentials and places auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
return func(o *dialOptions) {
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
Expand Down Expand Up @@ -249,7 +266,7 @@ func WithStatsHandler(h stats.Handler) DialOption {
}
}

// FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors.
// FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on non-temporary dial errors.
// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
// address and won't try to reconnect.
// The default value of FailOnNonTempDialError is false.
Expand Down Expand Up @@ -303,10 +320,9 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
}

// DialContext creates a client connection to the given target. ctx can be used to
// cancel or expire the pending connecting. Once this function returns, the
// cancel or expire the pending connection. Once this function returns, the
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns.
// This is the EXPERIMENTAL API.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
Expand Down Expand Up @@ -681,13 +697,15 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
}
if !ok {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put()
}
return nil, nil, errConnClosing
}
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put()
}
return nil, nil, err
Expand Down
1 change: 1 addition & 0 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func (p protoCodec) Marshal(v interface{}) ([]byte, error) {
func (p protoCodec) Unmarshal(data []byte, v interface{}) error {
cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data)
v.(proto.Message).Reset()
err := cb.Unmarshal(v.(proto.Message))
cb.SetBuf(nil)
protoBufferPool.Put(cb)
Expand Down
8 changes: 4 additions & 4 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ func NewTLS(c *tls.Config) TransportCredentials {
return tc
}

// NewClientTLSFromCert constructs a TLS from the input certificate for client.
// NewClientTLSFromCert constructs TLS credentials from the input certificate for client.
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
}

// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
// NewClientTLSFromFile constructs TLS credentials from the input certificate file for client.
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
Expand All @@ -218,12 +218,12 @@ func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredent
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
}

// NewServerTLSFromCert constructs a TLS from the input certificate for server.
// NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
}

// NewServerTLSFromFile constructs a TLS from the input certificate file and key
// NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
// file for server.
func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package mock_test
package mock_helloworld_test

import (
"fmt"
Expand All @@ -8,7 +8,7 @@ import (
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
helloworld "google.golang.org/grpc/examples/helloworld/helloworld"
hwmock "google.golang.org/grpc/examples/helloworld/mock/mock_helloworld"
hwmock "google.golang.org/grpc/examples/helloworld/mock_helloworld"
)

// rpcMsg implements the gomock.Matcher interface
Expand Down
Loading

0 comments on commit 3ea2870

Please sign in to comment.