Skip to content

Commit

Permalink
aws/service/rds: Don't auto-sign (preSignUrl) for same region operati…
Browse files Browse the repository at this point in the history
…ons (aws#2631)

Fixes the RDS API client's customization to not generate a presigned URL for same region calls for CreateDBInstanceReadReplica and CopyDBSnapshot.

Update of aws#1847
  • Loading branch information
jasdel authored May 30, 2019
1 parent 93181a0 commit 0172821
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 66 deletions.
21 changes: 21 additions & 0 deletions service/rds/customizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ func copyDBSnapshotPresign(r *request.Request) {
}

originParams.DestinationRegion = r.Config.Region

// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}

newParams := awsutil.CopyOf(r.Params).(*CopyDBSnapshotInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -60,6 +66,11 @@ func createDBInstanceReadReplicaPresign(r *request.Request) {
}

originParams.DestinationRegion = r.Config.Region
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}

newParams := awsutil.CopyOf(r.Params).(*CreateDBInstanceReadReplicaInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -72,6 +83,11 @@ func copyDBClusterSnapshotPresign(r *request.Request) {
}

originParams.DestinationRegion = r.Config.Region
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}

newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -84,6 +100,11 @@ func createDBClusterPresign(r *request.Request) {
}

originParams.DestinationRegion = r.Config.Region
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}

newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand Down
263 changes: 197 additions & 66 deletions service/rds/customizations_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// +build go1.9

package rds

import (
"fmt"
"io/ioutil"
"net/url"
"regexp"
"strings"
"testing"
"time"

Expand All @@ -15,8 +16,7 @@ import (
"github.com/aws/aws-sdk-go/awstesting/unit"
)

func TestPresignWithPresignNotSet(t *testing.T) {
reqs := map[string]*request.Request{}
func TestCopyDBSnapshotNoPanic(t *testing.T) {
svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})

f := func() {
Expand All @@ -27,95 +27,226 @@ func TestPresignWithPresignNotSet(t *testing.T) {
if paniced, p := awstesting.DidPanic(f); paniced {
t.Errorf("expect no panic, got %v", p)
}
}

reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})

reqs[opCreateDBInstanceReadReplica], _ = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
})
func TestPresignCrossRegionRequest(t *testing.T) {
const targetRegion = "us-west-2"

svc := New(unit.Session, &aws.Config{Region: aws.String(targetRegion)})

const regexPattern = `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+`

cases := map[string]struct {
Req *request.Request
Assert func(*testing.T, string)
}{
opCopyDBSnapshot: {
Req: func() *request.Request {
req, _ := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCopyDBSnapshot, targetRegion)),
},
opCreateDBInstanceReadReplica: {
Req: func() *request.Request {
req, _ := svc.CreateDBInstanceReadReplicaRequest(
&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCreateDBInstanceReadReplica, targetRegion)),
},
opCopyDBClusterSnapshot: {
Req: func() *request.Request {
req, _ := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCopyDBClusterSnapshot, targetRegion)),
},
opCreateDBCluster: {
Req: func() *request.Request {
req, _ := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-1"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
})
return req
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCreateDBCluster, targetRegion)),
},
opCopyDBSnapshot + " same region": {
Req: func() *request.Request {
req, _ := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-2"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsEmpty(),
},
opCreateDBInstanceReadReplica + " same region": {
Req: func() *request.Request {
req, _ := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-2"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsEmpty(),
},
opCopyDBClusterSnapshot + " same region": {
Req: func() *request.Request {
req, _ := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-2"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
})
return req
}(),
Assert: assertAsEmpty(),
},
opCreateDBCluster + " same region": {
Req: func() *request.Request {
req, _ := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-2"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
})
return req
}(),
Assert: assertAsEmpty(),
},
opCopyDBSnapshot + " presignURL set": {
Req: func() *request.Request {
req, _ := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCreateDBInstanceReadReplica + " presignURL set": {
Req: func() *request.Request {
req, _ := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCopyDBClusterSnapshot + " presignURL set": {
Req: func() *request.Request {
req, _ := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCreateDBCluster + " presignURL set": {
Req: func() *request.Request {
req, _ := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-1"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
}

for op, req := range reqs {
req.Sign()
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
for name, c := range cases {
t.Run(name, func(t *testing.T) {
if err := c.Req.Sign(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))

u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))

exp := fmt.Sprintf(`^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=us-west-2.+`, op)
if re, a := regexp.MustCompile(exp), u; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
c.Assert(t, u)
})
}
}

func TestPresignWithPresignSet(t *testing.T) {
func TestPresignWithSourceNotSet(t *testing.T) {
reqs := map[string]*request.Request{}
svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})

f := func() {
// Doesn't panic on nil input
req, _ := svc.CopyDBSnapshotRequest(nil)
req.Sign()
}
if paniced, p := awstesting.DidPanic(f); paniced {
t.Errorf("expect no panic, got %v", p)
}

reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("presignedURL"),
})

reqs[opCreateDBInstanceReadReplica], _ = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("presignedURL"),
})

for _, req := range reqs {
req.Sign()
_, err := req.Presign(5 * time.Minute)
if err != nil {
t.Fatal(err)
}
}
}

b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
func assertAsRegexMatch(exp string) func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
if e, a := "presignedURL", u; !strings.Contains(a, e) {
t.Errorf("expect %s to be in %s", e, a)
if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
}
}

func TestPresignWithSourceNotSet(t *testing.T) {
reqs := map[string]*request.Request{}
svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
func assertAsEmpty() func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

f := func() {
// Doesn't panic on nil input
req, _ := svc.CopyDBSnapshotRequest(nil)
req.Sign()
}
if paniced, p := awstesting.DidPanic(f); paniced {
t.Errorf("expect no panic, got %v", p)
if len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
}

reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})
func assertAsEqual(expect string) func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

for _, req := range reqs {
_, err := req.Presign(5 * time.Minute)
if err != nil {
t.Fatal(err)
if e, a := expect, v; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
}

0 comments on commit 0172821

Please sign in to comment.