Skip to content

Commit

Permalink
added special case behavior for MapClaims so they aren't all weird
Browse files Browse the repository at this point in the history
  • Loading branch information
dgrijalva committed Apr 12, 2016
1 parent 4ec621a commit fb4ca74
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
14 changes: 10 additions & 4 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Parser struct {
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, keyFunc, &MapClaims{})
return p.ParseWithClaims(tokenString, keyFunc, MapClaims{})
}

func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) {
Expand All @@ -42,6 +42,7 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla

// parse Claims
var claimBytes []byte
token.Claims = claims

if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
Expand All @@ -50,12 +51,17 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla
if p.UseJSONNumber {
dec.UseNumber()
}
if err = dec.Decode(&claims); err != nil {
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
}

token.Claims = claims

// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
Expand Down
45 changes: 33 additions & 12 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var jwtTestData = []struct {
name string
tokenString string
keyfunc jwt.Keyfunc
claims jwt.MapClaims
claims jwt.Claims
valid bool
errors uint32
parser *jwt.Parser
Expand Down Expand Up @@ -106,7 +106,7 @@ var jwtTestData = []struct {
"invalid signing method",
"",
defaultKeyFunc,
map[string]interface{}{"foo": "bar"},
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
&jwt.Parser{ValidMethods: []string{"HS256"}},
Expand All @@ -115,7 +115,7 @@ var jwtTestData = []struct {
"valid signing method",
"",
defaultKeyFunc,
map[string]interface{}{"foo": "bar"},
jwt.MapClaims{"foo": "bar"},
true,
0,
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
Expand All @@ -124,7 +124,18 @@ var jwtTestData = []struct {
"JSON Number",
"",
defaultKeyFunc,
map[string]interface{}{"foo": json.Number("123.4")},
jwt.MapClaims{"foo": json.Number("123.4")},
true,
0,
&jwt.Parser{UseJSONNumber: true},
},
{
"Standard Claims",
"",
defaultKeyFunc,
&jwt.StandardClaims{
ExpiresAt: time.Now().Add(time.Second * 10).Unix(),
},
true,
0,
&jwt.Parser{UseJSONNumber: true},
Expand All @@ -141,7 +152,7 @@ func init() {
}
}

func makeSample(c jwt.MapClaims) string {
func makeSample(c jwt.Claims) string {
keyData, e := ioutil.ReadFile("test/sample_key")
if e != nil {
panic(e.Error())
Expand All @@ -162,20 +173,30 @@ func makeSample(c jwt.MapClaims) string {
}

func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
// If the token string is blank, use helper function to generate string
if data.tokenString == "" {
data.tokenString = makeSample(data.claims)
}

// Parse the token
var token *jwt.Token
var err error
if data.parser != nil {
token, err = data.parser.Parse(data.tokenString, data.keyfunc)
} else {
token, err = jwt.Parse(data.tokenString, data.keyfunc)
var parser = data.parser
if parser == nil {
parser = new(jwt.Parser)
}
// Figure out correct claims type
switch data.claims.(type) {
case jwt.MapClaims:
token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, jwt.MapClaims{})
case *jwt.StandardClaims:
token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, &jwt.StandardClaims{})
}

if !reflect.DeepEqual(&data.claims, token.Claims) {
// Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}

Expand Down Expand Up @@ -218,13 +239,13 @@ func TestParseRequest(t *testing.T) {

r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
token, err := jwt.ParseFromRequestWithClaims(r, data.keyfunc, &jwt.MapClaims{})
token, err := jwt.ParseFromRequestWithClaims(r, data.keyfunc, jwt.MapClaims{})

if token == nil {
t.Errorf("[%v] Token was not found: %v", data.name, err)
continue
}
if !reflect.DeepEqual(&data.claims, token.Claims) {
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
if data.valid && err != nil {
Expand Down

0 comments on commit fb4ca74

Please sign in to comment.