Skip to content

Commit

Permalink
Add CIDR support for IP whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
iwat committed Jan 13, 2016
1 parent 482cf88 commit a0f2b7e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
15 changes: 13 additions & 2 deletions middleware_ip_whitelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ func (i *IPWhiteListMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Re

// Enabled, check incoming IP address
for _, ip := range i.TykMiddleware.Spec.AllowedIPs {
allowedIP := net.ParseIP(ip)
// Might be CIDR, try this one first then fallback to IP parsing later
allowedIP, allowedNet, err := net.ParseCIDR(ip)
if err != nil {
allowedIP = net.ParseIP(ip)
}

splitIP := strings.Split(r.RemoteAddr, ":")
remoteIPString := splitIP[0]
if len(splitIP) > 2 {
Expand All @@ -41,8 +46,14 @@ func (i *IPWhiteListMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Re
}
remoteIP = net.ParseIP(remoteIPString)

// Check CIDR if possible
if allowedNet != nil && allowedNet.Contains(remoteIP) {
// matched, pass through
return nil, 200
}

// We parse the IP to manage IPv4 and IPv6 easily
if allowedIP.String() == remoteIP.String() {
if allowedIP.Equal(remoteIP) {
// matched, pass through
return nil, 200
}
Expand Down
31 changes: 30 additions & 1 deletion middleware_ip_whitelist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ var ipMiddlewareTestDefinitionEnabledPass string = `
"strip_listen_path": false
},
"enable_ip_whitelisting": true,
"allowed_ips": ["127.0.0.1"]
"allowed_ips": ["127.0.0.1", "127.0.0.1/24"]
}
`

Expand Down Expand Up @@ -237,6 +237,35 @@ func TestIpMiddlewareIPPass(t *testing.T) {
}
}

func TestIpMiddlewareIPPassCIDR(t *testing.T) {
spec := MakeIPSampleAPI(ipMiddlewareTestDefinitionEnabledPass)
redisStore := RedisStorageManager{KeyPrefix: "apikey-"}
healthStore := &RedisStorageManager{KeyPrefix: "apihealth."}
orgStore := &RedisStorageManager{KeyPrefix: "orgKey."}
spec.Init(&redisStore, &redisStore, healthStore, orgStore)
thisSession := createNonThrottledSession()
spec.SessionManager.UpdateSession("gfgg1234", thisSession, 60)
uri := "/about-lonelycoder/"
method := "GET"

recorder := httptest.NewRecorder()
param := make(url.Values)
req, err := http.NewRequest(method, uri+param.Encode(), nil)
req.RemoteAddr = "127.0.0.2"
req.Header.Add("authorization", "gfgg1234")

if err != nil {
t.Fatal(err)
}

chain := getChain(*spec)
chain.ServeHTTP(recorder, req)

if recorder.Code != 200 {
t.Error("Invalid response code, should be 200: \n", recorder.Code, recorder.Body, req.RemoteAddr)
}
}

func TestIpMiddlewareIPMissing(t *testing.T) {
spec := MakeIPSampleAPI(ipMiddlewareTestDefinitionMissing)
redisStore := RedisStorageManager{KeyPrefix: "apikey-"}
Expand Down

0 comments on commit a0f2b7e

Please sign in to comment.