Skip to content

Commit

Permalink
Added HostClient.PendingRequests(), which may be used for balancing l…
Browse files Browse the repository at this point in the history
…oad among multiple HostClient instances
  • Loading branch information
valyala committed Sep 21, 2016
1 parent ec59ce3 commit a52a42a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
16 changes: 16 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ type HostClient struct {

readerPool sync.Pool
writerPool sync.Pool

pendingRequests uint64
}

type clientConn struct {
Expand Down Expand Up @@ -957,16 +959,27 @@ var errorChPool sync.Pool
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) Do(req *Request, resp *Response) error {
atomic.AddUint64(&c.pendingRequests, 1)
retry, err := c.do(req, resp)
if err != nil && retry && isIdempotent(req) {
_, err = c.do(req, resp)
}
if err == io.EOF {
err = ErrConnectionClosed
}
atomic.AddUint64(&c.pendingRequests, ^uint64(0))
return err
}

// PendingRequests returns the current number of requests the client
// is executing.
//
// This function may be used for balancing load among multiple HostClient
// instances.
func (c *HostClient) PendingRequests() int {
return int(atomic.LoadUint64(&c.pendingRequests))
}

func isIdempotent(req *Request) bool {
return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut()
}
Expand Down Expand Up @@ -1961,6 +1974,9 @@ func (c *pipelineConnClient) logger() Logger {
// This number may exceed MaxPendingRequests*MaxConns by up to two times, since
// each connection to the server may keep up to MaxPendingRequests requests
// in the queue before sending them to the server.
//
// This function may be used for balancing load among multiple PipelineClient
// instances.
func (c *PipelineClient) PendingRequests() int {
c.connClientsLock.Lock()
n := 0
Expand Down
94 changes: 94 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,100 @@ func TestClientDoTimeoutDisableNormalizing(t *testing.T) {
}
}

func TestHostClientPendingRequests(t *testing.T) {
const concurrency = 10
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
Handler: func(ctx *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Fatalf("unexpected error: %s", err)
}
close(serverStopCh)
}()

c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}

pendingRequests := c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}

resultCh := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()

if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
resultCh <- fmt.Errorf("unexpected error: %s", err)
return
}

if resp.StatusCode() != StatusOK {
resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
return
}
resultCh <- nil
}()
}

// wait while all the requests reach server
for i := 0; i < concurrency; i++ {
select {
case <-readyCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}

pendingRequests = c.PendingRequests()
if pendingRequests != concurrency {
t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency)
}

// unblock request handlers on the server and wait until all the requests are finished.
close(doneCh)
for i := 0; i < concurrency; i++ {
select {
case err := <-resultCh:
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}

pendingRequests = c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}

// stop the server
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}

func TestHostClientMaxConnsWithDeadline(t *testing.T) {
var (
emptyBodyCount uint8
Expand Down

0 comments on commit a52a42a

Please sign in to comment.