Skip to content

Commit

Permalink
mw: remove IsEnabledForSpec dups in ProcessRequest
Browse files Browse the repository at this point in the history
In a middleware, ProcessRequest can only be called if IsEnabledForSpec
returned true. Thus, it makes little sense for middlewares to repeat or
do their "enabled" check in ProcessRequest.

This was the case for a few of them. Turns out that this was done
because some parts of the code, especially the tests, did not always
keep that promise.

Refactor the codebase to keep that promise and apply the
simplifications. Add a new helper func, mwList, to help us with that.
Also renamed appendMiddleware to make it more obvious that the actual
append doesn't always happen.
  • Loading branch information
mvdan authored and buger committed Sep 17, 2017
1 parent af3600f commit 1599fbb
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 130 deletions.
93 changes: 46 additions & 47 deletions api_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,33 +282,33 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "coprocess",
"api_name": spec.Name,
}).Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver)
appendMiddleware(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver})
mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver})
} else {
chainArray = append(chainArray, createDynamicMiddleware(obj.Name, true, obj.RequireSession, baseMid))
}
}

appendMiddleware(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &VersionCheck{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &RequestSizeLimitMiddleware{baseMid})
appendMiddleware(&chainArray, &TrackEndpointMiddleware{baseMid})
appendMiddleware(&chainArray, &TransformMiddleware{baseMid})
appendMiddleware(&chainArray, &TransformHeaders{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid, CacheStore: cacheStore})
appendMiddleware(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &TransformMethod{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &VersionCheck{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RequestSizeLimitMiddleware{baseMid})
mwAppendEnabled(&chainArray, &TrackEndpointMiddleware{baseMid})
mwAppendEnabled(&chainArray, &TransformMiddleware{baseMid})
mwAppendEnabled(&chainArray, &TransformHeaders{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid, CacheStore: cacheStore})
mwAppendEnabled(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &TransformMethod{BaseMiddleware: baseMid})

for _, obj := range mwPostFuncs {
if mwDriver != apidef.OttoDriver {
log.WithFields(logrus.Fields{
"prefix": "coprocess",
"api_name": spec.Name,
}).Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Post", ", driver: ", mwDriver)
appendMiddleware(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver})
mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver})
} else {
chainArray = append(chainArray, createDynamicMiddleware(obj.Name, false, obj.RequireSession, baseMid))
}
Expand All @@ -330,19 +330,19 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "coprocess",
"api_name": spec.Name,
}).Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver)
appendMiddleware(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver})
mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver})
} else {
chainArray = append(chainArray, createDynamicMiddleware(obj.Name, true, obj.RequireSession, baseMid))
}
}

appendMiddleware(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &VersionCheck{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &RequestSizeLimitMiddleware{baseMid})
appendMiddleware(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &TrackEndpointMiddleware{baseMid})
mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &VersionCheck{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RequestSizeLimitMiddleware{baseMid})
mwAppendEnabled(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &TrackEndpointMiddleware{baseMid})

// Select the keying method to use for setting session states
var authArray []alice.Constructor
Expand All @@ -352,8 +352,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "main",
"api_name": spec.Name,
}).Info("Checking security policy: OAuth")
authArray = append(authArray, createMiddleware(&Oauth2KeyExists{baseMid}))

mwAppendEnabled(&authArray, &Oauth2KeyExists{baseMid})
}

useCoProcessAuth := EnableCoProcess && mwDriver != apidef.OttoDriver && spec.EnableCoProcessAuth
Expand All @@ -365,7 +364,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "main",
"api_name": spec.Name,
}).Info("Checking security policy: Basic")
authArray = append(authArray, createMiddleware(&BasicAuthKeyIsValid{baseMid}))
mwAppendEnabled(&authArray, &BasicAuthKeyIsValid{baseMid})
}

if spec.EnableSignatureChecking {
Expand All @@ -374,7 +373,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "main",
"api_name": spec.Name,
}).Info("Checking security policy: HMAC")
authArray = append(authArray, createMiddleware(&HMACMiddleware{BaseMiddleware: baseMid}))
mwAppendEnabled(&authArray, &HMACMiddleware{BaseMiddleware: baseMid})
}

if spec.EnableJWT {
Expand All @@ -383,7 +382,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "main",
"api_name": spec.Name,
}).Info("Checking security policy: JWT")
authArray = append(authArray, createMiddleware(&JWTMiddleware{baseMid}))
mwAppendEnabled(&authArray, &JWTMiddleware{baseMid})
}

if spec.UseOpenID {
Expand All @@ -394,7 +393,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
}).Info("Checking security policy: OpenID")

// initialise the OID configuration on this reference Spec
authArray = append(authArray, createMiddleware(&OpenIDMW{BaseMiddleware: baseMid}))
mwAppendEnabled(&authArray, &OpenIDMW{BaseMiddleware: baseMid})
}

if useCoProcessAuth {
Expand All @@ -411,7 +410,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,

if useCoProcessAuth {
newExtractor(spec, baseMid)
appendMiddleware(&authArray, &CoProcessMiddleware{baseMid, coprocess.HookType_CustomKeyCheck, mwAuthCheckFunc.Name, mwDriver})
mwAppendEnabled(&authArray, &CoProcessMiddleware{baseMid, coprocess.HookType_CustomKeyCheck, mwAuthCheckFunc.Name, mwDriver})
}
}

Expand Down Expand Up @@ -439,27 +438,27 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
"prefix": "coprocess",
"api_name": spec.Name,
}).Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver)
appendMiddleware(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_PostKeyAuth, obj.Name, mwDriver})
mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_PostKeyAuth, obj.Name, mwDriver})
}

appendMiddleware(&chainArray, &KeyExpired{baseMid})
appendMiddleware(&chainArray, &AccessRightsCheck{baseMid})
appendMiddleware(&chainArray, &RateLimitAndQuotaCheck{baseMid})
appendMiddleware(&chainArray, &GranularAccessMiddleware{baseMid})
appendMiddleware(&chainArray, &TransformMiddleware{baseMid})
appendMiddleware(&chainArray, &TransformHeaders{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid, CacheStore: cacheStore})
appendMiddleware(&chainArray, &TransformMethod{BaseMiddleware: baseMid})
appendMiddleware(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &KeyExpired{baseMid})
mwAppendEnabled(&chainArray, &AccessRightsCheck{baseMid})
mwAppendEnabled(&chainArray, &RateLimitAndQuotaCheck{baseMid})
mwAppendEnabled(&chainArray, &GranularAccessMiddleware{baseMid})
mwAppendEnabled(&chainArray, &TransformMiddleware{baseMid})
mwAppendEnabled(&chainArray, &TransformHeaders{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid, CacheStore: cacheStore})
mwAppendEnabled(&chainArray, &TransformMethod{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid})

for _, obj := range mwPostFuncs {
if mwDriver != apidef.OttoDriver {
log.WithFields(logrus.Fields{
"prefix": "coprocess",
"api_name": spec.Name,
}).Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Post", ", driver: ", mwDriver)
appendMiddleware(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver})
mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver})
} else {
chainArray = append(chainArray, createDynamicMiddleware(obj.Name, false, obj.RequireSession, baseMid))
}
Expand All @@ -476,12 +475,12 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
log.Debug("Chain completed")

var simpleArray []alice.Constructor
simpleArray = append(simpleArray, createMiddleware(&IPWhiteListMiddleware{baseMid}))
simpleArray = append(simpleArray, createMiddleware(&OrganizationMonitor{BaseMiddleware: baseMid}))
simpleArray = append(simpleArray, createMiddleware(&VersionCheck{BaseMiddleware: baseMid}))
mwAppendEnabled(&simpleArray, &IPWhiteListMiddleware{baseMid})
mwAppendEnabled(&simpleArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&simpleArray, &VersionCheck{BaseMiddleware: baseMid})
simpleArray = append(simpleArray, authArray...)
simpleArray = append(simpleArray, createMiddleware(&KeyExpired{baseMid}))
simpleArray = append(simpleArray, createMiddleware(&AccessRightsCheck{baseMid}))
mwAppendEnabled(&simpleArray, &KeyExpired{baseMid})
mwAppendEnabled(&simpleArray, &AccessRightsCheck{baseMid})

rateLimitPath := spec.Proxy.ListenPath + "tyk/rate-limits/"
log.WithFields(logrus.Fields{
Expand Down
20 changes: 10 additions & 10 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,16 @@ func getChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&MiddlewareContextVars{BaseMiddleware: baseMid}),
createMiddleware(&AuthKey{baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid}),
createMiddleware(&TransformHeaders{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&MiddlewareContextVars{BaseMiddleware: baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
&TransformHeaders{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
10 changes: 9 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,20 @@ func createMiddleware(mw TykMiddleware) func(http.Handler) http.Handler {
}
}

func appendMiddleware(chain *[]alice.Constructor, mw TykMiddleware) {
func mwAppendEnabled(chain *[]alice.Constructor, mw TykMiddleware) {
if mw.IsEnabledForSpec() {
*chain = append(*chain, createMiddleware(mw))
}
}

func mwList(mws ...TykMiddleware) []alice.Constructor {
var list []alice.Constructor
for _, mw := range mws {
mwAppendEnabled(&list, mw)
}
return list
}

// BaseMiddleware wraps up the ApiSpec and Proxy objects to be included in a
// middleware handler, this can probably be handled better.
type BaseMiddleware struct {
Expand Down
18 changes: 9 additions & 9 deletions multiauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ func getMultiAuthStandardAndBasicAuthChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&BasicAuthKeyIsValid{baseMid}),
createMiddleware(&AuthKey{baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&BasicAuthKeyIsValid{baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
16 changes: 8 additions & 8 deletions mw_auth_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ func getAuthKeyChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&AuthKey{baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
16 changes: 8 additions & 8 deletions mw_basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ func getBasicAuthChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&BasicAuthKeyIsValid{baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&BasicAuthKeyIsValid{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
4 changes: 0 additions & 4 deletions mw_context_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ func (m *MiddlewareContextVars) IsEnabledForSpec() bool {
// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (m *MiddlewareContextVars) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {

if !m.Spec.EnableContextVars {
return nil, 200
}

copiedRequest := copyRequest(r)
contextDataObject := make(map[string]interface{})

Expand Down
16 changes: 8 additions & 8 deletions mw_hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ func getHMACAuthChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&HMACMiddleware{BaseMiddleware: baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&HMACMiddleware{BaseMiddleware: baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
5 changes: 0 additions & 5 deletions mw_ip_whitelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ func (i *IPWhiteListMiddleware) IsEnabledForSpec() bool {

// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (i *IPWhiteListMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
// Disabled, pass through
if !i.Spec.EnableIpWhiteListing {
return nil, 200
}

remoteIP := net.ParseIP(requestIP(r))

// Enabled, check incoming IP address
Expand Down
16 changes: 8 additions & 8 deletions mw_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ func getJWTChain(spec *APISpec) http.Handler {
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
baseMid := &BaseMiddleware{spec, proxy}
chain := alice.New(
createMiddleware(&IPWhiteListMiddleware{baseMid}),
createMiddleware(&JWTMiddleware{baseMid}),
createMiddleware(&VersionCheck{BaseMiddleware: baseMid}),
createMiddleware(&KeyExpired{baseMid}),
createMiddleware(&AccessRightsCheck{baseMid}),
createMiddleware(&RateLimitAndQuotaCheck{baseMid})).Then(proxyHandler)

chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&JWTMiddleware{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
)...).Then(proxyHandler)
return chain
}

Expand Down
Loading

0 comments on commit 1599fbb

Please sign in to comment.