Skip to content

Commit

Permalink
feat(gate): Add Gate.PermissionMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
TwiN committed Sep 14, 2022
1 parent 71e2f63 commit 2b43cb5
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 35 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,27 @@ have the `backup` permission:
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
```

If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect
an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`:
```go
router := mux.NewRouter()

userRouter := router.PathPrefix("/").Subrouter()
userRouter.Use(gate.Protect)
userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET")
userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET")
userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH")

adminRouter := router.PathPrefix("/").Subrouter()
adminRouter.Use(gate.PermissionMiddleware("admin"))
adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST")
adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE")
```


## Rate limiting
To add a rate limit of 100 requests per second:
```
```go
gate := g8.New().WithRateLimit(100)
```

Expand Down
89 changes: 55 additions & 34 deletions gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []
// If a custom token extractor is not specified, the token will be extracted from the Authorization header.
//
// For instance, if you're using a session cookie, you can extract the token from the cookie like so:
// authorizationService := g8.NewAuthorizationService()
// customTokenExtractorFunc := func(request *http.Request) string {
// sessionCookie, err := request.Cookie("session")
// if err != nil {
// return ""
// }
// return sessionCookie.Value
// }
// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
//
// authorizationService := g8.NewAuthorizationService()
// customTokenExtractorFunc := func(request *http.Request) string {
// sessionCookie, err := request.Cookie("session")
// if err != nil {
// return ""
// }
// return sessionCookie.Value
// }
// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
//
// You would normally use this with a client provider that matches whatever need you have.
// For example, if you're using a session cookie, your client provider would retrieve the user from the session ID
Expand All @@ -90,8 +91,8 @@ func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request
// WithRateLimit adds rate limiting to the Gate
//
// If you just want to use a gate for rate limiting purposes:
// gate := g8.New().WithRateLimit(50)
//
// gate := g8.New().WithRateLimit(50)
func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate {
gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond)
return gate
Expand All @@ -102,12 +103,13 @@ func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate {
// or lack thereof.
//
// Example:
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.Protect(yourHandler))
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.Protect(yourHandler))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) Protect(handler http.Handler) http.Handler {
Expand All @@ -118,12 +120,13 @@ func (gate *Gate) Protect(handler http.Handler) http.Handler {
// as well as a slice of permissions that must be met.
//
// Example:
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin")))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN")))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler {
Expand All @@ -147,12 +150,13 @@ func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string)
// permissions or lack thereof.
//
// Example:
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc))
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc {
Expand All @@ -163,12 +167,13 @@ func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc {
// token as well as a slice of permissions that must be met.
//
// Example:
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin")))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"}))
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin")))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"}))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc {
Expand Down Expand Up @@ -215,3 +220,19 @@ func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string {
}
return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ")
}

// PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used
// as a middleware for libraries that support such a feature.
//
// For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so:
//
// router := mux.NewRouter()
// router.Use(gate.PermissionMiddleware("admin"))
// router.Handle("/admin/handle", adminHandler)
//
// If you do not want to protect a router with a specific permission, you can use Gate.Protect instead.
func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return gate.ProtectWithPermissions(next, permissions)
}
}
34 changes: 34 additions & 0 deletions gate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,40 @@ func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testi
}
}

func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()

router := http.NewServeMux()
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
router.ServeHTTP(responseRecorder, request)

// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
// is protected by the permission "admin", the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}

func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()

router := http.NewServeMux()
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
router.ServeHTTP(responseRecorder, request)

// Since the client registered directly in the AuthorizationService has the permission "mod" and the
// testHandler is protected by the permission "admin", the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}

func TestGate_ProtectFuncWithInvalidToken(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
Expand Down

0 comments on commit 2b43cb5

Please sign in to comment.