Skip to content

Commit

Permalink
refactor: cleanup the code for CORS handling (ory#1959)
Browse files Browse the repository at this point in the history
  • Loading branch information
harsimranmaan authored Aug 2, 2020
1 parent cf90919 commit 5a53d28
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 36 deletions.
36 changes: 9 additions & 27 deletions driver/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
return h
}
}

corsOptions := conf.CORSOptions(iface)

var alwaysAllow bool = len(corsOptions.AllowedOrigins) == 0
var patterns []glob.Glob
for _, o := range corsOptions.AllowedOrigins {
if o == "*" {
alwaysAllow = true
}
// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if splittedO := strings.Split(o, "://"); len(splittedO) != 1 && splittedO[1] == "*" {
Expand All @@ -54,20 +58,10 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
if err != nil {
reg.Logger().WithError(err).Fatalf("Unable to parse cors origin: %s", o)
}
patterns = append(patterns, g)
}

var alwaysAllow bool
for _, o := range corsOptions.AllowedOrigins {
if o == "*" {
alwaysAllow = true
break
}
patterns = append(patterns, g)
}

if len(corsOptions.AllowedOrigins) == 0 {
alwaysAllow = true
}

options := cors.Options{
AllowedOrigins: corsOptions.AllowedOrigins,
Expand Down Expand Up @@ -111,27 +105,15 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
return false
}

if alwaysAllow {
return true
}

for _, p := range cl.AllowedCORSOrigins {
if p == "*" {
for _, o := range cl.AllowedCORSOrigins {
if o == "*" {
return true
}
}

var clientPatterns []glob.Glob
for _, o := range cl.AllowedCORSOrigins {
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
}
clientPatterns = append(patterns, g)
}

for _, p := range clientPatterns {
if p.Match(origin) {
if(g.Match(origin)){
return true
}
}
Expand Down
52 changes: 43 additions & 9 deletions driver/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
. "github.com/ory/hydra/driver"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/oauth2"

"github.com/ory/hydra/x"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -68,69 +68,77 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}},
expectHeader: http.Header{"Vary": {"Origin"}},
},
{
d: "should reject when basic auth client exists but origin not allowed",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTI6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-2", "bar"))}},
expectHeader: http.Header{"Vary": {"Origin"}},
},
{
d: "should accept when basic auth client exists and origin allowed",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTM6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed per client",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*.foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTQ6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with full wildcard) is allowed globally",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"*"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-5", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"*"}, "Authorization": {"Basic Zm9vLTU6YmFy"}},
header: http.Header{"Origin": {"*"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-5", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"*"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed globally",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://*.foobar.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-6", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTY6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-6", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with full wildcard) allowed per client",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-7", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTc6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-7", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should fail when token introspection fails",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
},
code: http.StatusNotImplemented,
Expand All @@ -140,6 +148,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
{
d: "should work when token introspection returns a session",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
sess := oauth2.NewSession("foo-9")
sess.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour))
Expand All @@ -160,17 +169,41 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
{
d: "should accept any allowed specified origin protocol",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-11", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://*", "https://*"})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTQ6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-11", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept client origin when basic auth client exists and origin is set at the client as well as the server",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://**.example.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-12", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://myapp.example.biz"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-12", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://myapp.example.biz"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept server origin when basic auth client exists and origin is set at the client as well as the server",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://**.example.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-13", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://client-app.example.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-13", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://client-app.example.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) {
if tc.prep != nil {
viper.Reset()
tc.prep()
}

Expand All @@ -189,4 +222,5 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
assert.EqualValues(t, tc.expectHeader, res.Header())
})
}

}

0 comments on commit 5a53d28

Please sign in to comment.