Skip to content

Commit

Permalink
Merge pull request fagongzi#205 from zuiqiangqishao/master
Browse files Browse the repository at this point in the history
给jwt插件添加验证csrf令牌机制
  • Loading branch information
zhangxu19830126 authored Mar 5, 2020
2 parents 3839931 + 4212bf3 commit f790779
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 6 deletions.
9 changes: 9 additions & 0 deletions examples/jwt.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"tokenLookup": "token lookup, [header|query|cookie:Authorization]",
"authSchema": "jwt schema, [Bearer]",
"renewTokenHeaderName": "the header name for new token in the response header",
"csrfHeaderName": "the header name for CSRFToken",
"redis": {
"addr": "127.0.0.1:6379",
"maxActive": "max connections, int",
Expand All @@ -17,6 +18,14 @@
"prefix": "the prefix of token in the redis"
}
},
{
"method": "token_and_csrf_in_redis",
"params": {
"prefix": "the prefix of token in the redis",
"csrf_white_method":"GET,OPTION",
"csrf_white_path":"/testinfo,/testdesc"
}
},
{
"method": "renew_by_redis",
"params": {
Expand Down
1 change: 1 addition & 0 deletions examples/jwt_a_simple_working_one.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"tokenLookup": "header:Authorization",
"authSchema": "Bearer",
"renewTokenHeaderName": "",
"csrfHeaderName": "_csrf",
"redis": {
"addr": "127.0.0.1:6379",
"maxActive": 1000,
Expand Down
107 changes: 101 additions & 6 deletions pkg/proxy/filter_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ import (

const (
// besides checking token is legitimate or not, it checks whether token exists in redis
actionTokenInRedis string = "token_in_redis"
actionTokenInRedis string = "token_in_redis"
// besides checking token is legitimate or not, it checks whether token exists in redis and the value is the csrfToken value
actionTokenAndCSRFInRedis string = "token_and_csrf_in_redis"
// update token's TTL
actionRenewByRaw string = "renew_by_raw"
actionRenewByRaw string = "renew_by_raw"
// update token's TTL and in the same time put new token in redis, previous token invalid
actionRenewByRedis string = "renew_by_redis"
actionRenewByRedis string = "renew_by_redis"
// fetch fields from token and put them in header which is redirected to a backend server who is unbeknownst to JWT
actionFetchToHeader string = "fetch_to_header"
actionFetchToCookie string = "fetch_to_cookie"
Expand All @@ -31,8 +33,21 @@ const (
)

var (
errJWTMissing = errors.New("missing jwt token")
errJWTInvalid = errors.New("invalid jwt token")
operationFail = errors.New("jwt filter operation fail")

errJWTMissing = errors.New("missing jwt token")
errJWTInvalid = errors.New("invalid jwt token")
errCSRFMissing = errors.New("missing csrf token")
errCSRFInvalid = errors.New("invalid csrf token")

//handler custom error to statusCode
customErrMaps = map[string]int{
errJWTMissing.Error(): fasthttp.StatusUnauthorized,
errJWTInvalid.Error(): fasthttp.StatusUnauthorized,
errCSRFMissing.Error(): fasthttp.StatusForbidden,
errCSRFInvalid.Error(): fasthttp.StatusForbidden,
redis.ErrNil.Error(): fasthttp.StatusForbidden, //redis nil err
}
)

type tokenGetter func(filter.Context) (string, error)
Expand All @@ -43,6 +58,7 @@ type JWTCfg struct {
Secret string `json:"secret"`
Method string `json:"method"`
TokenLookup string `json:"tokenLookup"`
CSRFHeaderName string `json:"csrfHeaderName"`
AuthSchema string `json:"authSchema"`
RenewTokenHeaderName string `json:"renewTokenHeaderName,omitempty"`
Redis *Redis `json:"redis,omitempty"`
Expand Down Expand Up @@ -70,6 +86,7 @@ type JWTFilter struct {
cfg *JWTCfg
secretBytes []byte
getter tokenGetter
csrfGetter tokenGetter
redisPool *redis.Pool
leaseTTLDuration time.Duration
signing *jwt.SigningMethodHMAC
Expand Down Expand Up @@ -123,12 +140,16 @@ func (f *JWTFilter) Pre(c filter.Context) (statusCode int, err error) {

for idx, act := range f.actions {
ok, err := act(f.actionArgs[idx], token, claims, c)
//Note: if return err is nil, the request will be continue,otherwise termination the request
if err != nil {
if code, custom := customErrMaps[err.Error()]; custom {
return code, err
}
return fasthttp.StatusInternalServerError, err
}

if !ok {
return fasthttp.StatusForbidden, nil
return fasthttp.StatusForbidden, operationFail
}
}

Expand Down Expand Up @@ -199,6 +220,7 @@ func (f *JWTFilter) initTokenLookup() {
case "cookie":
f.getter = jwtFromCookie(parts[1])
}
f.csrfGetter = csrfFromHeader(f.cfg.CSRFHeaderName)
}

func (f *JWTFilter) initActions() error {
Expand All @@ -208,6 +230,8 @@ func (f *JWTFilter) initActions() error {
switch c.Method {
case actionTokenInRedis:
f.actions = append(f.actions, f.tokenInRedisAction)
case actionTokenAndCSRFInRedis:
f.actions = append(f.actions, f.actionTokenAndCSRFInRedis)
case actionRenewByRaw:
f.actions = append(f.actions, f.renewByRawAction)
case actionRenewByRedis:
Expand Down Expand Up @@ -322,6 +346,67 @@ func (f *JWTFilter) tokenInRedisAction(args map[string]interface{}, token string
return value, err
}

func (f *JWTFilter) actionTokenAndCSRFInRedis(args map[string]interface{}, token string, claims jwt.MapClaims, c filter.Context) (bool, error) {
if f.cfg.Redis == nil {
return false, fmt.Errorf("redis not setting")
}

var buf bytes.Buffer
buf.WriteString(args["prefix"].(string))
buf.WriteString(token)
key := hack.SliceToString(buf.Bytes())

conn := f.getRedis()
value, err := redis.String(conn.Do("GET", key))
conn.Close()

if err != nil {
return false, err
}

//filter csrf white list
reqMethod := strings.ToUpper(string(c.OriginRequest().Method()))
reqPath := strings.ToUpper(string(c.OriginRequest().URI().Path()))
sep := ","

if args["csrf_white_method"] != nil {
noAuthMethods := strings.Split(args["csrf_white_method"].(string), sep)
for _, v := range noAuthMethods {
if v == "" {
continue
}
if strings.ToUpper(string(v)) == reqMethod {
return true, err
}
}
}

if args["csrf_white_path"] != nil {
noAuthUri := strings.Split(args["csrf_white_path"].(string), sep)
for _, v := range noAuthUri {
if v == "" {
continue
}
if strings.ToUpper(v) == reqPath {
return true, err
}
}
}

//check csrf
csrfToken, err := f.csrfGetter(c)
if err != nil {
if err == errJWTMissing {
err = errCSRFMissing
}
return false, err
}
if value != csrfToken {
return false, errCSRFInvalid
}
return true, err
}

func (f *JWTFilter) fetchToHeader(args map[string]interface{}, token string, claims jwt.MapClaims, c filter.Context) (bool, error) {
var buf bytes.Buffer
prefix := args["prefix"].(string)
Expand Down Expand Up @@ -380,3 +465,13 @@ func jwtFromHeader(header string, authScheme string) tokenGetter {
return "", errJWTMissing
}
}

func csrfFromHeader(header string) tokenGetter {
return func(c filter.Context) (string, error) {
value := string(c.OriginRequest().Request.Header.Peek(header))
if len(value) == 0 {
return "", errCSRFMissing
}
return value, nil
}
}

0 comments on commit f790779

Please sign in to comment.