From 2b43cb5c344478891b3d3d8d71deedbf147f0e71 Mon Sep 17 00:00:00 2001 From: TwiN Date: Wed, 14 Sep 2022 19:55:34 -0400 Subject: [PATCH] feat(gate): Add Gate.PermissionMiddleware --- README.md | 19 ++++++++++- gate.go | 89 ++++++++++++++++++++++++++++++++-------------------- gate_test.go | 34 ++++++++++++++++++++ 3 files changed, 107 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index f1461e3..056dc60 100644 --- a/README.md +++ b/README.md @@ -177,10 +177,27 @@ have the `backup` permission: router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"})) ``` +If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect +an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`: +```go +router := mux.NewRouter() + +userRouter := router.PathPrefix("/").Subrouter() +userRouter.Use(gate.Protect) +userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET") +userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET") +userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH") + +adminRouter := router.PathPrefix("/").Subrouter() +adminRouter.Use(gate.PermissionMiddleware("admin")) +adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST") +adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE") +``` + ## Rate limiting To add a rate limit of 100 requests per second: -``` +```go gate := g8.New().WithRateLimit(100) ``` diff --git a/gate.go b/gate.go index 08082c8..2dad0b8 100644 --- a/gate.go +++ b/gate.go @@ -66,15 +66,16 @@ func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody [] // If a custom token extractor is not specified, the token will be extracted from the Authorization header. // // For instance, if you're using a session cookie, you can extract the token from the cookie like so: -// authorizationService := g8.NewAuthorizationService() -// customTokenExtractorFunc := func(request *http.Request) string { -// sessionCookie, err := request.Cookie("session") -// if err != nil { -// return "" -// } -// return sessionCookie.Value -// } -// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) +// +// authorizationService := g8.NewAuthorizationService() +// customTokenExtractorFunc := func(request *http.Request) string { +// sessionCookie, err := request.Cookie("session") +// if err != nil { +// return "" +// } +// return sessionCookie.Value +// } +// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) // // You would normally use this with a client provider that matches whatever need you have. // For example, if you're using a session cookie, your client provider would retrieve the user from the session ID @@ -90,8 +91,8 @@ func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request // WithRateLimit adds rate limiting to the Gate // // If you just want to use a gate for rate limiting purposes: -// gate := g8.New().WithRateLimit(50) // +// gate := g8.New().WithRateLimit(50) func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate { gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond) return gate @@ -102,12 +103,13 @@ func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate { // or lack thereof. // // Example: -// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) -// router := http.NewServeMux() -// // Without protection -// router.Handle("/handle", yourHandler) -// // With protection -// router.Handle("/handle", gate.Protect(yourHandler)) +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) +// router := http.NewServeMux() +// // Without protection +// router.Handle("/handle", yourHandler) +// // With protection +// router.Handle("/handle", gate.Protect(yourHandler)) // // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey func (gate *Gate) Protect(handler http.Handler) http.Handler { @@ -118,12 +120,13 @@ func (gate *Gate) Protect(handler http.Handler) http.Handler { // as well as a slice of permissions that must be met. // // Example: -// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin"))) -// router := http.NewServeMux() -// // Without protection -// router.Handle("/handle", yourHandler) -// // With protection -// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN"))) +// router := http.NewServeMux() +// // Without protection +// router.Handle("/handle", yourHandler) +// // With protection +// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) // // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler { @@ -147,12 +150,13 @@ func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string) // permissions or lack thereof. // // Example: -// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) -// router := http.NewServeMux() -// // Without protection -// router.HandleFunc("/handle", yourHandlerFunc) -// // With protection -// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc)) +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) +// router := http.NewServeMux() +// // Without protection +// router.HandleFunc("/handle", yourHandlerFunc) +// // With protection +// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc)) // // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc { @@ -163,12 +167,13 @@ func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc { // token as well as a slice of permissions that must be met. // // Example: -// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin"))) -// router := http.NewServeMux() -// // Without protection -// router.HandleFunc("/handle", yourHandlerFunc) -// // With protection -// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"})) +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin"))) +// router := http.NewServeMux() +// // Without protection +// router.HandleFunc("/handle", yourHandlerFunc) +// // With protection +// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"})) // // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc { @@ -215,3 +220,19 @@ func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string { } return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ") } + +// PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used +// as a middleware for libraries that support such a feature. +// +// For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so: +// +// router := mux.NewRouter() +// router.Use(gate.PermissionMiddleware("admin")) +// router.Handle("/admin/handle", adminHandler) +// +// If you do not want to protect a router with a specific permission, you can use Gate.Protect instead. +func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return gate.ProtectWithPermissions(next, permissions) + } +} diff --git a/gate_test.go b/gate_test.go index 0e3d7bd..81e6887 100644 --- a/gate_test.go +++ b/gate_test.go @@ -311,6 +311,40 @@ func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testi } } +func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler + // is protected by the permission "admin", the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "mod" and the + // testHandler is protected by the permission "admin", the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + func TestGate_ProtectFuncWithInvalidToken(t *testing.T) { gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) request, _ := http.NewRequest("GET", "/handle", http.NoBody)