Skip to content

Commit

Permalink
Handle allowed CORS requests origins
Browse files Browse the repository at this point in the history
  • Loading branch information
rafmagana committed Sep 17, 2014
1 parent 11bcc03 commit 3f53779
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 1 deletion.
19 changes: 18 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type api struct {
Password string
Cert string
Key string
cors *CORS
handlers []func(http.ResponseWriter, *http.Request)
start func(*api)
}
Expand Down Expand Up @@ -53,7 +54,6 @@ func NewAPI(g *gobot.Gobot) *api {

func (a *api) ServeHTTP(res http.ResponseWriter, req *http.Request) {
for _, handler := range a.handlers {
res.Header().Set("Access-Control-Allow-Origin", "*")
handler(res, req)
}
a.router.ServeHTTP(res, req)
Expand Down Expand Up @@ -93,6 +93,23 @@ func (a *api) SetBasicAuth(user, password string) {
a.AddHandler(a.basicAuth)
}

func (a *api) AllowRequestsFrom(allowedOrigins ...string) {
a.SetCORS(NewCORS(allowedOrigins))
}

func (a *api) SetCORS(cors *CORS) {
a.cors = cors
a.AddHandler(a.CORSHandler)
}

func (a *api) CORSHandler(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if a.cors.isOriginAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}

}

func (a *api) SetDebug() {
a.AddHandler(func(res http.ResponseWriter, req *http.Request) {
log.Println(req)
Expand Down
30 changes: 30 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func initTestAPI() *api {

return a
}

func TestBasicAuth(t *testing.T) {
a := initTestAPI()

Expand All @@ -48,6 +49,35 @@ func TestBasicAuth(t *testing.T) {
gobot.Assert(t, response.Code, 401)
}

func TestAPIAllowRequestsFrom(t *testing.T) {
api := initTestAPI()
api.AllowRequestsFrom("http://server.com")
gobot.Assert(t, api.cors.AllowOrigins, []string{"http://server.com"})
}

func TestCORS(t *testing.T) {
api := initTestAPI()

// Accepted origin
allowedOrigin := []string{"http://server.com"}
api.AllowRequestsFrom(allowedOrigin[0])

request, _ := http.NewRequest("GET", "/api/", nil)
request.Header.Set("Origin", allowedOrigin[0])
response := httptest.NewRecorder()
api.ServeHTTP(response, request)
gobot.Assert(t, response.Header()["Access-Control-Allow-Origin"], allowedOrigin)

// Not accepted Origin
disallowedOrigin := []string{"http://disallowed.com"}
request, _ = http.NewRequest("GET", "/api/", nil)
request.Header.Set("Origin", disallowedOrigin[0])
response = httptest.NewRecorder()
api.ServeHTTP(response, request)
gobot.Refute(t, response.Header()["Access-Control-Allow-Origin"], disallowedOrigin)
gobot.Refute(t, response.Header()["Access-Control-Allow-Origin"], allowedOrigin)
}

func TestRobeaux(t *testing.T) {
a := initTestAPI()
// html assets
Expand Down
24 changes: 24 additions & 0 deletions api/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package api

type CORS struct {
AllowOrigins []string
AllowHeaders []string // Not yet implemented
AllowMethods []string // ditto
}

func NewCORS(allowedOrigins []string) *CORS {
return &CORS{
AllowOrigins: allowedOrigins,
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Type"},
}
}

func (c *CORS) isOriginAllowed(currentOrigin string) bool {
for _, allowedOrigin := range c.AllowOrigins {
if "*" == allowedOrigin || currentOrigin == allowedOrigin {
return true
}
}
return false
}
47 changes: 47 additions & 0 deletions api/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package api

import (
"github.com/hybridgroup/gobot"
"testing"
)

func TestNewCORS(t *testing.T) {
var cors interface{} = NewCORS([]string{})

// Does it return a pointer to an instance of CORS?
_, ok := cors.(*CORS)
if !ok {
t.Errorf("NewCORS() should have returned a *CORS")
}
}

func TestNewCorsSetsProperties(t *testing.T) {
allowedOrigins := []string{"http://server:port"}

cors := NewCORS(allowedOrigins)

gobot.Assert(t, cors.AllowOrigins, allowedOrigins)
}

func TestCORSIsOriginAllowed(t *testing.T) {
cors := NewCORS([]string{"*"})

// When all the origins are accepted
gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), true)
gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), true)
gobot.Assert(t, cors.isOriginAllowed("http://server.com"), true)

// When one origin is accepted
cors.AllowOrigins = []string{"http://localhost:8000"}

gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), true)
gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), false)
gobot.Assert(t, cors.isOriginAllowed("http://server.com"), false)

// When several origins are accepted
cors.AllowOrigins = []string{"http://localhost:8000", "http://server.com"}

gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), true)
gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), false)
gobot.Assert(t, cors.isOriginAllowed("http://server.com"), true)
}

0 comments on commit 3f53779

Please sign in to comment.