Skip to content

Commit

Permalink
Config CORS (micro#2270)
Browse files Browse the repository at this point in the history
* Added cors.config for CORS

* Added cors.config for CORS

* Added cors.config for CORS

Co-authored-by: 于海洋 <[email protected]>
  • Loading branch information
helloworldyuhaiyang and 于海洋 authored Sep 17, 2021
1 parent 4c7d2e2 commit ad53252
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 23 deletions.
57 changes: 35 additions & 22 deletions api/server/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,54 @@ import (
"net/http"
)

// CombinedCORSHandler wraps a server and provides CORS headers
func CombinedCORSHandler(h http.Handler) http.Handler {
return corsHandler{h}
}

type corsHandler struct {
handler http.Handler
type Config struct {
AllowOrigin string
AllowCredentials bool
AllowMethods string
AllowHeaders string
}

func (c corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
SetHeaders(w, r)

if r.Method == "OPTIONS" {
return
}
// CombinedCORSHandler wraps a server and provides CORS headers
func CombinedCORSHandler(h http.Handler, config *Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config != nil {
SetHeaders(w, r, config)
}
if r.Method == "OPTIONS" {
return
}

c.handler.ServeHTTP(w, r)
h.ServeHTTP(w, r)
})
}

// SetHeaders sets the CORS headers
func SetHeaders(w http.ResponseWriter, r *http.Request) {
func SetHeaders(w http.ResponseWriter, _ *http.Request, config *Config) {
set := func(w http.ResponseWriter, k, v string) {
if v := w.Header().Get(k); len(v) > 0 {
return
}
w.Header().Set(k, v)
}

if origin := r.Header.Get("Origin"); len(origin) > 0 {
set(w, "Access-Control-Allow-Origin", origin)
//For forward-compatible code, default values may not be provided in the future
if config.AllowCredentials {
set(w, "Access-Control-Allow-Credentials", "true")
} else {
set(w, "Access-Control-Allow-Credentials", "false")
}
if config.AllowOrigin == "" {
set(w, "Access-Control-Allow-Origin", "*")
} else {
set(w, "Access-Control-Allow-Origin", config.AllowOrigin)
}
if config.AllowMethods == "" {
set(w, "Access-Control-Allow-Methods", "POST, PATCH, GET, OPTIONS, PUT, DELETE")
} else {
set(w, "Access-Control-Allow-Methods", config.AllowMethods)
}
if config.AllowHeaders == "" {
set(w, "Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
} else {
set(w, "Access-Control-Allow-Headers", config.AllowHeaders)
}

set(w, "Access-Control-Allow-Credentials", "true")
set(w, "Access-Control-Allow-Methods", "POST, PATCH, GET, OPTIONS, PUT, DELETE")
set(w, "Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
}
2 changes: 1 addition & 1 deletion api/server/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *httpServer) Handle(path string, handler http.Handler) {

// wrap with cors
if s.opts.EnableCORS {
handler = cors.CombinedCORSHandler(handler)
handler = cors.CombinedCORSHandler(handler, s.opts.CORSConfig)
}

// wrap with logger
Expand Down
73 changes: 73 additions & 0 deletions api/server/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package http

import (
"fmt"
"github.com/asim/go-micro/v3/api/server"
"github.com/asim/go-micro/v3/api/server/cors"
"io/ioutil"
"net/http"
"testing"
Expand Down Expand Up @@ -39,3 +41,74 @@ func TestHTTPServer(t *testing.T) {
t.Fatal(err)
}
}

func TestCORSHTTPServer(t *testing.T) {
testResponse := "hello world"
testAllowOrigin := "*"
testAllowCredentials := true
testAllowMethods := "GET"
testAllowHeaders := "Accept, Content-Type, Content-Length"

s := NewServer("localhost:0",
server.EnableCORS(true),
server.CORSConfig(&cors.Config{
AllowCredentials: testAllowCredentials,
AllowOrigin: testAllowOrigin,
AllowMethods: testAllowMethods,
AllowHeaders: testAllowHeaders,
}),
)

s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, testResponse)
}))

if err := s.Start(); err != nil {
t.Fatal(err)
}

rsp, err := http.Get(fmt.Sprintf("http://%s/", s.Address()))
if err != nil {
t.Fatal(err)
}
defer rsp.Body.Close()

b, err := ioutil.ReadAll(rsp.Body)
if err != nil {
t.Fatal(err)
}

if string(b) != testResponse {
t.Fatalf("Unexpected response, got %s, expected %s", string(b), testResponse)
}

allowCredentials := rsp.Header.Get("Access-Control-Allow-Credentials")
getTestCredentialsStr := func() string {
if testAllowCredentials == true {
return "true"
} else {
return "false"
}
}
if getTestCredentialsStr() != allowCredentials {
t.Fatalf("Unexpected Access-Control-Allow-Credentials, got %s, expected %s", allowCredentials, getTestCredentialsStr())
}

allowOrigin := rsp.Header.Get("Access-Control-Allow-Origin")
if testAllowOrigin != allowOrigin {
t.Fatalf("Unexpected Access-Control-Allow-Origins, got %s, expected %s", allowOrigin, testAllowOrigin)
}

allowMethods := rsp.Header.Get("Access-Control-Allow-Methods")
if testAllowMethods != allowMethods {
t.Fatalf("Unexpected Access-Control-Allow-Methods, got %s, expected %s", allowMethods, testAllowMethods)
}
allowHeaders := rsp.Header.Get("Access-Control-Allow-Headers")
if testAllowHeaders != allowHeaders {
t.Fatalf("Unexpected Access-Control-Allow-Headers, got %s, expected %s", allowHeaders, testAllowHeaders)
}

if err := s.Stop(); err != nil {
t.Fatal(err)
}
}
8 changes: 8 additions & 0 deletions api/server/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"crypto/tls"
"github.com/asim/go-micro/v3/api/server/cors"
"net/http"

"github.com/asim/go-micro/v3/api/resolver"
Expand All @@ -13,6 +14,7 @@ type Option func(o *Options)
type Options struct {
EnableACME bool
EnableCORS bool
CORSConfig *cors.Config
ACMEProvider acme.Provider
EnableTLS bool
ACMEHosts []string
Expand All @@ -35,6 +37,12 @@ func EnableCORS(b bool) Option {
}
}

func CORSConfig(c *cors.Config) Option {
return func(o *Options) {
o.CORSConfig = c
}
}

func EnableACME(b bool) Option {
return func(o *Options) {
o.EnableACME = b
Expand Down

0 comments on commit ad53252

Please sign in to comment.