Skip to content
This repository has been archived by the owner on May 24, 2023. It is now read-only.

Commit

Permalink
add custom KeyFunc to middleware configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
svs-valiton committed Oct 19, 2021
1 parent a628715 commit 3bc48d2
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 10 deletions.
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jwtware.New(config ...jwtware.Config) func(*fiber.Ctx) error
| KeyRefreshRateLimit | `*time.Duration` |KeyRefreshRateLimit limits the rate at which refresh requests are granted. | `nil` |
| KeyRefreshTimeout | `*time.Duration` |KeyRefreshTimeout is the duration for the context used to create the HTTP request for a refresh of the JWKs. | `1min` |
| KeyRefreshUnknownKID | `bool` |KeyRefreshUnknownKID indicates that the JWKs refresh request will occur every time a kid that isn't cached is seen. | `false` |
| KeyFunc | `func() jwt.Keyfunc` |KeyFunc defines a user-defined function that supplies the public key for a token validation. | `jwtKeyFunc` |


### HS256 Example
Expand Down Expand Up @@ -243,3 +244,53 @@ The RS256 is actually identical to the HS256 test above.

### JWKs Test
The tests are identical to basic `JWT` tests above, with exception that `KeySetURL` to valid public keys collection in JSON format should be supplied.

### Custom KeyFunc example

KeyFunc defines a user-defined function that supplies the public key for a token validation.
The function shall take care of verifying the signing algorithm and selecting the proper key.
A user-defined KeyFunc can be useful if tokens are issued by an external party.

When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
This is one of the three options to provide a token validation key.
The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
Required if neither SigningKeys nor SigningKey is provided.
Default to an internal implementation verifying the signing algorithm and selecting the proper key.

```go
package main

import (
"fmt"
"github.com/gofiber/fiber/v2"

jwtware "github.com/gofiber/jwt/v3"
"github.com/golang-jwt/jwt/v4"
)

func main() {
app := fiber.New()

app.Use(jwtware.New(jwtware.Config{
KeyFunc: customKeyFunc(),
}))

app.Get("/ok", func(c *fiber.Ctx) error {
return c.SendString("OK")
})
}

func customKeyFunc() jwt.Keyfunc {
return func(t *jwt.Token) (interface{}, error) {
// Always check the signing method
if t.Method.Alg() != jwtware.HS256 {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
}

// TODO custom implementation of loading signing key like from a database
signingKey := "secret"

return []byte(signingKey), nil
}
}
```
27 changes: 19 additions & 8 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ type Config struct {
// Optional. Default: "Bearer".
AuthScheme string

keyFunc jwt.Keyfunc
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
}

// makeCfg function will check correctness of supplied configuration
Expand All @@ -123,7 +132,7 @@ func makeCfg(config []Config) (cfg Config) {
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired JWT")
}
}
if cfg.SigningKey == nil && len(cfg.SigningKeys) == 0 && cfg.KeySetURL == "" {
if cfg.SigningKey == nil && len(cfg.SigningKeys) == 0 && cfg.KeySetURL == "" && cfg.KeyFunc == nil {
panic("Fiber: JWT middleware requires signing key or url where to download one")
}
if cfg.SigningMethod == "" && cfg.KeySetURL == "" {
Expand All @@ -144,13 +153,15 @@ func makeCfg(config []Config) (cfg Config) {
if cfg.KeyRefreshTimeout == nil {
cfg.KeyRefreshTimeout = &defaultKeyRefreshTimeout
}
if cfg.KeySetURL != "" {
jwks := &KeySet{
Config: &cfg,
if cfg.KeyFunc == nil {
if cfg.KeySetURL != "" {
jwks := &KeySet{
Config: &cfg,
}
cfg.KeyFunc = jwks.keyFunc()
} else {
cfg.KeyFunc = jwtKeyFunc(cfg)
}
cfg.keyFunc = jwks.keyFunc()
} else {
cfg.keyFunc = jwtKeyFunc(cfg)
}
return cfg
}
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func New(config ...Config) fiber.Handler {
var token *jwt.Token

if _, ok := cfg.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, cfg.keyFunc)
token, err = jwt.Parse(auth, cfg.KeyFunc)
} else {
t := reflect.ValueOf(cfg.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, cfg.keyFunc)
token, err = jwt.ParseWithClaims(auth, claims, cfg.KeyFunc)
}
if err == nil && token.Valid {
// Store user information from token into context.
Expand Down
46 changes: 46 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package jwtware_test

import (
"fmt"
"github.com/golang-jwt/jwt/v4"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -231,3 +233,47 @@ func TestJwkFromServer(t *testing.T) {
utils.AssertEqual(t, 200, resp.StatusCode)
}
}

func TestCustomKeyFunc(t *testing.T) {
t.Parallel()

defer func() {
// Assert
if err := recover(); err != nil {
t.Fatalf("Middleware should not panic")
}
}()

test := hamac[0]
// Arrange
app := fiber.New()

app.Use(jwtware.New(jwtware.Config{
KeyFunc: customKeyFunc(),
}))

app.Get("/ok", func(c *fiber.Ctx) error {
return c.SendString("OK")
})

req := httptest.NewRequest("GET", "/ok", nil)
req.Header.Add("Authorization", "Bearer "+test.Token)

// Act
resp, err := app.Test(req)

// Assert
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}

func customKeyFunc() jwt.Keyfunc {
return func(t *jwt.Token) (interface{}, error) {
// Always check the signing method
if t.Method.Alg() != jwtware.HS256 {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
}

return []byte(defaultSigningKey), nil
}
}

0 comments on commit 3bc48d2

Please sign in to comment.