Skip to content

Commit

Permalink
Refactored to use middleware instead of one giant function
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin committed Jul 9, 2014
1 parent a839f1a commit 9dd0fac
Show file tree
Hide file tree
Showing 10 changed files with 439 additions and 107 deletions.
1 change: 1 addition & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (b AuthorisationManager) IsKeyExpired(newSession *SessionState) bool {
// UpdateSession updates the session state in the storage engine
func (b AuthorisationManager) UpdateSession(keyName string, session SessionState) {
v, _ := json.Marshal(session)
log.Info(session)
key_exp := (session.Expires - time.Now().Unix()) + 300 // Add 5 minutes to key expiry, just in case

b.Store.SetKey(keyName, string(v), key_exp)
Expand Down
53 changes: 53 additions & 0 deletions error_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package main

import (
"net/http"
"time"
"runtime/pprof"
"fmt"
"github.com/gorilla/context"
)

type ErrorHandler struct{
TykMiddleware
}

func (e ErrorHandler) HandleError(w http.ResponseWriter, r *http.Request, err string, err_code int) {
if config.EnableAnalytics {
t := time.Now()
keyName := r.Header.Get(e.Spec.ApiDefinition.Auth.AuthHeaderName)
version := e.Spec.getVersionFromRequest(r)
if version == "" {
version = "Non Versioned"
}
thisRecord := AnalyticsRecord{
r.Method,
r.URL.Path,
r.ContentLength,
r.Header.Get("User-Agent"),
t.Day(),
t.Month(),
t.Year(),
t.Hour(),
err_code,
keyName,
t,
version,
e.Spec.ApiDefinition.Name,
e.Spec.ApiDefinition.ApiId,
e.Spec.ApiDefinition.OrgId}
analytics.RecordHit(thisRecord)
}

w.WriteHeader(err_code)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("X-Generator", "tyk.io")
thisError := ApiError{fmt.Sprintf("%s", err)}
templates.ExecuteTemplate(w, "error.json", &thisError)
if doMemoryProfile {
pprof.WriteHeapProfile(prof_file)
}

// Clean up
context.Clear(r)
}
111 changes: 5 additions & 106 deletions gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"fmt"
"github.com/Sirupsen/logrus"
"net/http"
"net/http/httputil"
"time"
Expand All @@ -17,112 +16,12 @@ type ApiError struct {
func handler(p *httputil.ReverseProxy, apiSpec ApiSpec) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {

// Check versioning, blacklist, whitelist and ignored status
requestValid, stat := apiSpec.IsRequestValid(r)
if requestValid == false {
handle_error(w, r, string(stat), 409, apiSpec)
return
}

if stat == StatusOkAndIgnore {
success_handler(w, r, p, apiSpec)
return
}

// All is ok with the request itself, now auth and validate the rest
// Check for API key existence
authHeaderValue := r.Header.Get(apiSpec.ApiDefinition.Auth.AuthHeaderName)
if authHeaderValue != "" {
// Check if API key valid
key_authorised, thisSessionState := authManager.IsKeyAuthorised(authHeaderValue)

// Check if this version is allowable!
accessingVersion := apiSpec.getVersionFromRequest(r)
apiId := apiSpec.ApiId
tm := TykMiddleware{apiSpec, p}
handler := SuccessHandler{tm}
// Skip all other execution
handler.ServeHTTP(w, r)
return

// If there's nothing in our profile, we let them through to the next phase
if len(thisSessionState.AccessRights) > 0 {
// Run auth checks
versionList, apiExists := thisSessionState.AccessRights[apiId]
if !apiExists {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access to unauthorised API.")
handle_error(w, r, "Access to this API has been disallowed", 403, apiSpec)
return

} else {
found := false
for _, vInfo := range versionList.Versions {
if vInfo == accessingVersion {
found = true
break
}
}
if !found {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access to unauthorised API version.")
handle_error(w, r, "Access to this API version has been disallowed", 403, apiSpec)
return
}
}
}

keyExpired := authManager.IsKeyExpired(&thisSessionState)
if key_authorised {
if !keyExpired {
// If valid, check if within rate limit
forwardMessage, reason := sessionLimiter.ForwardMessage(&thisSessionState)
if forwardMessage {
success_handler(w, r, p, apiSpec)
} else {
// TODO Use an Enum!
if reason == 1 {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Key rate limit exceeded.")
handle_error(w, r, "Rate limit exceeded", 409, apiSpec)
} else if reason == 2 {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Key quota limit exceeded.")
handle_error(w, r, "Quota exceeded", 409, apiSpec)
}

}
authManager.UpdateSession(authHeaderValue, thisSessionState)
} else {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access from expired key.")
handle_error(w, r, "Key has expired, please renew", 403, apiSpec)
}
} else {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access with non-existent key.")
handle_error(w, r, "Key not authorised", 403, apiSpec)
}
} else {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
}).Info("Attempted access with malformed header, no auth header found.")
handle_error(w, r, "Authorisation field missing", 400, apiSpec)
}
}
}

Expand Down
48 changes: 47 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/rcrowley/goagain"
"net"
"time"
"github.com/justinas/alice"
)

var log = logrus.New()
Expand Down Expand Up @@ -153,6 +154,41 @@ func getApiSpecs() []ApiSpec {
return ApiSpecs
}

func customHandler1(h http.Handler) http.Handler {
thisHandler := func(w http.ResponseWriter, r *http.Request) {
log.Info("Middlwware 1 called!")
h.ServeHTTP(w, r)
}

return http.HandlerFunc(thisHandler)
}

func customHandler2(h http.Handler) http.Handler {
thisHandler := func(w http.ResponseWriter, r *http.Request) {
log.Info("Middlwware 2 called!")
h.ServeHTTP(w, r)
}

return http.HandlerFunc(thisHandler)
}

type StructMiddleware struct{
spec ApiSpec
}

func (s StructMiddleware) New(spec ApiSpec) func(http.Handler) http.Handler {
aliceHandler := func(h http.Handler) http.Handler {
thisHandler := func(w http.ResponseWriter, r *http.Request) {
log.Info("Middlwware 3 called!")
log.Info(spec.ApiId)
h.ServeHTTP(w, r)
}
return http.HandlerFunc(thisHandler)
}

return aliceHandler
}

func loadApps(ApiSpecs []ApiSpec, Muxer *http.ServeMux) {
// load the APi defs
log.Info("Loading API configurations.")
Expand All @@ -166,7 +202,17 @@ func loadApps(ApiSpecs []ApiSpec, Muxer *http.ServeMux) {
}
log.Info(remote)
proxy := httputil.NewSingleHostReverseProxy(remote)
Muxer.HandleFunc(spec.Proxy.ListenPath, handler(proxy, spec))

myHandler := http.HandlerFunc(handler(proxy, spec))
tykMiddleware := TykMiddleware{spec, proxy}

chain := alice.New(
VersionCheck{tykMiddleware}.New(),
KeyExists{tykMiddleware}.New(),
AccessRightsCheck{tykMiddleware}.New(),
KeyExpired{tykMiddleware}.New(),
RateLimitAndQuotaCheck{tykMiddleware}.New()).Then(myHandler)
Muxer.Handle(spec.Proxy.ListenPath, chain)
}
}

Expand Down
66 changes: 66 additions & 0 deletions middleware_access_rights.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package main

import "net/http"

import (
"github.com/gorilla/context"
"github.com/Sirupsen/logrus"
)

type AccessRightsCheck struct{
TykMiddleware
}

func (a AccessRightsCheck) New() func(http.Handler) http.Handler {
aliceHandler := func(h http.Handler) http.Handler {
thisHandler := func(w http.ResponseWriter, r *http.Request) {

accessingVersion := a.Spec.getVersionFromRequest(r)
thisSessionState := context.Get(r, SessionData).(SessionState)
authHeaderValue := context.Get(r, AuthHeaderValue)

// If there's nothing in our profile, we let them through to the next phase
if len(thisSessionState.AccessRights) > 0 {
// Otherwise, run auth checks
versionList, apiExists := thisSessionState.AccessRights[a.Spec.ApiId]
if !apiExists {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access to unauthorised API.")
handler := ErrorHandler{a.TykMiddleware}
handler.HandleError(w, r, "Access to this API has been disallowed", 403)
return
}

// Find the version in their key access details
found := false
for _, vInfo := range versionList.Versions {
if vInfo == accessingVersion {
found = true
break
}
}
if !found {
// Not found? Bounce
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access to unauthorised API version.")
handler := ErrorHandler{a.TykMiddleware}
handler.HandleError(w, r, "Access to this API has been disallowed", 403)
return
}
}

// No gates failed, request is valid, carry on
h.ServeHTTP(w, r)
}

return http.HandlerFunc(thisHandler)
}

return aliceHandler
}
57 changes: 57 additions & 0 deletions middleware_key_exists.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import "net/http"

import (
"github.com/gorilla/context"
"github.com/Sirupsen/logrus"
)

type KeyExists struct{
TykMiddleware
}

func (k KeyExists) New() func(http.Handler) http.Handler {
aliceHandler := func(h http.Handler) http.Handler {
thisHandler := func(w http.ResponseWriter, r *http.Request) {

authHeaderValue := r.Header.Get(k.Spec.ApiDefinition.Auth.AuthHeaderName)
if authHeaderValue == "" {
// No header value, fail
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
}).Info("Attempted access with malformed header, no auth header found.")

handler := ErrorHandler{k.TykMiddleware}
handler.HandleError(w, r, "Authorisation field missing", 400)
return
}

// Check if API key valid
key_exists, thisSessionState := authManager.IsKeyAuthorised(authHeaderValue)
if !key_exists {
log.WithFields(logrus.Fields{
"path": r.URL.Path,
"origin": r.RemoteAddr,
"key": authHeaderValue,
}).Info("Attempted access with non-existent key.")

handler := ErrorHandler{k.TykMiddleware}
handler.HandleError(w, r, "Key not authorised", 403)
return
}

// Set session state on context, we will need it later
context.Set(r, SessionData, thisSessionState)
context.Set(r, AuthHeaderValue, authHeaderValue)

// Request is valid, carry on
h.ServeHTTP(w, r)
}

return http.HandlerFunc(thisHandler)
}

return aliceHandler
}
Loading

0 comments on commit 9dd0fac

Please sign in to comment.