Skip to content

Commit

Permalink
feat: make compatible with v1 and v2 api of amazon pay
Browse files Browse the repository at this point in the history
  • Loading branch information
mukezhz committed Feb 1, 2024
1 parent 13d8c0e commit 92c3d93
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
7 changes: 4 additions & 3 deletions amazonpay/checkout_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/mukezhz/aws-services-go/amazonpay/signing"
"net/http"

"github.com/mukezhz/aws-services-go/amazonpay/signing"
)

type CreateCheckoutSessionRequest struct {
Expand All @@ -31,11 +32,11 @@ func (c *CreateCheckoutSessionRequest) ToPayload() (string, error) {

// GenerateButtonSignature method
func (c *Client) GenerateButtonSignature(payload string) (string, error) {
stringToSign, err := signing.StringToSign(payload)
stringToSign, err := signing.StringToSign(payload, c.Algorithm)
if err != nil {
return "", err
}
signature, err := signing.Sign(c.PrivateKey, stringToSign)
signature, err := signing.Sign(c.PrivateKey, stringToSign, c.salt)
if err != nil {
return "", err
}
Expand Down
50 changes: 29 additions & 21 deletions amazonpay/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/mukezhz/aws-services-go/amazonpay/signing"
"io"
"net/http"
"net/url"
"runtime"
"time"

"github.com/mukezhz/aws-services-go/amazonpay/signing"

"github.com/rs/xid"
)

Expand Down Expand Up @@ -47,34 +47,42 @@ type Client struct {
Region string
Sandbox bool
HTTPClient *http.Client
Algorithm string
salt int
endpoint *url.URL
}

endpoint *url.URL
type ClientInput struct {
PublicKeyID string
PrivateKey []byte
Region string
Sandbox bool
HTTPClient *http.Client
Version string
}

// New returns a new pay client instance.
func New(publicKeyID string, privateKey []byte, region string, sandbox bool, httpClient *http.Client) (*Client, error) {
if publicKeyID == "" {
return nil, errors.New("missing publicKeyID")
}
if privateKey == nil {
return nil, errors.New("missing privateKey")
}
if region == "" {
return nil, errors.New("missing region")
}
func New(input ClientInput) (*Client, error) {
c := &Client{
PublicKeyID: publicKeyID,
PrivateKey: privateKey,
Region: region,
Sandbox: sandbox,
HTTPClient: httpClient,
PublicKeyID: input.PublicKeyID,
PrivateKey: input.PrivateKey,
Region: input.Region,
Sandbox: input.Sandbox,
HTTPClient: input.HTTPClient,
}
endpointURL := c.createEndpointURL()
u, err := url.Parse(endpointURL)
if err != nil {
return nil, err
}
c.endpoint = u
version := input.Version
if input.Version != "v1" {
version = "v2"
}
al := signing.GetAlgorithm(version)
c.Algorithm = al.Algorithm
c.salt = al.SaltLength
return c, nil
}

Expand Down Expand Up @@ -122,16 +130,16 @@ func (c *Client) NewRequest(method, path string, body interface{}) (*http.Reques
if err != nil {
return nil, err
}
stringToSign, err := signing.StringToSign(canonicalRequest)
stringToSign, err := signing.StringToSign(canonicalRequest, c.Algorithm)
if err != nil {
return nil, err
}
signature, err := signing.Sign(c.PrivateKey, stringToSign)
signature, err := signing.Sign(c.PrivateKey, stringToSign, c.salt)
if err != nil {
return nil, err
}
signedHeaders := signing.SignedHeaders(req)
authValue := signing.AuthHeaderValue(c.PublicKeyID, signedHeaders, signature)
authValue := signing.AuthHeaderValue(c.PublicKeyID, signedHeaders, signature, c.Algorithm)
req.Header.Set("Authorization", authValue)

return req, nil
Expand Down
26 changes: 20 additions & 6 deletions amazonpay/signing/signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ const (
SaltLengthForAlgorithmV2 = 32
)

type AlgorithmStuff struct {
Algorithm string
SaltLength int
}

func GetAlgorithm(version string) AlgorithmStuff {
switch version {
case "v1":
return AlgorithmStuff{Algorithm, SaltLengthForAlgorithm}
default:
return AlgorithmStuff{AlgorithmV2, SaltLengthForAlgorithmV2}
}
}

// CanonicalRequest =
//
// HTTPRequestMethod + '\n' +
Expand Down Expand Up @@ -115,15 +129,15 @@ func RequestPayload(r *http.Request) ([]byte, error) {
return b, err
}

func StringToSign(canonicalRequest string) (string, error) {
func StringToSign(canonicalRequest, algorithm string) (string, error) {
hexencode, err := HexEncodeSHA256Hash([]byte(canonicalRequest))
if err != nil {
return "", err
}
return fmt.Sprintf("%s\n%s", AlgorithmV2, hexencode), nil
return fmt.Sprintf("%s\n%s", algorithm, hexencode), nil
}

func Sign(privateKeyData []byte, stringToSign string) (string, error) {
func Sign(privateKeyData []byte, stringToSign string, salt int) (string, error) {
block, _ := pem.Decode(privateKeyData)
if block == nil {
return "", errors.New("invalid private key data")
Expand All @@ -138,7 +152,7 @@ func Sign(privateKeyData []byte, stringToSign string) (string, error) {
}
hashed := sha256.Sum256([]byte(stringToSign))
signature, err := rsa.SignPSS(rand.Reader, key, crypto.SHA256, hashed[:], &rsa.PSSOptions{
SaltLength: SaltLengthForAlgorithmV2,
SaltLength: salt,
})
if err != nil {
return "", err
Expand All @@ -155,8 +169,8 @@ func HexEncodeSHA256Hash(body []byte) (string, error) {
return fmt.Sprintf("%x", hash.Sum(nil)), err
}

func AuthHeaderValue(publicKeyID, signedHeaders, signature string) string {
return fmt.Sprintf("%s PublicKeyId=%s, SignedHeaders=%s, Signature=%s", AlgorithmV2, publicKeyID, signedHeaders, signature)
func AuthHeaderValue(publicKeyID, signedHeaders, signature, algorithm string) string {
return fmt.Sprintf("%s PublicKeyId=%s, SignedHeaders=%s, Signature=%s", algorithm, publicKeyID, signedHeaders, signature)
}

func trimString(s string) string {
Expand Down

0 comments on commit 92c3d93

Please sign in to comment.