Skip to content

Commit

Permalink
server: abstract proxy into separate package (openshift#1074)
Browse files Browse the repository at this point in the history
* server: abstract proxy into separate package

This commit moves all of the proxy logic into a new package located at
`pkg/proxy`. The motivation for this is to separate the concerns and to
allow us to leverage this package in future tools that require proxying.
  • Loading branch information
squat authored Oct 20, 2016
1 parent aeaa60d commit 69a904d
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 91 deletions.
9 changes: 9 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ type Authenticator struct {
successURL string
}

// The trivial token "extractor" always extracts a constant string.
// s should not be the empty string, or else the Authorization header
// may end up as "Bearer ".
func ConstantTokenExtractor(s string) func(*http.Request) (string, error) {
return func(_ *http.Request) (string, error) {
return s, nil
}
}

func NewAuthenticator(ccfg oidc.ClientConfig, issuerURL *url.URL, errorURL, successURL string) (*Authenticator, error) {
client, err := oidc.NewClient(ccfg)
if err != nil {
Expand Down
23 changes: 12 additions & 11 deletions cmd/bridge/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"k8s.io/kubernetes/pkg/client/restclient"

"github.com/coreos-inc/bridge/auth"
"github.com/coreos-inc/bridge/pkg/proxy"
"github.com/coreos-inc/bridge/server"
"github.com/coreos-inc/bridge/stats"
"github.com/coreos-inc/bridge/verify"
Expand Down Expand Up @@ -178,7 +179,7 @@ func main() {
log.Fatalf("Kubernetes config provided invalid URL: %v", err)
}

srv.K8sProxyConfig = &server.ProxyConfig{
srv.K8sProxyConfig = &proxy.Config{
TLSClientConfig: inClusterTLSCfg,
HeaderBlacklist: []string{"Cookie"},
Endpoint: k8sURL,
Expand All @@ -188,7 +189,7 @@ func main() {
case "off-cluster":
k8sModeOffClusterEndpointURL := validateFlagIsURL("k8s-mode-off-cluster-endpoint", *fK8sModeOffClusterEndpoint)

srv.K8sProxyConfig = &server.ProxyConfig{
srv.K8sProxyConfig = &proxy.Config{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: *fK8sModeOffClusterSkipVerifyTLS,
},
Expand Down Expand Up @@ -223,27 +224,27 @@ func main() {
ID: *fUserAuthOIDCClientID,
Secret: *fUserAuthOIDCClientSecret,
},
RedirectURL: server.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginCallbackEndpoint),
RedirectURL: proxy.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginCallbackEndpoint),
Scope: []string{"openid", "email", "profile"},
}

var (
err error
authLoginErrorEndpoint = server.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginErrorEndpoint)
authLoginSuccessEndpoint = server.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginSuccessEndpoint)
authLoginErrorEndpoint = proxy.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginErrorEndpoint)
authLoginSuccessEndpoint = proxy.SingleJoiningSlash(srv.BaseURL.String(), server.AuthLoginSuccessEndpoint)
)

dexProxyConfigEndpoint := validateFlagIsURL("user-auth-oidc-issuer-url", *fUserAuthOIDCIssuerURL)
dexProxyConfigEndpoint.Path = server.SingleJoiningSlash(dexProxyConfigEndpoint.Path, "/api")
dexProxyConfigEndpoint.Path = proxy.SingleJoiningSlash(dexProxyConfigEndpoint.Path, "/api")

srv.DexProxyConfig = &server.ProxyConfig{
srv.DexProxyConfig = &proxy.Config{
Endpoint: dexProxyConfigEndpoint,
TLSClientConfig: &tls.Config{
RootCAs: certPool,
},
HeaderBlacklist: []string{"Cookie"},
TokenExtractor: auth.ExtractTokenFromCookie,
}
srv.DexProxyConfig.Director = server.DirectorFromTokenExtractor(srv.DexProxyConfig, auth.ExtractTokenFromCookie)

if *fKubectlClientID != "" {
srv.KubectlClientID = *fKubectlClientID
Expand Down Expand Up @@ -304,13 +305,13 @@ func main() {
switch *fK8sAuth {
case "service-account":
validateFlagIs("k8s-mode", *fK8sMode, "in-cluster")
srv.K8sProxyConfig.TokenExtractor = server.ConstantTokenExtractor(k8sAuthServiceAccountBearerToken)
srv.K8sProxyConfig.Director = server.DirectorFromTokenExtractor(srv.K8sProxyConfig, auth.ConstantTokenExtractor(k8sAuthServiceAccountBearerToken))
case "bearer-token":
validateFlagNotEmpty("k8s-auth-bearer-token", *fK8sAuthBearerToken)
srv.K8sProxyConfig.TokenExtractor = server.ConstantTokenExtractor(*fK8sAuthBearerToken)
srv.K8sProxyConfig.Director = server.DirectorFromTokenExtractor(srv.K8sProxyConfig, auth.ConstantTokenExtractor(*fK8sAuthBearerToken))
case "oidc":
validateFlagIs("user-auth", *fUserAuth, "oidc")
srv.K8sProxyConfig.TokenExtractor = auth.ExtractTokenFromCookie
srv.K8sProxyConfig.Director = server.DirectorFromTokenExtractor(srv.K8sProxyConfig, auth.ExtractTokenFromCookie)
default:
flagFatalf("k8s-mode", "must be one of: service-account, bearer-token, oidc")
}
Expand Down
76 changes: 28 additions & 48 deletions server/proxy.go → pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package server
package proxy

import (
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
Expand All @@ -12,40 +12,21 @@ import (
"time"

"golang.org/x/net/websocket"

"github.com/coreos/go-oidc/oidc"
)

type ProxyConfig struct {
type Config struct {
HeaderBlacklist []string
Endpoint *url.URL
TokenExtractor oidc.RequestTokenExtractor
TLSClientConfig *tls.Config
Director func(*http.Request)
}

// The trivial token "extractor" always extracts a constant string
func ConstantTokenExtractor(s string) func(*http.Request) (string, error) {
var err error = nil
if s == "" {
err = errors.New("no token present")
}

return func(_ *http.Request) (string, error) {
return s, err
}
}

const (
proxyWriteDeadline = time.Second * 10
proxyReadDeadline = time.Second * 10
)

type proxy struct {
type Proxy struct {
reverseProxy *httputil.ReverseProxy
config *ProxyConfig
config *Config
}

func newProxy(cfg *ProxyConfig) *proxy {
func NewProxy(cfg *Config) *Proxy {
// Copy of http.DefaultTransport with TLSClientConfig added
insecureTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Expand All @@ -62,14 +43,15 @@ func newProxy(cfg *ProxyConfig) *proxy {
Transport: insecureTransport,
}

proxy := &proxy{
proxy := &Proxy{
reverseProxy: reverseProxy,
config: cfg,
}

reverseProxy.Director = func(r *http.Request) {
proxy.rewriteRequest(r)
if cfg.Director == nil {
cfg.Director = proxy.director
}
reverseProxy.Director = cfg.Director

return proxy
}
Expand All @@ -86,13 +68,10 @@ func SingleJoiningSlash(a, b string) string {
return a + b
}

func (p *proxy) rewriteRequest(r *http.Request) {
// At this writing, the only errors we can get from TokenExtractor
// are benign and correct variations on "no token found"
if token, err := p.config.TokenExtractor(r); err == nil {
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

// director is a default function to rewrite the request being
// proxied. If the user does not supply a custom director function,
// then this will be used.
func (p *Proxy) director(r *http.Request) {
for _, h := range p.config.HeaderBlacklist {
r.Header.Del(h)
}
Expand All @@ -103,9 +82,9 @@ func (p *proxy) rewriteRequest(r *http.Request) {
r.URL.Scheme = p.config.Endpoint.Scheme
}

func (p *proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
isWebsocket := false
upgrades := req.Header["Upgrade"]
upgrades := r.Header["Upgrade"]

for _, upgrade := range upgrades {
if strings.ToLower(upgrade) == "websocket" {
Expand All @@ -115,23 +94,23 @@ func (p *proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
}

if !isWebsocket {
p.reverseProxy.ServeHTTP(res, req)
p.reverseProxy.ServeHTTP(w, r)
return
}

p.rewriteRequest(req)
p.config.Director(r)

if req.URL.Scheme == "https" {
req.URL.Scheme = "wss"
if r.URL.Scheme == "https" {
r.URL.Scheme = "wss"
} else {
req.URL.Scheme = "ws"
r.URL.Scheme = "ws"
}

config := &websocket.Config{
Location: req.URL,
Location: r.URL,
Version: websocket.ProtocolVersionHybi13,
TlsConfig: p.config.TLSClientConfig,
Header: req.Header,
Header: r.Header,

// NOTE (ericchiang): K8s might not enforce this but websockets requests are
// required to supply an origin.
Expand All @@ -140,8 +119,8 @@ func (p *proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {

backend, err := websocket.DialConfig(config)
if err != nil {
plog.Errorf("Failed to dial backend: %v", err)
http.Error(res, "bad gateway", http.StatusBadGateway)
log.Printf("Failed to dial backend: %v", err)
http.Error(w, "bad gateway", http.StatusBadGateway)
return
}
defer backend.Close()
Expand All @@ -166,7 +145,7 @@ func (p *proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// Only wait for a single error and let the defers close both connections.
<-errc

}).ServeHTTP(res, req)
}).ServeHTTP(w, r)
}

func copyFrames(dest, src *websocket.Conn) error {
Expand All @@ -184,6 +163,7 @@ func copyFrames(dest, src *websocket.Conn) error {
}

// frameCodec is a websocket.Codec which preserves frame types for copying.

//
// This differs from websocket.Message which presents a different frame type for "[]byte" and "string".
var frameCodec = websocket.Codec{Marshal: marshalFrame, Unmarshal: unmarshalFrame}
Expand Down
41 changes: 28 additions & 13 deletions server/proxy_test.go → pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server
package proxy

import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -96,10 +97,9 @@ func TestProxyDirector(t *testing.T) {
}

for i, tt := range tests {
p := newProxy(&ProxyConfig{
p := NewProxy(&Config{
Endpoint: &tt.target,
HeaderBlacklist: tt.blacklist,
TokenExtractor: ConstantTokenExtractor(""),
})
p.reverseProxy.Director(tt.in)

Expand Down Expand Up @@ -174,12 +174,11 @@ func startProxyServer(t *testing.T) (string, func(), error) {
return "", nil, err
}
targetURL.Path = ""
proxy := newProxy(&ProxyConfig{
Endpoint: targetURL,
TokenExtractor: ConstantTokenExtractor(""),
p := NewProxy(&Config{
Endpoint: targetURL,
})
proxyMux := http.NewServeMux()
proxyMux.Handle("/proxy/", http.StripPrefix("/proxy/", proxy))
proxyMux.Handle("/proxy/", http.StripPrefix("/proxy/", p))
proxyServer := httptest.NewServer(proxyMux)

return proxyServer.URL, func() {
Expand Down Expand Up @@ -261,13 +260,12 @@ func TestProxyRewriteRequestAuthorization(t *testing.T) {
}

for i, tt := range tests {
p := &proxy{
config: &ProxyConfig{
Endpoint: testurl,
TokenExtractor: ConstantTokenExtractor(tt.tok),
},
c := &Config{
Endpoint: testurl,
}
p.rewriteRequest(tt.req)
c.Director = DirectorFromToken(c, tt.tok)
p := NewProxy(c)
p.config.Director(tt.req)
got := tt.req.Header.Get("Authorization")
if tt.want != got {
t.Errorf("case %d: unexpected header: want=%q got=%q", i, tt.want, got)
Expand All @@ -283,3 +281,20 @@ func mustNewRequestWithHeader(t *testing.T, hdr http.Header) *http.Request {
req.Header = hdr
return req
}

func DirectorFromToken(config *Config, token string) func(*http.Request) {
return func(r *http.Request) {
if token != "" {
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

for _, h := range config.HeaderBlacklist {
r.Header.Del(h)
}

r.Host = config.Endpoint.Host
r.URL.Host = config.Endpoint.Host
r.URL.Path = SingleJoiningSlash(config.Endpoint.Path, r.URL.Path)
r.URL.Scheme = config.Endpoint.Scheme
}
}
Loading

0 comments on commit 69a904d

Please sign in to comment.