Skip to content

Commit

Permalink
Implement Access-Control-Allow-Origin *
Browse files Browse the repository at this point in the history
Use Access-Control-Allow-Origin: * when all origins are allowed and
AllowCredentials is false.

Fix rs#30
  • Loading branch information
rs committed Jun 8, 2017
1 parent 3d4811e commit 8dd4211
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
18 changes: 14 additions & 4 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ func New(options Options) *Cors {

// Allowed Origins
if len(options.AllowedOrigins) == 0 {
// Default is all origins
c.allowedOriginsAll = true
if options.AllowOriginFunc == nil {
// Default is all origins
c.allowedOriginsAll = true
}
} else {
c.allowedOrigins = []string{}
c.allowedWOrigins = []wildcard{}
Expand Down Expand Up @@ -267,7 +269,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
return
}
headers.Set("Access-Control-Allow-Origin", origin)
if c.allowedOriginsAll && !c.allowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Access-Control-Allow-Origin", origin)
}
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
Expand Down Expand Up @@ -315,7 +321,11 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {

return
}
headers.Set("Access-Control-Allow-Origin", origin)
if c.allowedOriginsAll && !c.allowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Access-Control-Allow-Origin", origin)
}
if len(c.exposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
}
Expand Down
42 changes: 25 additions & 17 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestMatchAllOrigin(t *testing.T) {

assertHeaders(t, res.Header(), map[string]string{
"Vary": "Origin",
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
Expand All @@ -69,6 +69,29 @@ func TestMatchAllOrigin(t *testing.T) {
})
}

func TestMatchAllOriginWithCredentials(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")

s.Handler(testHandler).ServeHTTP(res, req)

assertHeaders(t, res.Header(), map[string]string{
"Vary": "Origin",
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}

func TestAllowedOrigin(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
Expand Down Expand Up @@ -405,24 +428,9 @@ func TestDebug(t *testing.T) {
Debug: true,
})

if s.logf == nil {
if s.Log == nil {
t.Error("Logger not created when debug=true")
}

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)

s.Handler(testHandler).ServeHTTP(res, req)

assertHeaders(t, res.Header(), map[string]string{
"Vary": "Origin",
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}

func TestOptionsPassthrough(t *testing.T) {
Expand Down

0 comments on commit 8dd4211

Please sign in to comment.