Skip to content

Commit

Permalink
feature: api blacklisting middleware
Browse files Browse the repository at this point in the history
refs: TykTechnologies#1545

This PR adds IP Blacklisting Middleware. It works in same way as
IP Whitelisting.
  • Loading branch information
asoorm authored and buger committed Mar 20, 2018
1 parent 9809c35 commit 4e667e2
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,

mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPBlackListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &CertificateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &RateLimitForAPI{BaseMiddleware: baseMid})
Expand Down Expand Up @@ -349,6 +350,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,

mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &IPBlackListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &CertificateCheckMW{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&chainArray, &VersionCheck{BaseMiddleware: baseMid})
Expand Down Expand Up @@ -470,6 +472,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int,

var simpleArray []alice.Constructor
mwAppendEnabled(&simpleArray, &IPWhiteListMiddleware{baseMid})
mwAppendEnabled(&chainArray, &IPBlackListMiddleware{BaseMiddleware: baseMid})
mwAppendEnabled(&simpleArray, &OrganizationMonitor{BaseMiddleware: baseMid})
mwAppendEnabled(&simpleArray, &VersionCheck{BaseMiddleware: baseMid})
simpleArray = append(simpleArray, authArray...)
Expand Down
2 changes: 2 additions & 0 deletions apidef/api_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ type APIDefinition struct {
EnableBatchRequestSupport bool `bson:"enable_batch_request_support" json:"enable_batch_request_support"`
EnableIpWhiteListing bool `mapstructure:"enable_ip_whitelisting" bson:"enable_ip_whitelisting" json:"enable_ip_whitelisting"`
AllowedIPs []string `mapstructure:"allowed_ips" bson:"allowed_ips" json:"allowed_ips"`
EnableIpBlacklisting bool `mapstructure:"enable_ip_blacklisting" bson:"enable_ip_blacklisting" json:"enable_ip_blacklisting"`
BlacklistedIPs []string `mapstructure:"blacklisted_ips" bson:"blacklisted_ips" json:"blacklisted_ips"`
DontSetQuotasOnCreate bool `mapstructure:"dont_set_quota_on_create" bson:"dont_set_quota_on_create" json:"dont_set_quota_on_create"`
ExpireAnalyticsAfter int64 `mapstructure:"expire_analytics_after" bson:"expire_analytics_after" json:"expire_analytics_after"` // must have an expireAt TTL index set (http://docs.mongodb.org/manual/tutorial/expire-data/)
ResponseProcessors []ResponseProcessor `bson:"response_processors" json:"response_processors"`
Expand Down
1 change: 1 addition & 0 deletions multiauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func getMultiAuthStandardAndBasicAuthChain(spec *APISpec) http.Handler {
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&IPBlackListMiddleware{BaseMiddleware: baseMid},
&BasicAuthKeyIsValid{baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
Expand Down
2 changes: 2 additions & 0 deletions mw_api_rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func getRLOpenChain(spec *APISpec) http.Handler {
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&IPBlackListMiddleware{BaseMiddleware: baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&RateLimitForAPI{BaseMiddleware: baseMid},
)...).Then(proxyHandler)
Expand All @@ -48,6 +49,7 @@ func getGlobalRLAuthKeyChain(spec *APISpec) http.Handler {
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&IPBlackListMiddleware{BaseMiddleware: baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
Expand Down
1 change: 1 addition & 0 deletions mw_auth_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func getAuthKeyChain(spec *APISpec) http.Handler {
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&IPBlackListMiddleware{BaseMiddleware: baseMid},
&AuthKey{baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
Expand Down
1 change: 1 addition & 0 deletions mw_hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func getHMACAuthChain(spec *APISpec) http.Handler {
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
&IPBlackListMiddleware{BaseMiddleware: baseMid},
&HMACMiddleware{BaseMiddleware: baseMid},
&VersionCheck{BaseMiddleware: baseMid},
&KeyExpired{baseMid},
Expand Down
57 changes: 57 additions & 0 deletions mw_ip_blacklist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"errors"
"net"
"net/http"
)

// IPBlackListMiddleware lets you define a list of IPs to block from upstream
type IPBlackListMiddleware struct {
BaseMiddleware
}

func (i *IPBlackListMiddleware) Name() string {
return "IPBlackListMiddleware"
}

func (i *IPBlackListMiddleware) EnabledForSpec() bool {
return i.Spec.EnableIpBlacklisting && len(i.Spec.BlacklistedIPs) > 0
}

// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (i *IPBlackListMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
remoteIP := net.ParseIP(requestIP(r))

// Enabled, check incoming IP address
for _, ip := range i.Spec.BlacklistedIPs {
// Might be CIDR, try this one first then fallback to IP parsing later
blockedIP, blockedNet, err := net.ParseCIDR(ip)
if err != nil {
blockedIP = net.ParseIP(ip)
}

// Check CIDR if possible
if blockedNet != nil && blockedNet.Contains(remoteIP) {

return i.handleError(r, remoteIP.String())
}

// We parse the IP to manage IPv4 and IPv6 easily
if blockedIP.Equal(remoteIP) {

return i.handleError(r, remoteIP.String())
}
}

return nil, http.StatusOK
}

func (i *IPBlackListMiddleware) handleError(r *http.Request, blacklistedIP string) (error, int) {

// Fire Authfailed Event
AuthFailed(i, r, blacklistedIP)
// Report in health check
reportHealthValue(i.Spec, KeyFailure, "-1")
return errors.New("access from this IP has been disallowed"), http.StatusForbidden
}
41 changes: 41 additions & 0 deletions mw_ip_blacklist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package main

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestIPBlacklistMiddleware(t *testing.T) {
spec := buildAPI(func(spec *APISpec) {
spec.EnableIpBlacklisting = true
spec.BlacklistedIPs = []string{"127.0.0.1", "127.0.0.1/24"}
})[0]

for ti, tc := range []struct {
remote, forwarded string
wantCode int
}{
{"127.0.0.1:80", "", http.StatusForbidden}, // remote exact match
{"127.0.0.2:80", "", http.StatusForbidden}, // remote CIDR match
{"10.0.0.1:80", "", http.StatusOK}, // no match
{"10.0.0.1:80", "127.0.0.1", http.StatusForbidden}, // forwarded exact match
{"10.0.0.1:80", "127.0.0.2", http.StatusForbidden}, // forwarded CIDR match
} {
rec := httptest.NewRecorder()
req := testReq(t, "GET", "/", nil)
req.RemoteAddr = tc.remote
if tc.forwarded != "" {
req.Header.Set("X-Forwarded-For", tc.forwarded)
}

mw := &IPBlackListMiddleware{}
mw.Spec = spec
_, code := mw.ProcessRequest(rec, req, nil)

if code != tc.wantCode {
t.Errorf("[%d] Response code %d should be %d\n%q %q", ti,
code, tc.wantCode, tc.remote, tc.forwarded)
}
}
}

0 comments on commit 4e667e2

Please sign in to comment.