Skip to content

Commit

Permalink
Disallow creating a key with the same certificate (TykTechnologies#2637)
Browse files Browse the repository at this point in the history
  • Loading branch information
buger authored Oct 25, 2019
1 parent 4355f89 commit 31432a5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 46 deletions.
5 changes: 5 additions & 0 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,11 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {

if newSession.Certificate != "" {
newKey = generateToken(newSession.OrgID, newSession.Certificate)
_, ok := FallbackKeySesionManager.SessionDetail(newKey, false)
if ok {
doJSONWrite(w, http.StatusInternalServerError, apiError("Failed to create key - Key with given certificate already found:"+newKey))
return
}
}

newSession.LastUpdated = strconv.Itoa(int(time.Now().Unix()))
Expand Down
2 changes: 1 addition & 1 deletion gateway/cert_go1.10_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestProxyTransport(t *testing.T) {
spec.Proxy.Transport.ProxyURL = proxy.URL
})

client := getTLSClient(nil, nil)
client := GetTLSClient(nil, nil)
client.Transport = &http.Transport{
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
}
Expand Down
65 changes: 29 additions & 36 deletions gateway/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@ import (
"github.com/TykTechnologies/tyk/user"
)

func getTLSClient(cert *tls.Certificate, caCert []byte) *http.Client {
// Setup HTTPS client
tlsConfig := &tls.Config{}

if cert != nil {
tlsConfig.Certificates = []tls.Certificate{*cert}
}

if len(caCert) > 0 {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
tlsConfig.BuildNameToCertificate()
} else {
tlsConfig.InsecureSkipVerify = true
}

transport := &http.Transport{TLSClientConfig: tlsConfig}

return &http.Client{Transport: transport}
}

func genCertificate(template *x509.Certificate) ([]byte, []byte, []byte, tls.Certificate) {
priv, _ := rsa.GenerateKey(rand.Reader, 512)

Expand Down Expand Up @@ -91,7 +69,7 @@ func TestGatewayTLS(t *testing.T) {
dir, _ := ioutil.TempDir("", "certs")
defer os.RemoveAll(dir)

client := getTLSClient(nil, nil)
client := GetTLSClient(nil, nil)

t.Run("Without certificates", func(t *testing.T) {
globalConf := config.Global()
Expand Down Expand Up @@ -204,9 +182,9 @@ func TestGatewayControlAPIMutualTLS(t *testing.T) {
}()

clientCertPem, _, _, clientCert := genCertificate(&x509.Certificate{})
clientWithCert := getTLSClient(&clientCert, serverCertPem)
clientWithCert := GetTLSClient(&clientCert, serverCertPem)

clientWithoutCert := getTLSClient(nil, nil)
clientWithoutCert := GetTLSClient(nil, nil)

t.Run("Separate domain", func(t *testing.T) {
certID, _ := CertificateManager.Add(combinedPEM, "")
Expand Down Expand Up @@ -276,7 +254,7 @@ func TestAPIMutualTLS(t *testing.T) {

t.Run("SNI and domain per API", func(t *testing.T) {
t.Run("API without mutual TLS", func(t *testing.T) {
client := getTLSClient(&clientCert, serverCertPem)
client := GetTLSClient(&clientCert, serverCertPem)

BuildAndLoadAPI(func(spec *APISpec) {
spec.Domain = "localhost"
Expand All @@ -287,7 +265,7 @@ func TestAPIMutualTLS(t *testing.T) {
})

t.Run("MutualTLSCertificate not set", func(t *testing.T) {
client := getTLSClient(nil, nil)
client := GetTLSClient(nil, nil)

BuildAndLoadAPI(func(spec *APISpec) {
spec.Domain = "localhost"
Expand All @@ -303,7 +281,7 @@ func TestAPIMutualTLS(t *testing.T) {
})

t.Run("Client certificate match", func(t *testing.T) {
client := getTLSClient(&clientCert, serverCertPem)
client := GetTLSClient(&clientCert, serverCertPem)
clientCertID, _ := CertificateManager.Add(clientCertPem, "")

BuildAndLoadAPI(func(spec *APISpec) {
Expand All @@ -320,14 +298,14 @@ func TestAPIMutualTLS(t *testing.T) {
CertificateManager.Delete(clientCertID)
CertificateManager.FlushCache()

client = getTLSClient(&clientCert, serverCertPem)
client = GetTLSClient(&clientCert, serverCertPem)
ts.Run(t, test.TestCase{
Client: client, Domain: "localhost", ErrorMatch: badcertErr,
})
})

t.Run("Client certificate differ", func(t *testing.T) {
client := getTLSClient(&clientCert, serverCertPem)
client := GetTLSClient(&clientCert, serverCertPem)

clientCertPem2, _, _, _ := genCertificate(&x509.Certificate{})
clientCertID2, _ := CertificateManager.Add(clientCertPem2, "")
Expand Down Expand Up @@ -364,7 +342,7 @@ func TestAPIMutualTLS(t *testing.T) {
}

t.Run("Without certificate", func(t *testing.T) {
clientWithoutCert := getTLSClient(nil, nil)
clientWithoutCert := GetTLSClient(nil, nil)

loadAPIS()

Expand All @@ -385,7 +363,7 @@ func TestAPIMutualTLS(t *testing.T) {
})

t.Run("Client certificate not match", func(t *testing.T) {
client := getTLSClient(&clientCert, serverCertPem)
client := GetTLSClient(&clientCert, serverCertPem)

loadAPIS()

Expand All @@ -401,7 +379,7 @@ func TestAPIMutualTLS(t *testing.T) {

t.Run("Client certificate match", func(t *testing.T) {
loadAPIS(clientCertID)
client := getTLSClient(&clientCert, serverCertPem)
client := GetTLSClient(&clientCert, serverCertPem)

ts.Run(t, test.TestCase{
Path: "/with_mutual",
Expand Down Expand Up @@ -431,7 +409,7 @@ func TestUpstreamMutualTLS(t *testing.T) {
defer upstream.Close()

t.Run("Without API", func(t *testing.T) {
client := getTLSClient(&clientCert, nil)
client := GetTLSClient(&clientCert, nil)

if _, err := client.Get(upstream.URL); err == nil {
t.Error("Should reject without certificate")
Expand Down Expand Up @@ -496,20 +474,35 @@ func TestKeyWithCertificateTLS(t *testing.T) {
spec.OrgID = "default"
})

client := getTLSClient(&clientCert, nil)
client := GetTLSClient(&clientCert, nil)

t.Run("Cert unknown", func(t *testing.T) {
ts.Run(t, test.TestCase{Code: 403, Client: client})
})

t.Run("Cert known", func(t *testing.T) {
CreateSession(func(s *user.SessionState) {
_, key := ts.CreateSession(func(s *user.SessionState) {
s.Certificate = clientCertID
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1"},
}}
})

if key == "" {
t.Fatal("Should create key based on certificate")
}

_, key = ts.CreateSession(func(s *user.SessionState) {
s.Certificate = clientCertID
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1"},
}}
})

if key != "" {
t.Fatal("Should not allow create key based on the same certificate")
}

ts.Run(t, test.TestCase{Path: "/", Code: 200, Client: client})
})
}
Expand Down
6 changes: 3 additions & 3 deletions gateway/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestHTTP2_TLS(t *testing.T) {
})

// HTTP/2 client
http2Client := getTLSClient(&clientCert, serverCertPem)
http2Client := GetTLSClient(&clientCert, serverCertPem)
http2.ConfigureTransport(http2Client.Transport.(*http.Transport))

ts.Run(t, test.TestCase{Client: http2Client, Path: "", Code: 200, Proto: "HTTP/2.0", BodyMatch: "Hello, I am an HTTP/2 Server"})
Expand Down Expand Up @@ -215,7 +215,7 @@ func TestGRPC_BasicAuthentication(t *testing.T) {

address := strings.TrimPrefix(ts.URL, "https://")
name := "Furkan"
client := getTLSClient(nil, nil)
client := GetTLSClient(nil, nil)

// To create key
ts.Run(t, []test.TestCase{
Expand Down Expand Up @@ -270,7 +270,7 @@ func TestGRPC_TokenBasedAuthentication(t *testing.T) {

address := strings.TrimPrefix(ts.URL, "https://")
name := "Furkan"
client := getTLSClient(nil, nil)
client := GetTLSClient(nil, nil)

// To create key
resp, _ := ts.Run(t, []test.TestCase{
Expand Down
35 changes: 29 additions & 6 deletions gateway/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ import (
"bytes"
"compress/gzip"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/json"
"fmt"
"net/http/httptest"

"golang.org/x/net/context"

"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
Expand All @@ -26,15 +24,15 @@ import (
"testing"
"time"

"github.com/TykTechnologies/tyk/cli"

jwt "github.com/dgrijalva/jwt-go"
"github.com/garyburd/redigo/redis"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
uuid "github.com/satori/go.uuid"
"golang.org/x/net/context"

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/cli"
"github.com/TykTechnologies/tyk/config"
"github.com/TykTechnologies/tyk/storage"
"github.com/TykTechnologies/tyk/test"
Expand Down Expand Up @@ -683,16 +681,41 @@ func (s *Test) RunExt(t testing.TB, testCases ...test.TestCase) {
}
}

func GetTLSClient(cert *tls.Certificate, caCert []byte) *http.Client {
// Setup HTTPS client
tlsConfig := &tls.Config{}

if cert != nil {
tlsConfig.Certificates = []tls.Certificate{*cert}
}

if len(caCert) > 0 {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
tlsConfig.BuildNameToCertificate()
} else {
tlsConfig.InsecureSkipVerify = true
}

transport := &http.Transport{TLSClientConfig: tlsConfig}

return &http.Client{Transport: transport}
}

func (s *Test) CreateSession(sGen ...func(s *user.SessionState)) (*user.SessionState, string) {
session := CreateStandardSession()
if len(sGen) > 0 {
sGen[0](session)
}

client := GetTLSClient(nil, nil)

resp, err := s.Do(test.TestCase{
Method: http.MethodPost,
Path: "/tyk/keys/create",
Data: session,
Client: client,
AdminAuth: true,
})

Expand Down

0 comments on commit 31432a5

Please sign in to comment.