Skip to content

Commit

Permalink
[TT-7818] Fix strip listen path case for oas router (#4715)
Browse files Browse the repository at this point in the history
  • Loading branch information
furkansenharputlu authored Jan 31, 2023
1 parent b3c80f3 commit 75beff2
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 58 deletions.
1 change: 1 addition & 0 deletions ci/tests/plugin-compiler/testplugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
var logger = log.Get()

// AddFooBarHeader adds custom "Foo: Bar" header to the request
//
//nolint:deadcode
func AddFooBarHeader(rw http.ResponseWriter, r *http.Request) {
r.Header.Add("Foo", "Bar")
Expand Down
1 change: 1 addition & 0 deletions ctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
RequestStatus
GraphQLRequest
GraphQLIsWebSocketUpgrade
OASOperation

// CacheOptions holds cache options required for cache writer middleware.
CacheOptions
Expand Down
11 changes: 11 additions & 0 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3374,6 +3374,17 @@ func ctxGetRequestStatus(r *http.Request) (stat RequestStatus) {
return
}

func ctxSetOperation(r *http.Request, op *Operation) {
setCtxValue(r, ctx.OASOperation, op)
}

func ctxGetOperation(r *http.Request) (op *Operation) {
if v := r.Context().Value(ctx.OASOperation); v != nil {
op = v.(*Operation)
}
return
}

var createOauthClientSecret = func() string {
secret := uuid.NewV4()
return base64.StdEncoding.EncodeToString([]byte(secret.String()))
Expand Down
26 changes: 9 additions & 17 deletions gateway/api_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,9 @@ type APISpec struct {
Schema *graphql.Schema
} `json:"-"`

hasMock bool
OASRouter routers.Router
HasMock bool
HasValidateRequest bool
OASRouter routers.Router
}

// GetSessionLifetimeRespectsKeyExpiration returns a boolean to tell whether session lifetime should respect to key expiration or not.
Expand Down Expand Up @@ -422,14 +423,9 @@ func (a APIDefinitionLoader) MakeSpec(def *nestedApiDefinition, logger *logrus.E
spec.OAS = *def.OAS
}

serverURL := spec.Proxy.ListenPath
if spec.Proxy.StripListenPath {
serverURL = "/"
}

oasSpec := spec.OAS.T
oasSpec.Servers = openapi3.Servers{
{URL: serverURL},
{URL: spec.Proxy.ListenPath},
}

spec.OASRouter, err = gorillamux.NewRouter(&oasSpec)
Expand Down Expand Up @@ -1723,24 +1719,20 @@ func (a *APISpec) SanitizeProxyPaths(r *http.Request) {
log.Debug("Upstream path is: ", r.URL.Path)
}

func (a *APISpec) HasMock() bool {
return a.hasMock
}

func (a *APISpec) setHasMock() {
if !a.IsOAS {
a.hasMock = false
a.HasMock = false
return
}

middleware := a.OAS.GetTykExtension().Middleware
if middleware == nil {
a.hasMock = false
a.HasMock = false
return
}

if len(middleware.Operations) == 0 {
a.hasMock = false
a.HasMock = false
return
}

Expand All @@ -1750,12 +1742,12 @@ func (a *APISpec) setHasMock() {
}

if operation.MockResponse.Enabled {
a.hasMock = true
a.HasMock = true
return
}
}

a.hasMock = false
a.HasMock = false
}

type RoundRobin struct {
Expand Down
17 changes: 8 additions & 9 deletions gateway/mock_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ const acceptCode = "X-Tyk-Accept-Example-Code"
const acceptExampleName = "X-Tyk-Accept-Example-Name"

func (p *ReverseProxy) mockResponse(r *http.Request) (*http.Response, error) {
route, _, err := p.TykAPISpec.OASRouter.FindRoute(r)
if route == nil || err != nil {
operation := ctxGetOperation(r)
if operation == nil {
return nil, nil
}

operation := p.TykAPISpec.OAS.GetTykExtension().Middleware.Operations[route.Operation.OperationID]
if operation == nil || !operation.MockResponse.Enabled {
mockResponse := operation.MockResponse
if mockResponse == nil || !mockResponse.Enabled {
return nil, nil
}

Expand All @@ -36,18 +36,17 @@ func (p *ReverseProxy) mockResponse(r *http.Request) (*http.Response, error) {
var contentType string
var body []byte
var headers map[string]string
var err error

tykExampleRespOp := p.TykAPISpec.OAS.GetTykExtension().Middleware.Operations[route.Operation.OperationID].MockResponse

if tykExampleRespOp.FromOASExamples != nil && tykExampleRespOp.FromOASExamples.Enabled {
code, contentType, body, headers, err = mockFromOAS(r, route.Operation, tykExampleRespOp.FromOASExamples)
if mockResponse.FromOASExamples != nil && mockResponse.FromOASExamples.Enabled {
code, contentType, body, headers, err = mockFromOAS(r, operation.route.Operation, mockResponse.FromOASExamples)
res.StatusCode = code
if err != nil {
err = fmt.Errorf("mock: %s", err)
return res, err
}
} else {
code, body, headers = mockFromConfig(tykExampleRespOp)
code, body, headers = mockFromConfig(mockResponse)
}

for key, val := range headers {
Expand Down
16 changes: 6 additions & 10 deletions gateway/mw_oas_validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func (k *ValidateRequest) EnabledForSpec() bool {
}

if operation.ValidateRequest.Enabled {
k.Spec.HasValidateRequest = true
return true
}
}
Expand All @@ -44,13 +45,8 @@ func (k *ValidateRequest) EnabledForSpec() bool {

// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (k *ValidateRequest) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
route, pathParams, err := k.Spec.OASRouter.FindRoute(r)
if err != nil {
return nil, http.StatusOK
}

operation, ok := k.Spec.OAS.GetTykExtension().Middleware.Operations[route.Operation.OperationID]
if !ok {
operation := ctxGetOperation(r)
if operation == nil {
return nil, http.StatusOK
}

Expand All @@ -67,11 +63,11 @@ func (k *ValidateRequest) ProcessRequest(w http.ResponseWriter, r *http.Request,
// Validate request
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: r,
PathParams: pathParams,
Route: route,
PathParams: operation.pathParams,
Route: operation.route,
}

err = openapi3filter.ValidateRequestBody(r.Context(), requestValidationInput, route.Operation.RequestBody.Value)
err := openapi3filter.ValidateRequestBody(r.Context(), requestValidationInput, operation.route.Operation.RequestBody.Value)
if err != nil {
return fmt.Errorf("request validation error: %v", err), errResponseCode
}
Expand Down
55 changes: 34 additions & 21 deletions gateway/mw_oas_validate_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestValidateRequest(t *testing.T) {
err = oasAPI.Validate(context.Background())
assert.NoError(t, err)

ts.Gw.BuildAndLoadAPI(
apis := ts.Gw.BuildAndLoadAPI(
func(spec *APISpec) {
spec.VersionData = def.VersionData
spec.Name = "without regexp"
Expand All @@ -134,25 +134,39 @@ func TestValidateRequest(t *testing.T) {
headers := map[string]string{"Content-Type": "application/json"}

t.Run("default error response code", func(t *testing.T) {
_, _ = ts.Run(t, []test.TestCase{
{Data: `{"name": 123}`, Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product/push"},
{Data: `{"name": 123}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost, Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product"}`, Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": 123}}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan"}}`, Code: http.StatusOK, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": 123}}}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Code: http.StatusOK, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": 123}}}`, Domain: "custom-domain",
Code: http.StatusUnprocessableEntity, Method: http.MethodPost, Headers: headers, Path: "/product-regexp1/something/post", Client: test.NewClientLocal()},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Domain: "custom-domain",
Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product-regexp1/something/post", Client: test.NewClientLocal()},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Domain: "custom-domain",
Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product-regexp2/something/suffix/post", Client: test.NewClientLocal()},
}...)
check := func(t *testing.T) {
_, _ = ts.Run(t, []test.TestCase{
{Data: `{"name": 123}`, Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product/push"},
{Data: `{"name": 123}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost, Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product"}`, Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": 123}}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan"}}`, Code: http.StatusOK, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": 123}}}`, Code: http.StatusUnprocessableEntity, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Code: http.StatusOK, Method: http.MethodPost,
Headers: headers, Path: "/product/post"},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": 123}}}`, Domain: "custom-domain",
Code: http.StatusUnprocessableEntity, Method: http.MethodPost, Headers: headers, Path: "/product-regexp1/something/post", Client: test.NewClientLocal()},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Domain: "custom-domain",
Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product-regexp1/something/post", Client: test.NewClientLocal()},
{Data: `{"name": "my-product", "owner": {"name": "Furkan", "country": {"name": "Türkiye"}}}`, Domain: "custom-domain",
Code: http.StatusOK, Method: http.MethodPost, Headers: headers, Path: "/product-regexp2/something/suffix/post", Client: test.NewClientLocal()},
}...)
}

t.Run("stripListenPath=false", func(t *testing.T) {
check(t)
})

t.Run("stripListenPath=true", func(t *testing.T) {
apis[0].Proxy.StripListenPath = true
apis[1].Proxy.StripListenPath = true
apis[2].Proxy.StripListenPath = true
ts.Gw.LoadAPI(apis...)
check(t)
})
})

t.Run("custom error response code", func(t *testing.T) {
Expand All @@ -170,5 +184,4 @@ func TestValidateRequest(t *testing.T) {
{Data: `{"name": 123}`, Code: http.StatusTeapot, Method: http.MethodPost, Headers: headers, Path: "/product/post"},
}...)
})

}
28 changes: 28 additions & 0 deletions gateway/mw_version_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"net/http"
"time"

"github.com/getkin/kin-openapi/routers"

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/apidef/oas"
"github.com/TykTechnologies/tyk/request"
)

Expand Down Expand Up @@ -35,6 +38,26 @@ func (v *VersionCheck) DoMockReply(w http.ResponseWriter, meta apidef.MockRespon
w.Write(responseMessage)
}

type Operation struct {
*oas.Operation
route *routers.Route
pathParams map[string]string
}

func findRouteAndOperation(spec *APISpec, r *http.Request) {
route, pathParams, err := spec.OASRouter.FindRoute(r)
if err != nil {
return
}

operation, ok := spec.OAS.GetTykExtension().Middleware.Operations[route.Operation.OperationID]
if !ok {
return
}

ctxSetOperation(r, &Operation{Operation: operation, route: route, pathParams: pathParams})
}

// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (v *VersionCheck) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
targetVersion := v.Spec.getVersionFromRequest(r)
Expand All @@ -59,6 +82,11 @@ func (v *VersionCheck) ProcessRequest(w http.ResponseWriter, r *http.Request, _
return nil, mwStatusRespond
}

// For OAS route matching
if v.Spec.HasMock || v.Spec.HasValidateRequest {
findRouteAndOperation(v.Spec, r)
}

// Check versioning, blacklist, whitelist and ignored status
requestValid, stat := v.Spec.RequestValid(r)
if !requestValid {
Expand Down
2 changes: 1 addition & 1 deletion gateway/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ func (p *ReverseProxy) handleOutboundRequest(roundTripper *TykRoundTripper, outr
latency = time.Since(begin)
}()

if p.TykAPISpec.HasMock() {
if p.TykAPISpec.HasMock {
if res, err = p.mockResponse(outreq); res != nil {
return
}
Expand Down
1 change: 1 addition & 0 deletions smoke-tests/plugin-aliasing/foobar-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
var logger = log.Get()

// AddFooBarHeader adds custom "Foo: Bar" header to the request
//
//nolint:deadcode
func AddFooBarHeader(rw http.ResponseWriter, r *http.Request) {
r.Header.Add("Foo", "Bar")
Expand Down
1 change: 1 addition & 0 deletions smoke-tests/plugin-aliasing/helloworld-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
var logger = log.Get()

// AddHelloWorldHeader adds custom "Foo: Bar" header to the request
//
//nolint:deadcode
func AddHelloWorldHeader(rw http.ResponseWriter, r *http.Request) {
r.Header.Add("Hello", "World")
Expand Down

0 comments on commit 75beff2

Please sign in to comment.