Skip to content

Commit

Permalink
Add support for HandlerList.PushBackNamed()/PushFrontNamed()
Browse files Browse the repository at this point in the history
Also adds HandlerList.Remove() which allows removal of NamedHandlers.
  • Loading branch information
lsegal authored and jasdel committed Aug 20, 2015
1 parent e3d8359 commit 1fd4c57
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 38 deletions.
24 changes: 12 additions & 12 deletions aws/corehandlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type lener interface {
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *request.Request) {
var BuildContentLengthHandler = request.NamedHandler{"core.BuildContentLengthHandler", func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
Expand All @@ -47,17 +47,17 @@ func BuildContentLength(r *request.Request) {

r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
}}

// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *request.Request) {
var UserAgentHandler = request.NamedHandler{"core.UserAgentHandler", func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", aws.SDKName+"/"+aws.SDKVersion)
}
}}

var reStatusCode = regexp.MustCompile(`^(\d{3})`)

// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *request.Request) {
var SendHandler = request.NamedHandler{"core.SendHandler", func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
Expand Down Expand Up @@ -89,19 +89,19 @@ func SendHandler(r *request.Request) {
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
}
}
}}

// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *request.Request) {
var ValidateResponseHandler = request.NamedHandler{"core.ValidateResponseHandler", func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
}}

// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *request.Request) {
var AfterRetryHandler = request.NamedHandler{"core.AfterRetryHandler", func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
Expand All @@ -123,15 +123,15 @@ func AfterRetryHandler(r *request.Request) {
r.RetryCount++
r.Error = nil
}
}
}}

// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *request.Request) {
var ValidateEndpointHandler = request.NamedHandler{"core.ValidateEndpointHandler", func(r *request.Request) {
if r.Service.SigningRegion == "" && aws.StringValue(r.Service.Config.Region) == "" {
r.Error = aws.ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = aws.ErrMissingEndpoint
}
}
}}
8 changes: 4 additions & 4 deletions aws/corehandlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := service.New(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(corehandlers.ValidateEndpointHandler)
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)

req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
Expand All @@ -32,7 +32,7 @@ func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := service.New(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(corehandlers.ValidateEndpointHandler)
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)

req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
Expand Down Expand Up @@ -68,7 +68,7 @@ func TestAfterRetryRefreshCreds(t *testing.T) {
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(corehandlers.AfterRetryHandler)
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)

assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestSendHandlerError(t *testing.T) {
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBack(corehandlers.SendHandler)
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)

r.Send()
Expand Down
4 changes: 2 additions & 2 deletions aws/corehandlers/param_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *request.Request) {
var ValidateParametersHandler = request.NamedHandler{"core.ValidateParametersHandler", func(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
Expand All @@ -22,7 +22,7 @@ func ValidateParameters(r *request.Request) {
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
}}

// A validator validates values. Collects validations errors which occurs.
type validator struct {
Expand Down
6 changes: 3 additions & 3 deletions aws/corehandlers/param_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ func TestNoErrors(t *testing.T) {
}

req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)
corehandlers.ValidateParametersHandler.Fn(req)
assert.NoError(t, req.Error)
}

func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)
corehandlers.ValidateParametersHandler.Fn(req)

assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
Expand All @@ -82,7 +82,7 @@ func TestNestedMissingRequiredParameters(t *testing.T) {
}

req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)
corehandlers.ValidateParametersHandler.Fn(req)

assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
Expand Down
47 changes: 37 additions & 10 deletions aws/request/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,39 +47,66 @@ func (h *Handlers) Clear() {

// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
list []NamedHandler
}

// A NamedHandler is a struct that contains a name and function callback.
type NamedHandler struct {
Name string
Fn func(*Request)
}

// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
n.list = append([]NamedHandler{}, l.list...)
return n
}

// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
l.list = []NamedHandler{}
}

// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}

// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}

// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
}

// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
l.list = append(l.list, n)
}

// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
}

// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
}
}
l.list = newlist
}

// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
f.Fn(r)
}
}
13 changes: 13 additions & 0 deletions aws/request/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,16 @@ func TestMultipleHandlers(t *testing.T) {
t.Error("Expected handler to execute")
}
}

func TestNamedHandlers(t *testing.T) {
l := request.HandlerList{}
named := request.NamedHandler{"Name", func(r *request.Request) {}}
named2 := request.NamedHandler{"NotName", func(r *request.Request) {}}
l.PushBackNamed(named)
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
assert.Equal(t, 4, l.Len())
l.Remove(named)
assert.Equal(t, 2, l.Len())
}
14 changes: 7 additions & 7 deletions aws/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ func (s *Service) Initialize() {

s.Retryer = DefaultRetryer{s}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(corehandlers.ValidateEndpointHandler)
s.Handlers.Build.PushBack(corehandlers.UserAgentHandler)
s.Handlers.Sign.PushBack(corehandlers.BuildContentLength)
s.Handlers.Send.PushBack(corehandlers.SendHandler)
s.Handlers.AfterRetry.PushBack(corehandlers.AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(corehandlers.ValidateResponseHandler)
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
s.Handlers.Build.PushBackNamed(corehandlers.UserAgentHandler)
s.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
s.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
s.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
s.Handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)
if !aws.BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBack(corehandlers.ValidateParameters)
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateParametersHandler)
}
s.AddDebugHandlers()
s.buildEndpoint()
Expand Down

0 comments on commit 1fd4c57

Please sign in to comment.