Skip to content

Commit

Permalink
Split Request and Handlers from service package
Browse files Browse the repository at this point in the history
Also create corehandlers package to house core handler functions.
  • Loading branch information
lsegal authored and jasdel committed Aug 20, 2015
1 parent 77b2bb7 commit 0d59abf
Show file tree
Hide file tree
Showing 222 changed files with 6,725 additions and 6,207 deletions.
12 changes: 12 additions & 0 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"net/http"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
)
Expand Down Expand Up @@ -73,6 +74,8 @@ type Config struct {
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool

SleepDelay func(time.Duration)
}

// NewConfig returns a new Config pointer that can be chained with builder methods to
Expand Down Expand Up @@ -161,6 +164,11 @@ func (c *Config) WithS3ForcePathStyle(force bool) *Config {
return c
}

func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
c.SleepDelay = fn
return c
}

// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
Expand Down Expand Up @@ -215,6 +223,10 @@ func (c Config) Merge(other *Config) *Config {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}

if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}

return &dst
}

Expand Down
59 changes: 19 additions & 40 deletions aws/service/handler_functions.go → aws/corehandlers/handlers.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package service
package corehandlers

import (
"bytes"
Expand All @@ -9,16 +9,12 @@ import (
"net/url"
"regexp"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)

var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}

// Interface for matching types which also have a Len method.
type lener interface {
Len() int
Expand All @@ -27,7 +23,7 @@ type lener interface {
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
func BuildContentLength(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
Expand All @@ -41,10 +37,10 @@ func BuildContentLength(r *Request) {
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
Expand All @@ -54,14 +50,14 @@ func BuildContentLength(r *Request) {
}

// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
func UserAgentHandler(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", aws.SDKName+"/"+aws.SDKVersion)
}

var reStatusCode = regexp.MustCompile(`^(\d{3})`)

// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
func SendHandler(r *request.Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
Expand Down Expand Up @@ -96,7 +92,7 @@ func SendHandler(r *Request) {
}

// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
func ValidateResponseHandler(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
Expand All @@ -105,54 +101,37 @@ func ValidateResponseHandler(r *Request) {

// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
func AfterRetryHandler(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
r.Retryable = aws.Bool(r.Service.ShouldRetry(r))
r.Retryable = aws.Bool(r.ShouldRetry(r))
}

if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
r.RetryDelay = r.RetryRules(r)
fmt.Println(r.Service.Config.SleepDelay)
r.Service.Config.SleepDelay(r.RetryDelay)

// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
if r.IsErrorExpired() {
r.Service.Config.Credentials.Expire()
}

r.RetryCount++
r.Error = nil
}
}

var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)

// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
func ValidateEndpointHandler(r *request.Request) {
if r.Service.SigningRegion == "" && aws.StringValue(r.Service.Config.Region) == "" {
r.Error = ErrMissingRegion
r.Error = aws.ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
r.Error = aws.ErrMissingEndpoint
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package service
package corehandlers_test

import (
"fmt"
Expand All @@ -10,32 +10,35 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
)

func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(aws.NewConfig().WithRegion("us-west-2"))
svc := service.New(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
svc.Handlers.Validate.PushBack(corehandlers.ValidateEndpointHandler)

req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()

assert.NoError(t, err)
}

func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc := service.New(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
svc.Handlers.Validate.PushBack(corehandlers.ValidateEndpointHandler)

req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()

assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
assert.Equal(t, aws.ErrMissingRegion, err)
}

type mockCredsProvider struct {
Expand All @@ -55,24 +58,22 @@ func (m *mockCredsProvider) IsExpired() bool {
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&aws.Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: aws.Int(1)})
svc := service.New(&aws.Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: aws.Int(1)})

svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
svc.Handlers.AfterRetry.PushBack(corehandlers.AfterRetryHandler)

assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)

req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()

assert.True(t, svc.Config.Credentials.IsExpired())
Expand All @@ -90,14 +91,14 @@ func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, e
}

func TestSendHandlerError(t *testing.T) {
svc := NewService(&aws.Config{
svc := service.New(&aws.Config{
HTTPClient: &http.Client{
Transport: &testSendHandlerTransport{},
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBack(SendHandler)
r := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
svc.Handlers.Send.PushBack(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)

r.Send()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package service
package corehandlers

import (
"fmt"
"reflect"
"strings"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)

// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
func ValidateParameters(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package service
package corehandlers_test

import (
"testing"
Expand All @@ -7,13 +7,19 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
)

var testSvc = func() *Service {
s := &Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
var testSvc = func() *service.Service {
s := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
},
}
return s
}()
Expand Down Expand Up @@ -49,15 +55,15 @@ func TestNoErrors(t *testing.T) {
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}

req := NewRequest(testSvc, &Operation{}, input, nil)
ValidateParameters(req)
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)
assert.NoError(t, req.Error)
}

func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := NewRequest(testSvc, &Operation{}, input, nil)
ValidateParameters(req)
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)

assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
Expand All @@ -75,8 +81,8 @@ func TestNestedMissingRequiredParameters(t *testing.T) {
OptionalStruct: &ConditionalStructShape{},
}

req := NewRequest(testSvc, &Operation{}, input, nil)
ValidateParameters(req)
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParameters(req)

assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
Expand Down
3 changes: 2 additions & 1 deletion aws/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ var DefaultConfig = aws.NewConfig().
WithHTTPClient(http.DefaultClient).
WithMaxRetries(aws.DefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff)
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)
6 changes: 3 additions & 3 deletions aws/ec2metadata/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ package ec2metadata
import (
"path"

"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/request"
)

// GetMetadata uses the path provided to request
func (c *Client) GetMetadata(p string) (string, error) {
op := &service.Operation{
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
}

output := &metadataOutput{}
req := service.NewRequest(c.Service, op, nil, output)
req := request.New(c.Service.ServiceInfo, c.Service.Handlers, c.Service.Retryer, op, nil, output)

return output.Content, req.Send()
}
Expand Down
Loading

0 comments on commit 0d59abf

Please sign in to comment.