forked from hybridgroup/gobot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcors.go
70 lines (61 loc) · 1.97 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package api
import (
"net/http"
"regexp"
"strings"
)
// CORS represents CORS configuration
type CORS struct {
AllowOrigins []string
AllowHeaders []string
AllowMethods []string
ContentType string
allowOriginPatterns []string
}
// AllowRequestsFrom returns handler to verify that requests come from allowedOrigins
func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc {
c := &CORS{
AllowOrigins: allowedOrigins,
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Type"},
ContentType: "application/json; charset=utf-8",
}
c.generatePatterns()
return func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if c.isOriginAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders())
w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods())
w.Header().Set("Content-Type", c.ContentType)
}
}
}
// isOriginAllowed returns true if origin matches an allowed origin pattern.
func (c *CORS) isOriginAllowed(origin string) bool {
for _, allowedOriginPattern := range c.allowOriginPatterns {
if allowed, _ := regexp.MatchString(allowedOriginPattern, origin); allowed {
return true
}
}
return false
}
// generatePatterns generates regex expression for AllowOrigins
func (c *CORS) generatePatterns() {
if c.AllowOrigins != nil {
for _, origin := range c.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$")
}
}
}
// AllowedHeaders returns allowed headers in a string
func (c *CORS) AllowedHeaders() string {
return strings.Join(c.AllowHeaders, ",")
}
// AllowedMethods returns allowed http methods in a string
func (c *CORS) AllowedMethods() string {
return strings.Join(c.AllowMethods, ",")
}