Skip to content

Commit

Permalink
add graphql websockets support (TT-904) (#3393)
Browse files Browse the repository at this point in the history
* refactor graphql request handling in reverse proxy

* implement handshake and start websocket

* temp: failing websocket tests

* temp: hanging websocket test

* implement test for graphql websockets

* fail when request is websocket upgrade but websockets are disabled

* refactor some code

* update vendoring from graphql-go-tools

* update vendoring after rebase

* rewrite if statement for websockets
  • Loading branch information
pvormste authored Nov 23, 2020
1 parent 84eac28 commit a53ac15
Show file tree
Hide file tree
Showing 138 changed files with 23,975 additions and 147 deletions.
1 change: 1 addition & 0 deletions ctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
Definition
RequestStatus
GraphQLRequest
GraphQLIsWebSocketUpgrade
)

func setContext(r *http.Request, ctx context.Context) {
Expand Down
14 changes: 14 additions & 0 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,20 @@ func ctxGetGraphQLRequest(r *http.Request) (gqlRequest *gql.Request) {
return nil
}

func ctxSetGraphQLIsWebSocketUpgrade(r *http.Request, isWebSocketUpgrade bool) {
setCtxValue(r, ctx.GraphQLIsWebSocketUpgrade, isWebSocketUpgrade)
}

func ctxGetGraphQLIsWebSocketUpgrade(r *http.Request) (isWebSocketUpgrade bool) {
if v := r.Context().Value(ctx.GraphQLIsWebSocketUpgrade); v != nil {
if isWebSocketUpgrade, ok := v.(bool); ok {
return isWebSocketUpgrade
}
}

return false
}

func ctxGetDefaultVersion(r *http.Request) bool {
return r.Context().Value(ctx.VersionDefault) != nil
}
Expand Down
24 changes: 24 additions & 0 deletions gateway/mw_graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"errors"
"net/http"

"github.com/gorilla/websocket"
"github.com/jensneuse/abstractlogger"
"github.com/jensneuse/graphql-go-tools/pkg/execution/datasource"
"github.com/sirupsen/logrus"

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/config"
"github.com/TykTechnologies/tyk/headers"

gql "github.com/jensneuse/graphql-go-tools/pkg/graphql"
Expand All @@ -23,6 +25,10 @@ const (
TykGraphQLDataSource = "TykGraphQLDataSource"
)

const (
GraphQLWebSocketProtocol = "graphql-ws"
)

type GraphQLMiddleware struct {
BaseMiddleware
}
Expand Down Expand Up @@ -129,6 +135,19 @@ func (m *GraphQLMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reques
return errors.New("there was a problem proxying the request"), http.StatusInternalServerError
}

if websocket.IsWebSocketUpgrade(r) {
if !config.Global().HttpServerOptions.EnableWebSockets {
return errors.New("websockets are not allowed"), http.StatusUnprocessableEntity
}

if !m.websocketUpgradeUsesGraphQLProtocol(r) {
return errors.New("invalid websocket protocol for upgrading to a graphql websocket connection"), http.StatusBadRequest
}

ctxSetGraphQLIsWebSocketUpgrade(r, true)
return nil, http.StatusSwitchingProtocols
}

var gqlRequest gql.Request
err := gql.UnmarshalRequest(r.Body, &gqlRequest)
if err != nil {
Expand Down Expand Up @@ -169,6 +188,11 @@ func (m *GraphQLMiddleware) writeGraphQLError(w http.ResponseWriter, errors gql.
return errCustomBodyResponse, http.StatusBadRequest
}

func (m *GraphQLMiddleware) websocketUpgradeUsesGraphQLProtocol(r *http.Request) bool {
websocketProtocol := r.Header.Get(headers.SecWebSocketProtocol)
return websocketProtocol == GraphQLWebSocketProtocol
}

func absLoggerLevel(level logrus.Level) abstractlogger.Level {
switch level {
case logrus.ErrorLevel:
Expand Down
72 changes: 71 additions & 1 deletion gateway/mw_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ package gateway

import (
"net/http"
"strings"
"testing"

"github.com/TykTechnologies/tyk/apidef"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/config"
"github.com/TykTechnologies/tyk/headers"
"github.com/TykTechnologies/tyk/user"

Expand Down Expand Up @@ -164,6 +169,71 @@ func TestGraphQLMiddleware_EngineMode(t *testing.T) {
spec.GraphQL.ExecutionMode = apidef.GraphQLExecutionModeExecutionEngine
})

t.Run("on disabled websockets", func(t *testing.T) {
defer ResetTestConfig()
cfg := config.Global()
cfg.HttpServerOptions.EnableWebSockets = false
config.SetGlobal(cfg)

t.Run("should respond with 422 when trying to upgrade to websockets", func(t *testing.T) {
_, _ = g.Run(t, []test.TestCase{
{
Headers: map[string]string{
headers.Connection: "upgrade",
headers.Upgrade: "websocket",
headers.SecWebSocketProtocol: "graphql-ws",
headers.SecWebSocketVersion: "13",
headers.SecWebSocketKey: "123abc",
},
Code: http.StatusUnprocessableEntity,
BodyMatch: "websockets are not allowed",
},
}...)
})
})

t.Run("graphql websocket upgrade", func(t *testing.T) {
defer ResetTestConfig()
cfg := config.Global()
cfg.HttpServerOptions.EnableWebSockets = true
config.SetGlobal(cfg)

t.Run("should deny upgrade with 400 when protocol is not graphql-ws", func(t *testing.T) {
_, _ = g.Run(t, []test.TestCase{
{
Headers: map[string]string{
headers.Connection: "upgrade",
headers.Upgrade: "websocket",
headers.SecWebSocketProtocol: "invalid",
headers.SecWebSocketVersion: "13",
headers.SecWebSocketKey: "123abc",
},
Code: http.StatusBadRequest,
BodyMatch: "invalid websocket protocol for upgrading to a graphql websocket connection",
},
}...)
})

t.Run("should upgrade to websocket connection with correct protocol", func(t *testing.T) {
baseURL := strings.Replace(g.URL, "http://", "ws://", -1)
wsConn, _, err := websocket.DefaultDialer.Dial(baseURL, map[string][]string{
headers.SecWebSocketProtocol: {GraphQLWebSocketProtocol},
})
require.NoError(t, err)
defer wsConn.Close()

// Send a connection init message to gateway
err = wsConn.WriteMessage(websocket.BinaryMessage, []byte(`{"type":"connection_init","payload":{}}`))
require.NoError(t, err)

_, msg, err := wsConn.ReadMessage()

// Gateway should acknowledge the connection
assert.Equal(t, `{"id":"","type":"connection_ack","payload":null}`, string(msg))
assert.NoError(t, err)
})
})

t.Run("graphql api requests", func(t *testing.T) {
countries1 := gql.Request{
Query: "query Query { countries { name } }",
Expand Down
Loading

0 comments on commit a53ac15

Please sign in to comment.