Skip to content

Commit

Permalink
Grpclb: Support server list expiration (grpc#962)
Browse files Browse the repository at this point in the history
grpclb: Support server list expiration
  • Loading branch information
iamqizhao authored and menghanl committed Nov 16, 2016
1 parent 941cc89 commit 8551858
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 30 deletions.
80 changes: 65 additions & 15 deletions grpclb/grpclb.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"errors"
"fmt"
"sync"
"time"

"golang.org/x/net/context"
"google.golang.org/grpc"
Expand Down Expand Up @@ -93,16 +94,17 @@ type addrInfo struct {
}

type balancer struct {
r naming.Resolver
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []grpc.Address
rbs []remoteBalancerInfo
addrs []*addrInfo
next int
waitCh chan struct{}
done bool
r naming.Resolver
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []grpc.Address
rbs []remoteBalancerInfo
addrs []*addrInfo
next int
waitCh chan struct{}
done bool
expTimer *time.Timer
}

func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
Expand Down Expand Up @@ -180,14 +182,39 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
return nil
}

func (b *balancer) serverListExpire(seq int) {
b.mu.Lock()
defer b.mu.Unlock()
// TODO: gRPC interanls do not clear the connections when the server list is stale.
// This means RPCs will keep using the existing server list until b receives new
// server list even though the list is expired. Revisit this behavior later.
if b.done || seq < b.seq {
return
}
b.next = 0
b.addrs = nil
// Ask grpc internals to close all the corresponding connections.
b.addrCh <- nil
}

func convertDuration(d *lbpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}

func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
if l == nil {
return
}
servers := l.GetServers()
expiration := convertDuration(l.GetExpirationInterval())
var (
sl []*addrInfo
addrs []grpc.Address
)
for _, s := range servers {
// TODO: Support ExpirationInterval
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
addr := grpc.Address{
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
Expand All @@ -209,11 +236,20 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
b.next = 0
b.addrs = sl
b.addrCh <- addrs
if b.expTimer != nil {
b.expTimer.Stop()
b.expTimer = nil
}
if expiration > 0 {
b.expTimer = time.AfterFunc(expiration, func() {
b.serverListExpire(seq)
})
}
}
return
}

func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
Expand All @@ -226,8 +262,6 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
b.mu.Unlock()
return
}
b.seq++
seq := b.seq
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
Expand Down Expand Up @@ -260,6 +294,14 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
if err != nil {
break
}
b.mu.Lock()
if b.done || seq < b.seq {
b.mu.Unlock()
return
}
b.seq++ // tick when receiving a new list of servers.
seq = b.seq
b.mu.Unlock()
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
Expand Down Expand Up @@ -326,10 +368,15 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
return
}
b.mu.Lock()
b.seq++ // tick when getting a new balancer address
seq := b.seq
b.next = 0
b.mu.Unlock()
go func(cc *grpc.ClientConn) {
lbc := lbpb.NewLoadBalancerClient(cc)
for {
if retry := b.callRemoteBalancer(lbc); !retry {
if retry := b.callRemoteBalancer(lbc, seq); !retry {
cc.Close()
return
}
Expand Down Expand Up @@ -497,6 +544,9 @@ func (b *balancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
b.done = true
if b.expTimer != nil {
b.expTimer.Stop()
}
if b.waitCh != nil {
close(b.waitCh)
}
Expand Down
124 changes: 109 additions & 15 deletions grpclb/grpclb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,16 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error {
}

type remoteBalancer struct {
servers *lbpb.ServerList
done chan struct{}
sls []*lbpb.ServerList
intervals []time.Duration
done chan struct{}
}

func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer {
func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
servers: servers,
done: make(chan struct{}),
sls: sls,
intervals: intervals,
done: make(chan struct{}),
}
}

Expand All @@ -186,13 +188,16 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer)
if err := stream.Send(resp); err != nil {
return err
}
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: b.servers,
},
}
if err := stream.Send(resp); err != nil {
return err
for k, v := range b.sls {
time.Sleep(b.intervals[k])
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: v,
},
}
if err := stream.Send(resp); err != nil {
return err
}
}
<-b.done
return nil
Expand Down Expand Up @@ -268,7 +273,9 @@ func TestGRPCLB(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -343,7 +350,9 @@ func TestDropRequest(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -413,7 +422,9 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand All @@ -439,3 +450,86 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
}
cc.Close()
}

func TestServerExpiration(t *testing.T) {
// Start a backend.
beLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen %v", err)
}
beAddr := strings.Split(beLis.Addr().String(), ":")
bePort, err := strconv.Atoi(beAddr[1])
backends := startBackends(t, besn, beLis)
defer stopBackends(backends)

// Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to create the listener for the load balancer %v", err)
}
lbCreds := &serverNameCheckCreds{
sn: lbsn,
}
lb := grpc.NewServer(grpc.Creds(lbCreds))
if err != nil {
t.Fatalf("Failed to generate the port number %v", err)
}
be := &lbpb.Server{
IpAddress: []byte(beAddr[0]),
Port: int32(bePort),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
exp := &lbpb.Duration{
Seconds: 0,
Nanos: 100000000, // 100ms
}
var sls []*lbpb.ServerList
sl := &lbpb.ServerList{
Servers: bes,
ExpirationInterval: exp,
}
sls = append(sls, sl)
sl = &lbpb.ServerList{
Servers: bes,
}
sls = append(sls, sl)
var intervals []time.Duration
intervals = append(intervals, 0)
intervals = append(intervals, 500*time.Millisecond)
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
}()
defer func() {
ls.stop()
lb.Stop()
}()
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
addr: lbLis.Addr().String(),
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
}
// Sleep and wake up when the first server list gets expired.
time.Sleep(150 * time.Millisecond)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
}
// A non-failfast rpc should be succeeded after the second server list is received from
// the remote load balancer.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
}
cc.Close()
}

0 comments on commit 8551858

Please sign in to comment.