Skip to content

Commit

Permalink
add new interface Extractor for making token extraction pluggable
Browse files Browse the repository at this point in the history
  • Loading branch information
dgrijalva committed Jun 6, 2016
1 parent b6d201f commit bb45bfc
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 24 deletions.
75 changes: 75 additions & 0 deletions request/extractor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package request

import (
"errors"
"net/http"
)

// Errors
var (
ErrNoTokenInRequest = errors.New("no token present in request")
)

// Interface for extracting a token from an HTTP request
type Extractor interface {
ExtractToken(*http.Request) (string, error)
}

// Extract token from headers
type HeaderExtractor []string

func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, header := range e {
if ah := req.Header.Get(header); ah != "" {
return ah, nil
}
}
return "", ErrNoTokenInRequest
}

// Extract token from request args
type ArgumentExtractor []string

func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) {
// Make sure form is parsed
req.ParseMultipartForm(10e6)

// loop over arg names and return the first one that contains data
for _, arg := range e {
if ah := req.Form.Get(arg); ah != "" {
return ah, nil
}
}

return "", ErrNoTokenInRequest
}

// Tries extractors in order until one works or an error occurs
type MultiExtractor []Extractor

func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, extractor := range e {
if tok, err := extractor.ExtractToken(req); tok != "" {
return tok, nil
} else if err != ErrNoTokenInRequest {
return "", err
}
}
return "", ErrNoTokenInRequest
}

// Wrap an Extractor in this to post-process the value before it's handed off
type PostExtractionFilter struct {
Extractor
Filter func(string) (string, error)
}

func (e *PostExtractionFilter) ExtractToken(req *http.Request) (string, error) {
if tok, err := e.Extractor.ExtractToken(req); tok != "" {
return e.Filter(tok)
} else {
return "", err
}
}
25 changes: 25 additions & 0 deletions request/oauth2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package request

import (
"strings"
)

// Extract Authorization header and strip 'Bearer ' from it
var AuthorizationHeaderExtractor = &PostExtractionFilter{
HeaderExtractor{"Authorization"},
func(tok string) (string, error) {
// Should be a bearer token
if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " {
return tok[7:], nil
}
return tok, nil
},
}

// Extractor for OAuth2 access tokens
var OAuth2Extractor = &MultiExtractor{
// Look for authorization token first
AuthorizationHeaderExtractor,
// Extract access_token from form or GET argument
&ArgumentExtractor{"access_token"},
}
31 changes: 8 additions & 23 deletions request/request.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,24 @@
package request

import (
"errors"
"github.com/dgrijalva/jwt-go"
"net/http"
"strings"
)

// Errors
var (
ErrNoTokenInRequest = errors.New("no token present in request")
)

// Try to find the token in an http.Request.
// This method will call ParseMultipartForm if there's no token in the header.
// Currently, it looks in the Authorization header as well as
// looking for an 'access_token' request parameter in req.Form.
func ParseFromRequest(req *http.Request, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, jwt.MapClaims{}, keyFunc)
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, extractor, jwt.MapClaims{}, keyFunc)
}

func ParseFromRequestWithClaims(req *http.Request, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
// Look for an Authorization header
if ah := req.Header.Get("Authorization"); ah != "" {
// Should be a bearer token
if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " {
return jwt.ParseWithClaims(ah[7:], claims, keyFunc)
}
}

// Look for "access_token" parameter
req.ParseMultipartForm(10e6)
if tokStr := req.Form.Get("access_token"); tokStr != "" {
return jwt.ParseWithClaims(tokStr, claims, keyFunc)
func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
// Extract token from request
tokStr, err := extractor.ExtractToken(req)
if err != nil {
return nil, err
}

return nil, ErrNoTokenInRequest
return jwt.ParseWithClaims(tokStr, claims, keyFunc)
}
2 changes: 1 addition & 1 deletion request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestParseRequest(t *testing.T) {
r.Header.Set(k, tokenString)
}
}
token, err := ParseFromRequestWithClaims(r, jwt.MapClaims{}, keyfunc)
token, err := ParseFromRequestWithClaims(r, OAuth2Extractor, jwt.MapClaims{}, keyfunc)

if token == nil {
t.Errorf("[%v] Token was not found: %v", data.name, err)
Expand Down

0 comments on commit bb45bfc

Please sign in to comment.