From c0c00e6241a5950075e5c5f12b2e66a42cf0348b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 15 Jul 2021 23:34:01 +0300 Subject: [PATCH 01/16] V5.0.0-alpha --- .github/workflows/echo.yml | 36 +- .travis.yml | 21 - LICENSE | 2 +- Makefile | 6 +- README.md | 43 +- bind.go | 104 +- bind_test.go | 240 +++-- binder.go | 4 +- binder_external_test.go | 2 +- binder_go1.15_test.go | 265 ----- binder_test.go | 238 ++++- context.go | 640 ++++++----- context_fs.go | 33 - context_fs_go1.16.go | 52 - context_fs_go1.16_test.go | 135 --- context_test.go | 725 +++++++------ echo.go | 1053 ++++++++---------- echo_fs.go | 62 -- echo_fs_go1.16.go | 145 --- echo_fs_go1.16_test.go | 265 ----- echo_test.go | 1191 +++++++-------------- go.mod | 15 +- go.sum | 34 +- group.go | 180 ++-- group_fs.go | 9 - group_fs_go1.16.go | 33 - group_fs_go1.16_test.go | 106 -- group_test.go | 613 ++++++++++- httperror.go | 74 ++ httperror_test.go | 52 + json.go | 11 +- json_test.go | 10 +- log.go | 175 ++- log_test.go | 87 ++ middleware/DEVELOPMENT.md | 11 + middleware/basic_auth.go | 97 +- middleware/basic_auth_test.go | 175 ++- middleware/body_dump.go | 66 +- middleware/body_dump_test.go | 88 +- middleware/body_limit.go | 97 +- middleware/body_limit_test.go | 118 ++- middleware/compress.go | 66 +- middleware/compress_test.go | 154 +-- middleware/cors.go | 123 +-- middleware/cors_test.go | 34 +- middleware/csrf.go | 169 ++- middleware/csrf_test.go | 74 +- middleware/decompress.go | 46 +- middleware/decompress_test.go | 96 +- middleware/extractor.go | 43 +- middleware/extractor_test.go | 48 +- middleware/jwt.go | 325 ++---- middleware/jwt_external_test.go | 76 ++ middleware/jwt_test.go | 665 +++++------- middleware/key_auth.go | 198 ++-- middleware/key_auth_test.go | 193 ++-- middleware/logger.go | 187 ++-- middleware/logger_test.go | 4 +- middleware/method_override.go | 46 +- middleware/method_override_test.go | 68 +- middleware/middleware.go | 21 +- middleware/proxy.go | 167 +-- middleware/proxy_test.go | 21 +- middleware/rate_limiter.go | 106 +- middleware/rate_limiter_test.go | 135 ++- middleware/recover.go | 111 +- middleware/recover_test.go | 178 ++-- middleware/redirect.go | 140 ++- middleware/redirect_test.go | 2 +- middleware/request_id.go | 58 +- middleware/request_id_test.go | 101 +- middleware/request_logger.go | 15 +- middleware/request_logger_test.go | 4 +- middleware/rewrite.go | 74 +- middleware/rewrite_test.go | 77 +- middleware/secure.go | 162 +-- middleware/secure_test.go | 84 +- middleware/slash.go | 82 +- middleware/slash_test.go | 6 +- middleware/static.go | 223 ++-- middleware/static_1_16_test.go | 106 -- middleware/static_test.go | 263 ++++- middleware/timeout.go | 182 ---- middleware/timeout_test.go | 443 -------- middleware/util.go | 39 + middleware/util_test.go | 40 +- response.go | 31 +- route.go | 182 ++++ route_test.go | 423 ++++++++ router.go | 845 +++++++++++---- router_test.go | 1589 +++++++++++++++++++++------- server.go | 213 ++++ server_test.go | 815 ++++++++++++++ 93 files changed, 9654 insertions(+), 7212 deletions(-) delete mode 100644 .travis.yml delete mode 100644 binder_go1.15_test.go delete mode 100644 context_fs.go delete mode 100644 context_fs_go1.16.go delete mode 100644 context_fs_go1.16_test.go delete mode 100644 echo_fs.go delete mode 100644 echo_fs_go1.16.go delete mode 100644 echo_fs_go1.16_test.go delete mode 100644 group_fs.go delete mode 100644 group_fs_go1.16.go delete mode 100644 group_fs_go1.16_test.go create mode 100644 httperror.go create mode 100644 httperror_test.go create mode 100644 log_test.go create mode 100644 middleware/DEVELOPMENT.md create mode 100644 middleware/jwt_external_test.go delete mode 100644 middleware/static_1_16_test.go delete mode 100644 middleware/timeout.go delete mode 100644 middleware/timeout_test.go create mode 100644 route.go create mode 100644 route_test.go create mode 100644 server.go create mode 100644 server_test.go diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 266406664..d1967212e 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,6 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' + workflow_dispatch: # to be able to run workflow manually jobs: test: @@ -27,28 +28,18 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.14, 1.15, 1.16, 1.17] + # except v5 starts from 1.17 until there is last four major releases after that + go: [1.17] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code - uses: actions/checkout@v1 + uses: actions/checkout@v2 with: ref: ${{ github.ref }} @@ -62,10 +53,10 @@ jobs: - name: Upload coverage to Codecov if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2 with: - token: fail_ci_if_error: false + benchmark: needs: test strategy: @@ -76,21 +67,10 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code (Previous) uses: actions/checkout@v2 with: diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d45ad78..000000000 --- a/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -arch: - - amd64 - - ppc64le - -language: go -go: - - 1.14.x - - 1.15.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip diff --git a/LICENSE b/LICENSE index c46d0105f..2f18411bd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021 LabStack +Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index 48061f7e2..10f9c8f59 100644 --- a/Makefile +++ b/Makefile @@ -24,11 +24,11 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -bench=".*" ${PKG_LIST} + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.15" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 +goversion ?= "1.16" +test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 8b2321f05..b9cb69e33 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,14 @@ ## Supported Go versions +Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that. + As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: - 1.9.7+ - 1.10.3+ -- 1.14+ +- 1.16+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -39,24 +41,13 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support -## Benchmarks - -Date: 2020/11/11
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better! - - - - -The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz - ## [Guide](https://echo.labstack.com/guide) ### Installation ```sh // go get github.com/labstack/echo/{version} -go get github.com/labstack/echo/v4 +go get github.com/labstack/echo/v5 ``` ### Example @@ -65,8 +56,8 @@ go get github.com/labstack/echo/v4 package main import ( - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" "net/http" ) @@ -82,7 +73,9 @@ func main() { e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":1323"); err != http.ErrServerClosed { + log.Fatal(err) + } } // Handler @@ -93,15 +86,15 @@ func hello(c echo.Context) error { # Third-party middlewares -| Repository | Description | -|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | -| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | Please send a PR to add your own library here. diff --git a/bind.go b/bind.go index c841ca010..0a7eb8b42 100644 --- a/bind.go +++ b/bind.go @@ -11,42 +11,38 @@ import ( "strings" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(c Context, i interface{}) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} // BindPathParams binds path params to bindable object -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { - names := c.ParamNames() - values := c.ParamValues() +func BindPathParams(c Context, i interface{}) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathParams() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(i, params, "param"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // BindQueryParams binds query params to bindable object -func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindQueryParams(c Context, i interface{}) error { + if err := bindData(i, c.QueryParams(), "query"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } @@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { +func BindBody(c Context, i interface{}) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { case *HTTPError: return err default: - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - params, err := c.FormParams() + values, err := c.FormValues() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } - if err = b.bindData(i, params, "form"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(i, values, "form"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } default: return ErrUnsupportedMediaType @@ -97,34 +93,34 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } // BindHeaders binds HTTP headers to a bindable object -func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindHeaders(c Context, i interface{}) error { + if err := bindData(i, c.Request().Header, "header"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - if err := b.BindPathParams(c, i); err != nil { +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. +func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) { + if err := BindPathParams(c, i); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) - method := c.Request().Method + method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { - if err = b.BindQueryParams(c, i); err != nil { + if err = BindQueryParams(c, i); err != nil { return err } } - return b.BindBody(c, i) + return BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { +func bindData(destination interface{}, data map[string][]string, tag string) error { if destination == nil || len(data) == 0 { return nil } @@ -167,10 +163,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). - // structs that implement BindUnmarshaler are binded only when they have explicit tag + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). + // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := bindData(structField.Addr().Interface(), data, tag); err != nil { return err } } @@ -180,10 +176,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To - // fix this we must check all of the map values in a - // case-insensitive search. + // Go json.Unmarshal supports case-insensitive binding. However, the url params are bound case-sensitive which + // is inconsistent. To fix this we must check all the map values in a case-insensitive search. for k, v := range data { if strings.EqualFold(k, inputFieldName) { inputValue = v @@ -297,7 +291,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { @@ -308,7 +302,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { @@ -319,7 +313,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error { if value == "" { - value = "false" + return nil } boolVal, err := strconv.ParseBool(value) if err == nil { @@ -330,7 +324,7 @@ func setBoolField(value string, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0.0" + return nil } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { diff --git a/bind_test.go b/bind_test.go index 4ed8dbb50..8f711c4f8 100644 --- a/bind_test.go +++ b/bind_test.go @@ -277,7 +277,7 @@ func TestBindHeaderParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) @@ -291,7 +291,7 @@ func TestBindHeaderParamBadType(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) @@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) { } } +func TestBind_CombineQueryWithHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil) + req.Header.Set("language", "de") + req.Header.Set("length", "99") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPathParams(PathParams{{ + Name: "id", + Value: "999", + }}) + + type SearchOpts struct { + ID int `param:"id"` + Length int `query:"length"` + Page int `query:"page"` + Search string `query:"search"` + Language string `query:"language" header:"language"` + } + + opts := SearchOpts{ + Length: 100, + Page: 0, + Search: "default value", + Language: "en", + } + err := c.Bind(&opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // bind from query + assert.Equal(t, 10, opts.Page) // bind from query + assert.Equal(t, 999, opts.ID) // bind from path param + assert.Equal(t, "et", opts.Language) // bind from query + assert.Equal(t, "default value", opts.Search) // default value stays + + // make sure another bind will not mess already set values unless there are new values + err = BindHeaders(c, &opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists + assert.Equal(t, 10, opts.Page) + assert.Equal(t, 999, opts.ID) + assert.Equal(t, "de", opts.Language) // header overwrites now this value + assert.Equal(t, "default value", opts.Search) +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) @@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) { func TestBindbindData(t *testing.T) { a := assert.New(t) ts := new(bindTestStruct) - b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := bindData(ts, values, "form") a.NoError(err) a.Equal(0, ts.I) @@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + cc := c.(RoutableContext) + cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"}) + cc.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }) u := new(user) err := c.Bind(u) @@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + cc2 := c2.(RoutableContext) + cc2.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc2.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c2.Bind(u) @@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + cc3 := c3.(RoutableContext) + cc3.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc3.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c3.Bind(u) @@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) { assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - assert := assert.New(t) +func TestSetIntField(t *testing.T) { + ts := new(bindTestStruct) + ts.I = 100 + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 100, ts.I) + + // second set with value sets the value + err = setIntField("5", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) +} +func TestSetUintField(t *testing.T) { ts := new(bindTestStruct) + ts.UI = 100 + val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) - } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) - } - // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) - } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) - } + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(100), ts.UI) + + // second set with value sets the value + err = setUintField("5", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) +} - // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) - } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) - } +func TestSetFloatField(t *testing.T) { + ts := new(bindTestStruct) + ts.F32 = 100 - // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) - } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) - } + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(100), ts.F32) + + // second set with value sets the value + err = setFloatField("15.5", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) +} + +func TestSetBoolField(t *testing.T) { + ts := new(bindTestStruct) + ts.B = true + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // second set with value sets the value + err = setBoolField("true", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // fourth set to false + err = setBoolField("false", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, false, ts.B) +} + +func TestUnmarshalFieldNonPtr(t *testing.T) { + ts := new(bindTestStruct) + val := reflect.ValueOf(ts).Elem() ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) + if assert.NoError(t, err) { + assert.True(t, ok) + assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) } } @@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() assert := assert.New(b) ts := new(bindTestStructWithTags) - binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = bindData(ts, values, "form") } assert.NoError(err) assertBindTestStruct(assert, (*bindTestStruct)(ts)) @@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("node_from_path") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "node_from_path"}, + }) } var bindTarget interface{} @@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(bindTarget, c) + err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("real_node") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "real_node"}, + }) } var bindTarget interface{} @@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) { } else { bindTarget = &Node{} } - b := new(DefaultBinder) - err := b.BindBody(c, bindTarget) + err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/binder.go b/binder.go index 0900ce8dc..402b80bc0 100644 --- a/binder.go +++ b/binder.go @@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, - ValueFunc: c.Param, + ValueFunc: c.PathParam, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here - value := c.Param(sourceParam) + value := c.PathParam(sourceParam) if value == "" { return nil } diff --git a/binder_external_test.go b/binder_external_test.go index f1aecb52b..585ade816 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -4,7 +4,7 @@ package echo_test import ( "encoding/base64" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "log" "net/http" "net/http/httptest" diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go deleted file mode 100644 index 018628c3a..000000000 --- a/binder_go1.15_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build go1.15 - -package echo - -/** - Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\") - So pre 1.15 these tests fail with similar error: - - expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param" - actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param" -*/ - -import ( - "errors" - "github.com/stretchr/testify/assert" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context { - e := New() - req := httptest.NewRequest(http.MethodGet, URL, body) - if body != nil { - req.Header.Set(HeaderContentType, MIMEApplicationJSON) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) - for name, value := range pathParams { - names = append(names, name) - values = append(values, value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) - } - - return c -} - -func TestValueBinder_TimeError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue time.Time - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - dest := time.Time{} - var err error - if tc.whenMust { - err = b.MustTime("param", &dest, tc.whenLayout).BindError() - } else { - err = b.Time("param", &dest, tc.whenLayout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_TimesError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue []time.Time - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - layout := time.RFC3339 - if tc.whenLayout != "" { - layout = tc.whenLayout - } - - var dest []time.Time - var err error - if tc.whenMust { - err = b.MustTimes("param", &dest, layout).BindError() - } else { - err = b.Times("param", &dest, layout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue time.Duration - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - var dest time.Duration - var err error - if tc.whenMust { - err = b.MustDuration("param", &dest).BindError() - } else { - err = b.Duration("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationsError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue []time.Duration - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - var dest []time.Duration - var err error - if tc.whenMust { - err = b.MustDurations("param", &dest).BindError() - } else { - err = b.Durations("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/binder_test.go b/binder_test.go index 946906a96..f57da32d6 100644 --- a/binder_test.go +++ b/binder_test.go @@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) c := e.NewContext(req, rec) if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) + params := make(PathParams, 0) for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + params = append(params, PathParam{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + cc := c.(RoutableContext) + cc.SetRawPathParams(¶ms) } return c @@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) } } @@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } @@ -2755,3 +2756,224 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { } } } + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/context.go b/context.go index a4ecfadfc..a397fba70 100644 --- a/context.go +++ b/context.go @@ -3,210 +3,193 @@ package echo import ( "bytes" "encoding/xml" + "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" "net/url" + "path/filepath" "strings" "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. + // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. + RouteInfo() RouteInfo - // SetPath sets the registered path for the handler. - SetPath(p string) + // Path returns the registered path for the handler. + Path() string - // Param returns path parameter by name. - Param(name string) string + // PathParam returns path parameter by name. + PathParam(name string) string - // ParamNames returns path parameter names. - ParamNames() []string + // PathParams returns path parameter values. + PathParams() PathParams - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // SetPathParams sets path parameters for current request. + SetPathParams(params PathParams) - // ParamValues returns path parameter values. - ParamValues() []string + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // QueryParamDefault returns the query param or default value for the provided name. + QueryParamDefault(name, defaultValue string) string - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryString returns the URL query string. + QueryString() string - // QueryString returns the URL query string. - QueryString() string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormValueDefault returns the form field value or default value for the provided name. + FormValueDefault(name, defaultValue string) string - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormValues returns the form field values as `url.Values`. + FormValues() (url.Values, error) - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Get retrieves data from the context. - Get(key string) interface{} + // Get retrieves data from the context. + Get(key string) interface{} - // Set saves data in the context. - Set(key string, val interface{}) + // Set saves data in the context. + Set(key string, val interface{}) - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. - Bind(i interface{}) error + // Bind binds the request body into provided type `i`. The default binder + // does it based on Content-Type header. + Bind(i interface{}) error - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i interface{}) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data interface{}) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // String sends a string response with status code. - String(code int, s string) error + // String sends a string response with status code. + String(code int, s string) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // JSON sends a JSON response with status code. + JSON(code int, i interface{}) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i interface{}, indent string) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i interface{}) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // XML sends an XML response with status code. + XML(code int, i interface{}) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i interface{}, indent string) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // File sends a response with the content of the file. - File(file string) error + // File sends a response with the content of the file. + File(file string) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // FileFS sends a response with the content of the file from given filesystem. + FileFS(file string, filesystem fs.FS) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // Error invokes the registered HTTP error handler. Generally used by middleware. - Error(err error) + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // Handler returns the matched handler by router. - Handler() HandlerFunc - - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) - - // Logger returns the `Logger` instance. - Logger() Logger - - // Set the logger - SetLogger(l Logger) - - // Echo returns the `Echo` instance. - Echo() *Echo + // Echo returns the `Echo` instance. + Echo() *Echo +} - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } +// ServableContext is interface that Echo context implementation must implement to be usable in middleware/handlers and +// be able to be routed by Router. +type ServableContext interface { + Context // minimal set of methods for middlewares and handler + RoutableContext // minimal set for routing. These methods should not be accessed in middlewares/handlers - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. @@ -221,39 +204,95 @@ const ( defaultIndent = " " ) -func (c *context) writeContentType(value string) { +// DefaultContext is default implementation of Context interface and can be embedded into structs to compose +// new Contexts with extended/modified behaviour. +type DefaultContext struct { + request *http.Request + response *Response + + route RouteInfo + path string + + // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations. + pathParams *PathParams + // currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request. + // Lifecycle is not handle by Echo and could have excess allocations per served Request + currentParams PathParams + + query url.Values + store Map + echo *Echo + lock sync.RWMutex +} + +// NewDefaultContext creates new instance of DefaultContext. +// Argument pathParamAllocSize must be value that is stored in Echo.contextPathParamAllocSize field and is used +// to preallocate PathParams slice. +func NewDefaultContext(e *Echo, pathParamAllocSize int) *DefaultContext { + p := make(PathParams, pathParamAllocSize) + return &DefaultContext{ + pathParams: &p, + store: make(Map), + echo: e, + } +} + +// Reset resets the context after request completes. It must be called along +// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. +// See `Echo#ServeHTTP()` +func (c *DefaultContext) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.response.reset(w) + c.query = nil + c.store = nil + + c.route = nil + c.path = "" + // NOTE: Don't reset because it has to have length of c.echo.contextPathParamAllocSize at all times + *c.pathParams = (*c.pathParams)[:0] + c.currentParams = nil +} + +func (c *DefaultContext) writeContentType(value string) { header := c.Response().Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -func (c *context) Request() *http.Request { +// Request returns `*http.Request`. +func (c *DefaultContext) Request() *http.Request { return c.request } -func (c *context) SetRequest(r *http.Request) { +// SetRequest sets `*http.Request`. +func (c *DefaultContext) SetRequest(r *http.Request) { c.request = r } -func (c *context) Response() *Response { +// Response returns `*Response`. +func (c *DefaultContext) Response() *Response { return c.response } -func (c *context) SetResponse(r *Response) { +// SetResponse sets `*Response`. +func (c *DefaultContext) SetResponse(r *Response) { c.response = r } -func (c *context) IsTLS() bool { +// IsTLS returns true if HTTP connection is TLS otherwise false. +func (c *DefaultContext) IsTLS() bool { return c.request.TLS != nil } -func (c *context) IsWebSocket() bool { +// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. +func (c *DefaultContext) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) return strings.EqualFold(upgrade, "websocket") } -func (c *context) Scheme() string { +// Scheme returns the HTTP protocol scheme, `http` or `https`. +func (c *DefaultContext) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -274,7 +313,10 @@ func (c *context) Scheme() string { return "http" } -func (c *context) RealIP() string { +// RealIP returns the client's network address based on `X-Forwarded-For` +// or `X-Real-IP` request header. +// The behavior can be configured using `Echo#IPExtractor`. +func (c *DefaultContext) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } @@ -293,85 +335,116 @@ func (c *context) RealIP() string { return ra } -func (c *context) Path() string { +// Path returns the registered path for the handler. +func (c *DefaultContext) Path() string { return c.path } -func (c *context) SetPath(p string) { +// SetPath sets the registered path for the handler. +func (c *DefaultContext) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } - } - return "" +// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. +// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. +func (c *DefaultContext) RouteInfo() RouteInfo { + return c.route } -func (c *context) ParamNames() []string { - return c.pnames +// SetRouteInfo sets the route info of this request to the context. +func (c *DefaultContext) SetRouteInfo(ri RouteInfo) { + c.route = ri } -func (c *context) SetParamNames(names ...string) { - c.pnames = names +// RawPathParams returns raw path pathParams value. Allocation of PathParams is handled by Context. +func (c *DefaultContext) RawPathParams() *PathParams { + return c.pathParams +} - l := len(names) - if *c.echo.maxParam < l { - *c.echo.maxParam = l - } +// SetRawPathParams replaces any existing param values with new values for this context lifetime (request). +// +// DO NOT USE! +// Do not set any other value than what you got from RawPathParams as allocation of PathParams is handled by Context. +// If you mess up size of pathParams size your application will panic/crash during routing +func (c *DefaultContext) SetRawPathParams(params *PathParams) { + c.pathParams = params +} - if len(c.pvalues) < l { - // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues - newPvalues := make([]string, l) - copy(newPvalues, c.pvalues) - c.pvalues = newPvalues +// PathParam returns path parameter by name. +func (c *DefaultContext) PathParam(name string) string { + if c.currentParams != nil { + return c.currentParams.Get(name, "") } -} -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] + return c.pathParams.Get(name, "") } -func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - // It will brake the Router#Find code - limit := len(values) - if limit > *c.echo.maxParam { - limit = *c.echo.maxParam - } - for i := 0; i < limit; i++ { - c.pvalues[i] = values[i] +// PathParamDefault does not exist as expecting empty path param makes no sense + +// PathParams returns path parameter values. +func (c *DefaultContext) PathParams() PathParams { + if c.currentParams != nil { + return c.currentParams } + + result := make(PathParams, len(*c.pathParams)) + copy(result, *c.pathParams) + return result } -func (c *context) QueryParam(name string) string { +// SetPathParams sets path parameters for current request. +func (c *DefaultContext) SetPathParams(params PathParams) { + c.currentParams = params +} + +// QueryParam returns the query param for the provided name. +func (c *DefaultContext) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } -func (c *context) QueryParams() url.Values { +// QueryParamDefault returns the query param or default value for the provided name. +// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string +func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string { + value := c.QueryParam(name) + if value == "" { + value = defaultValue + } + return value +} + +// QueryParams returns the query parameters as `url.Values`. +func (c *DefaultContext) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -func (c *context) QueryString() string { +// QueryString returns the URL query string. +func (c *DefaultContext) QueryString() string { return c.request.URL.RawQuery } -func (c *context) FormValue(name string) string { +// FormValue returns the form field value for the provided name. +func (c *DefaultContext) FormValue(name string) string { return c.request.FormValue(name) } -func (c *context) FormParams() (url.Values, error) { +// FormValueDefault returns the form field value or default value for the provided name. +// Note: FormValueDefault does not distinguish if form had no value by that name or value was empty string +func (c *DefaultContext) FormValueDefault(name, defaultValue string) string { + value := c.FormValue(name) + if value == "" { + value = defaultValue + } + return value +} + +// FormValues returns the form field values as `url.Values`. +func (c *DefaultContext) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { if err := c.request.ParseMultipartForm(defaultMemory); err != nil { return nil, err @@ -384,7 +457,8 @@ func (c *context) FormParams() (url.Values, error) { return c.request.Form, nil } -func (c *context) FormFile(name string) (*multipart.FileHeader, error) { +// FormFile returns the multipart form file for the provided name. +func (c *DefaultContext) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err @@ -393,30 +467,36 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { return fh, nil } -func (c *context) MultipartForm() (*multipart.Form, error) { +// MultipartForm returns the multipart form. +func (c *DefaultContext) MultipartForm() (*multipart.Form, error) { err := c.request.ParseMultipartForm(defaultMemory) return c.request.MultipartForm, err } -func (c *context) Cookie(name string) (*http.Cookie, error) { +// Cookie returns the named cookie provided in the request. +func (c *DefaultContext) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *context) SetCookie(cookie *http.Cookie) { +// SetCookie adds a `Set-Cookie` header in HTTP response. +func (c *DefaultContext) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -func (c *context) Cookies() []*http.Cookie { +// Cookies returns the HTTP cookies sent with the request. +func (c *DefaultContext) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) interface{} { +// Get retrieves data from the context. +func (c *DefaultContext) Get(key string) interface{} { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val interface{}) { +// Set saves data in the context. +func (c *DefaultContext) Set(key string, val interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -426,18 +506,24 @@ func (c *context) Set(key string, val interface{}) { c.store[key] = val } -func (c *context) Bind(i interface{}) error { - return c.echo.Binder.Bind(i, c) +// Bind binds the request body into provided type `i`. The default binder +// does it based on Content-Type header. +func (c *DefaultContext) Bind(i interface{}) error { + return c.echo.Binder.Bind(c, i) } -func (c *context) Validate(i interface{}) error { +// Validate validates provided `i`. It is usually called after `Context#Bind()`. +// Validator must be registered using `Echo#Validator`. +func (c *DefaultContext) Validate(i interface{}) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data interface{}) (err error) { +// Render renders a template with data and sends a text/html response with status +// code. Renderer must be registered using `Echo.Renderer`. +func (c *DefaultContext) Render(code int, name string, data interface{}) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } @@ -448,19 +534,22 @@ func (c *context) Render(code int, name string, data interface{}) (err error) { return c.HTMLBlob(code, buf.Bytes()) } -func (c *context) HTML(code int, html string) (err error) { +// HTML sends an HTTP response with status code. +func (c *DefaultContext) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } -func (c *context) HTMLBlob(code int, b []byte) (err error) { +// HTMLBlob sends an HTTP blob response with status code. +func (c *DefaultContext) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -func (c *context) String(code int, s string) (err error) { +// String sends a string response with status code. +func (c *DefaultContext) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { +func (c *DefaultContext) jsonPBlob(code int, callback string, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -479,13 +568,14 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error return } -func (c *context) json(code int, i interface{}, indent string) error { +func (c *DefaultContext) json(code int, i interface{}, indent string) error { c.writeContentType(MIMEApplicationJSONCharsetUTF8) c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i interface{}) (err error) { +// JSON sends a JSON response with status code. +func (c *DefaultContext) JSON(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -493,19 +583,25 @@ func (c *context) JSON(code int, i interface{}) (err error) { return c.json(code, i, indent) } -func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) { +// JSONPretty sends a pretty-print JSON with status code. +func (c *DefaultContext) JSONPretty(code int, i interface{}, indent string) (err error) { return c.json(code, i, indent) } -func (c *context) JSONBlob(code int, b []byte) (err error) { +// JSONBlob sends a JSON blob response with status code. +func (c *DefaultContext) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) } -func (c *context) JSONP(code int, callback string, i interface{}) (err error) { +// JSONP sends a JSONP response with status code. It uses `callback` to construct +// the JSONP payload. +func (c *DefaultContext) JSONP(code int, callback string, i interface{}) (err error) { return c.jsonPBlob(code, callback, i) } -func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { +// JSONPBlob sends a JSONP blob response with status code. It uses `callback` +// to construct the JSONP payload. +func (c *DefaultContext) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { @@ -518,7 +614,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i interface{}, indent string) (err error) { +func (c *DefaultContext) xml(code int, i interface{}, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -531,7 +627,8 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i interface{}) (err error) { +// XML sends an XML response with status code. +func (c *DefaultContext) XML(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -539,11 +636,13 @@ func (c *context) XML(code int, i interface{}) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) { +// XMLPretty sends a pretty-print XML with status code. +func (c *DefaultContext) XMLPretty(code int, i interface{}, indent string) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLBlob(code int, b []byte) (err error) { +// XMLBlob sends an XML blob response with status code. +func (c *DefaultContext) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -553,39 +652,86 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -func (c *context) Blob(code int, contentType string, b []byte) (err error) { +// Blob sends a blob response with status code and content type. +func (c *DefaultContext) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +// Stream sends a streaming response with status code and content type. +func (c *DefaultContext) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -func (c *context) Attachment(file, name string) error { +// File sends a response with the content of the file. +func (c *DefaultContext) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *DefaultContext) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} + +// Attachment sends a response as attachment, prompting client to save the file. +func (c *DefaultContext) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -func (c *context) Inline(file, name string) error { +// Inline sends a response as inline, opening the file in the browser. +func (c *DefaultContext) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } -func (c *context) contentDisposition(file, name, dispositionType string) error { +func (c *DefaultContext) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) return c.File(file) } -func (c *context) NoContent(code int) error { +// NoContent sends a response with no body and a status code. +func (c *DefaultContext) NoContent(code int) error { c.response.WriteHeader(code) return nil } -func (c *context) Redirect(code int, url string) error { +// Redirect redirects the request to a provided URL with status code. +func (c *DefaultContext) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -594,45 +740,7 @@ func (c *context) Redirect(code int, url string) error { return nil } -func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) -} - -func (c *context) Echo() *Echo { +// Echo returns the `Echo` instance. +func (c *DefaultContext) Echo() *Echo { return c.echo } - -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res - } - return c.echo.Logger -} - -func (c *context) SetLogger(l Logger) { - c.logger = l -} - -func (c *context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.response.reset(w) - c.query = nil - c.handler = NotFoundHandler - c.store = nil - c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { - c.pvalues[i] = "" - } -} diff --git a/context_fs.go b/context_fs.go deleted file mode 100644 index 11ee84bcd..000000000 --- a/context_fs.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "os" - "path/filepath" -) - -func (c *context) File(file string) (err error) { - f, err := os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return - } - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return -} diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go deleted file mode 100644 index c1c724afd..000000000 --- a/context_fs_go1.16.go +++ /dev/null @@ -1,52 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "errors" - "io" - "io/fs" - "net/http" - "path/filepath" -) - -func (c *context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c Context, file string, filesystem fs.FS) error { - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} diff --git a/context_fs_go1.16_test.go b/context_fs_go1.16_test.go deleted file mode 100644 index 027d1c483..000000000 --- a/context_fs_go1.16_test.go +++ /dev/null @@ -1,135 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestContext_File(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok, from default file system", - whenFile: "_fixture/images/walle.png", - whenFS: nil, - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "ok, from custom file system", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - if tc.whenFS != nil { - e.Filesystem = tc.whenFS - } - - handler := func(ec Context) error { - return ec.(*context).File(tc.whenFile) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestContext_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - handler := func(ec Context) error { - return ec.(*context).FileFS(tc.whenFile, tc.whenFS) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} diff --git a/context_test.go b/context_test.go index a8b9a9946..dca680f9b 100644 --- a/context_test.go +++ b/context_test.go @@ -8,33 +8,33 @@ import ( "errors" "fmt" "io" + "io/fs" "math" "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "text/template" "time" - "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -46,9 +46,10 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -60,9 +61,10 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -73,7 +75,7 @@ func BenchmarkAllocXML(b *testing.B) { } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := context{request: &http.Request{ + c := DefaultContext{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { @@ -104,18 +106,16 @@ func TestContext(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Render @@ -126,106 +126,106 @@ func TestContext(t *testing.T) { } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } c.echo.Renderer = nil err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) + assert.Error(t, err) // JSON rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } // JSON with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // reset // JSONPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } // JSON (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // JSONP rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback := "callback" err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) } // XML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // XML with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // XML (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // XML response write error - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.response.Writer = responseWriterErr{} err = c.XML(0, 0) - testify.Error(t, err) + assert.Error(t, err) // XMLPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } t.Run("empty indent", func(t *testing.T) { @@ -237,166 +237,157 @@ func TestContext(t *testing.T) { t.Run("json", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New JSONBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) } }) t.Run("xml", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New XMLBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } }) }) // Legacy JSONBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) } // Legacy JSONPBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback = "callback" data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } // Legacy XMLBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // String rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // Stream rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) r := strings.NewReader("response from a stream") err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) } // Attachment rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // Inline rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // NoContent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) - - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) // Reset - c.SetParamNames("foo") - c.SetParamValues("bar") + c.pathParams = &PathParams{ + {Name: "foo", Value: "bar"}, + } c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + assert.Equal(t, 0, len(c.PathParams())) + assert.Equal(t, 0, len(c.store)) + assert.Equal(t, nil, c.RouteInfo()) + assert.Equal(t, 0, len(c.QueryParams())) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -404,12 +395,11 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -421,24 +411,22 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -453,104 +441,95 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") -} - -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() - - handler := func(c Context) error { return c.String(http.StatusOK, "OK") } - - r.Add(http.MethodGet, "/users/:id", handler) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) - - assert := testify.New(t) - - assert.Equal("/users/:id", c.Path()) - - r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContextPathParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) + c := e.NewContext(req, nil).(*DefaultContext) + params := &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + } // ParamNames - c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) - - // ParamValues - c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.pathParams = params + assert.EqualValues(t, *params, c.PathParams()) // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.Equal(t, "501", c.PathParam("fid")) + assert.Equal(t, "", c.PathParam("undefined")) } func TestContextGetAndSetParam(t *testing.T) { e := New() r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + _, err := r.Add(Route{ + Method: http.MethodGet, + Path: "/:foo", + Name: "", + Handler: func(Context) error { return nil }, + Middlewares: nil, + }) + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) - c.SetParamNames("foo") + + params := &PathParams{{Name: "foo", Value: "101"}} + // ParamNames + c.(*DefaultContext).pathParams = params // round-trip param values with modification - paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) - paramVals[0] = "bar" - c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + paramVals := c.PathParams() + assert.Equal(t, *params, c.PathParams()) + + paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context + assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams()) + + pathParams := PathParams{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathParams(pathParams) + assert.Equal(t, pathParams, c.PathParams()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { - c.Reset(nil, nil) + assert.NotPanics(t, func() { + c.(ServableContext).Reset(nil, nil) }) + assert.Equal(t, PathParams{}, c.PathParams()) + assert.Len(t, *c.(*DefaultContext).pathParams, 0) + assert.Equal(t, cap(*c.(*DefaultContext).pathParams), 1) } // Issue #1655 -func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - +func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) { e := New() - assert.Equal(0, *e.maxParam) - - expectedOneParam := []string{"one"} - expectedTwoParams := []string{"one", "two"} - expectedThreeParams := []string{"one", "two", ""} - expectedABCParams := []string{"A", "B", "C"} + c := e.NewContext(nil, nil).(*DefaultContext) - c := e.NewContext(nil, nil) - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) - - c.SetParamNames("1") - assert.Equal(2, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) - - c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) + assert.Equal(t, 0, e.contextPathParamAllocSize) + expectedTwoParams := &PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetRawPathParams(expectedTwoParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, *expectedTwoParams, c.PathParams()) - c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) - // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + expectedThreeParams := PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, + } + c.SetPathParams(expectedThreeParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, expectedThreeParams, c.PathParams()) } func TestContextFormValue(t *testing.T) { @@ -564,25 +543,29 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) + + // FormValueDefault + assert.Equal(t, "Jon Snow", c.FormValueDefault("name", "nope")) + assert.Equal(t, "default", c.FormValueDefault("missing", "default")) - // FormParams - params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + // FormValues + values, err := c.FormValues() + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, params) + }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + values, err = c.FormValues() + assert.Nil(t, values) + assert.Error(t, err) } func TestContextQueryParam(t *testing.T) { @@ -594,11 +577,15 @@ func TestContextQueryParam(t *testing.T) { c := e.NewContext(req, nil) // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) + + // QueryParamDefault + assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope")) + assert.Equal(t, "default", c.QueryParamDefault("missing", "default")) // QueryParams - testify.Equal(t, url.Values{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -609,7 +596,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -618,8 +605,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -634,8 +621,8 @@ func TestContextMultipartForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) } } @@ -644,22 +631,22 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &context{ + c := &DefaultContext{ echo: e, } @@ -671,42 +658,6 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context = new(context) - - testify.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - testify.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context = new(context) - - c.SetPath(path) - testify.Equal(t, path, c.Path()) -} - type validator struct{} func (*validator) Validate(i interface{}) error { @@ -717,10 +668,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -728,21 +679,21 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -751,7 +702,7 @@ func TestContext_Scheme(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, @@ -759,7 +710,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, @@ -767,7 +718,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, @@ -775,7 +726,7 @@ func TestContext_Scheme(t *testing.T) { "http", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, @@ -783,7 +734,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, @@ -791,7 +742,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, "http", @@ -799,44 +750,44 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, - testify.False, + assert.False, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -849,30 +800,14 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - testify.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { @@ -881,7 +816,7 @@ func TestContext_RealIP(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, @@ -889,7 +824,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, @@ -897,7 +832,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, }, @@ -905,7 +840,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, @@ -915,7 +850,7 @@ func TestContext_RealIP(t *testing.T) { "192.168.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, @@ -925,6 +860,128 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) + } +} + +func TestContext_File(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec Context) error { + return ec.(*DefaultContext).File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec Context) error { + return ec.(*DefaultContext).FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) } } diff --git a/echo.go b/echo.go index 2e63cc6b1..29fa22813 100644 --- a/echo.go +++ b/echo.go @@ -5,12 +5,12 @@ Example: package main - import ( - "net/http" - - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + import ( + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "log" + "net/http" + ) // Handler func hello(c echo.Context) error { @@ -29,7 +29,9 @@ Example: e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":8080"); err != http.ErrServerClosed { + log.Fatal(err) + } } Learn more at https://echo.labstack.com @@ -37,125 +39,89 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" "io" - "io/ioutil" - stdLog "log" - "net" + "io/fs" "net/http" - "reflect" - "runtime" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex - StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - notFoundHandler HandlerFunc - pool sync.Pool - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - } +// Echo is the top-level framework instance. +// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. +type Echo struct { + // premiddleware are middlewares that are run for every request before routing is done + premiddleware []MiddlewareFunc + // middleware are middlewares that are run after router found a matching route (not found and method not found are also matches) + middleware []MiddlewareFunc - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } + router Router + routers map[string]Router + routerCreator func(e *Echo) Router - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } + contextPool sync.Pool + // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context + // creation time so we can allocate path parameter values slice. + contextPathParamAllocSize int - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(next HandlerFunc) HandlerFunc + // NewContextFunc allows using custom context implementations, instead of default *echo.context + NewContextFunc func(e *Echo, pathParamAllocSize int) ServableContext + Debug bool + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(c Context) error + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. + Filesystem fs.FS +} - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c Context, err error) - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} - // Common struct for Echo & Group. - common struct{} -) +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) +// Renderer is the interface that wraps the Render function. +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} // MIME types const ( @@ -245,297 +211,360 @@ const ( const ( // Version of Echo - Version = "4.6.3" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` -) - -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) - -// Errors -var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") + Version = "5.0.0-alpha" ) -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } - - MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) - } - return ErrMethodNotAllowed - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - filesystem: createFilesystem(), - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, +func New() *Echo { + logger := newJSONLogger(os.Stdout) + e := &Echo{ + Logger: logger, + Filesystem: newDefaultFS(), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + + routers: make(map[string]Router), + routerCreator: func(ec *Echo) Router { + return NewRouter(RouterConfig{}) }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - ListenerNetwork: "tcp", } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.JSONSerializer = &DefaultJSONSerializer{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { + + e.router = NewRouter(RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() interface{} { return e.NewContext(nil, nil) } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return + return e } -// NewContext returns a Context instance. +// NewContext returns a new Context instance. +// +// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { - return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + var c Context + if e.NewContextFunc != nil { + c = e.NewContextFunc(e, e.contextPathParamAllocSize) + } else { + c = NewDefaultContext(e, e.contextPathParamAllocSize) } + c.SetRequest(r) + c.SetResponse(NewResponse(w, e)) + return c } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } // Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { +func (e *Echo) Routers() map[string]Router { return e.routers } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. +// RouterFor returns Router for given host. +func (e *Echo) RouterFor(host string) Router { + return e.routers[host] +} + +// ResetRouterCreator resets callback for creating new router instances. +// Note: current (default) router is immediately replaced with router created with creator func and vhost routers are cleared. +func (e *Echo) ResetRouterCreator(creator func(e *Echo) Router) { + e.routerCreator = creator + e.router = creator(e) + e.routers = make(map[string]Router) +} + +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not // -// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "commited" the // response and status code header has been sent to the client. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - - if c.Response().Committed { - return - } - - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr - } +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c Context, err error) { + if c.Response().Committed { + return } - } else { - he = &HTTPError{ + + he := &HTTPError{ Code: http.StatusInternalServerError, Message: http.StatusText(http.StatusInternalServerError), } - } + if errors.As(err, &he) { + if he.Internal != nil { // max 2 levels of checks even if internal could have also internal + errors.As(he.Internal, &he) + } + } - // Issue #1426 - code := he.Code - message := he.Message - if m, ok := he.Message.(string); ok { - if e.Debug { - message = Map{"message": m, "error": err.Error()} - } else { - message = Map{"message": m} + // Issue #1426 + code := he.Code + message := he.Message + if m, ok := he.Message.(string); ok { + if exposeError { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } } - } - // Send response - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) + // Send response + var cErr error + if c.Request().Method == http.MethodHead { // Issue #608 + cErr = c.NoContent(he.Code) + } else { + cErr = c.JSON(code, message) + } + if cErr != nil { + c.Echo().Logger.Error(err) // truly rare case. ala client already disconnected + } } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler -// in the router with optional route-level middleware. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// Any registers a new route for all supported HTTP methods and path with matching handler +// in the router with optional route-level middleware. Panics on error. +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) +} + +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.PathParam("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) } - return routes } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { return c.File(file) - }, m...) + } + return e.Add(http.MethodGet, path, handler, middleware...) } -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Routable) (RouteInfo, error) { + return e.add("", route) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) +func (e *Echo) add(host string, route Routable) (RouteInfo, error) { router := e.findRouter(host) - router.Add(method, path, func(c Context) error { - h := applyMiddleware(handler, middleware...) - return h(c) - }) - r := &Route{ - Method: method, - Path: path, - Name: name, + ri, err := router.Add(route) + if err != nil { + return nil, err + } + + paramsCount := len(ri.Params()) + if paramsCount > e.contextPathParamAllocSize { + e.contextPathParamAllocSize = paramsCount } - e.router.routes[method+path] = r - return r + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + "", + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Host creates a new router group for the provided host and optional host-level middleware. func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) + e.routers[name] = e.routerCreator(e) g = &Group{host: name, echo: e} g.Use(m...) return @@ -548,326 +577,82 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) -} - -// Reverse generates an URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() -} - -// Routes returns the registered routes. -func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes -} - // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) + return e.contextPool.Get().(Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) + var c ServableContext + if e.NewContextFunc != nil { + // NOTE: we are not casting always context to RoutableContext because casting to interface vs pointer to struct is + // "significantly" slower. Echo Context interface has way to many methods so these checks take time. + // These are benchmarks with 1.16: + // * interface extending another interface = +24% slower (3233 ns/op vs 2605 ns/op) + // * interface (not extending any, just methods)= +14% slower + // + // Quote from https://stackoverflow.com/a/31584377 + // "it's even worse with interface-to-interface assertion, because you also need to ensure that the type implements the interface." + // + // So most of the time we do not need custom context type and simple IF + cast to pointer to struct is fast enough. + c = e.contextPool.Get().(ServableContext) + } else { + c = e.contextPool.Get().(*DefaultContext) + } c.Reset(r, w) var h func(Context) error if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc Context) error { + // NOTE: router will be executed after pre middlewares have been run. We assume here that context we receive after pre middlewares + // is the same we began with. If not - this is use-case we do not support and is probably abuse from developer. + h1 := applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) + e.HTTPErrorHandler(c, err) } - // Release context - e.pool.Put(c) + e.contextPool.Put(c) } -// Start starts an HTTP server. +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(e); err != http.ErrServerClosed { +// log.Fatal(err) +// } +// // or standard library `http.Server` +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != http.ErrServerClosed { +// log.Fatal(err) +// } func (e *Echo) Start(address string) error { - e.startupMutex.Lock() - e.Server.Addr = address - if err := e.configureServer(e.Server); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return e.Server.Serve(e.Listener) -} - -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMutex.Lock() - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - e.startupMutex.Unlock() - return - } + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt) // start shutdown process on ctrl+c + defer cancel() + sc.GracefulContext = ctx - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - e.startupMutex.Unlock() - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMutex.Unlock() - return - } - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return ioutil.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - e.startupMutex.Lock() - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func (e *Echo) configureTLS(address string) { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMutex.Lock() - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - if s.TLSConfig != nil { - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -func (e *Echo) configureServer(s *http.Server) (err error) { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if s.TLSConfig == nil { - if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return nil - } - if e.TLSListener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } - return nil -} - -// ListenerAddr returns net.Addr for Listener -func (e *Echo) ListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.Listener == nil { - return nil - } - return e.Listener.Addr() -} - -// TLSListenerAddr returns net.Addr for TLSListener -func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.TLSListener == nil { - return nil - } - return e.TLSListener.Addr() -} - -// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { - e.startupMutex.Lock() - // Setup - s := e.Server - s.Addr = address - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = h2c.NewHandler(e, h2s) - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) - if err != nil { - e.startupMutex.Unlock() - return err - } - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - if he.Internal == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) - } - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he -} - -// Unwrap satisfies the Go 1.13 error wrapper interface. -func (he *HTTPError) Unwrap() error { - return he.Internal + return sc.Start(e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. @@ -892,19 +677,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -// GetPath returns RawPath, if it's empty returns Path from URL -// Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - -func (e *Echo) findRouter(host string) *Router { +func (e *Echo) findRouter(host string) Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { return r @@ -913,53 +686,59 @@ func (e *Echo) findRouter(host string) *Router { return e.router } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return t.String() + return h } -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + prefix string + fs fs.FS } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: os.DirFS(dir), } - // Ignore error from setting the KeepAlivePeriod as some systems, such as - // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP - _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) - return } -func newListener(address, network string) (*tcpKeepAliveListener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, ErrInvalidListenerNetwork - } - l, err := net.Listen(network, address) - if err != nil { - return nil, err +func (fs defaultFS) Open(name string) (fs.File, error) { + return fs.fs.Open(name) +} + +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to + // allow cases when root is given as `../somepath` which is not valid for fs.FS + root = filepath.Join(dFS.prefix, root) + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil + return fs.Sub(currentFs, root) } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) + if err != nil { + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } - return h + return subFs } diff --git a/echo_fs.go b/echo_fs.go deleted file mode 100644 index c3790545a..000000000 --- a/echo_fs.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "net/url" - "os" - "path/filepath" -) - -type filesystem struct { -} - -func createFilesystem() filesystem { - return filesystem{} -} - -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. - } - return e.static(prefix, root, e.GET) -} - -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err - } - - name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) - if err != nil { - // The access path does not exist - return NotFoundHandler(c) - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return c.File(name) - } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) - } - get(prefix, h) - } - return get(prefix+"/*", h) -} diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go deleted file mode 100644 index 435459de2..000000000 --- a/echo_fs_go1.16.go +++ /dev/null @@ -1,145 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "fmt" - "io/fs" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" -) - -type filesystem struct { - // Filesystem is file system used by Static and File handlers to access files. - // Defaults to os.DirFS(".") - // - // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary - // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths - // including `assets/images` as their prefix. - Filesystem fs.FS -} - -func createFilesystem() filesystem { - return filesystem{ - Filesystem: newDefaultFS(), - } -} - -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string) *Route { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - ) -} - -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } - - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return fsFile(c, name, fileSystem) - } -} - -// FileFS registers a new route with path to serve file from the provided file system. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return e.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// StaticFileHandler creates handler function to serve file from provided file system -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c Context) error { - return fsFile(c, file, filesystem) - } -} - -// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` -// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` -// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break -// all old applications that rely on being able to traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - prefix string - fs fs.FS -} - -func newDefaultFS() *defaultFS { - dir, _ := os.Getwd() - return &defaultFS{ - prefix: dir, - fs: os.DirFS(dir), - } -} - -func (fs defaultFS) Open(name string) (fs.File, error) { - return fs.fs.Open(name) -} - -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to - // allow cases when root is given as `../somepath` which is not valid for fs.FS - root = filepath.Join(dFS.prefix, root) - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil - } - return fs.Sub(currentFs, root) -} - -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) - if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) - } - return subFs -} diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go deleted file mode 100644 index 07e516555..000000000 --- a/echo_fs_go1.16_test.go +++ /dev/null @@ -1,265 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" -) - -func TestEcho_StaticFS(t *testing.T) { - var testCases = []struct { - name string - givenPrefix string - givenFs fs.FS - givenFsRoot string - whenURL string - expectStatus int - expectHeaderLocation string - expectBodyStartsWith string - }{ - { - name: "ok", - givenPrefix: "/images", - givenFs: os.DirFS("./_fixture/images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "ok, from sub fs", - givenPrefix: "/images", - givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "No file", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/scripts"), - whenURL: "/images/bolt.png", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/images"), - whenURL: "/images/", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory Redirect", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/folder", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenFs: os.DirFS("_fixture"), - whenURL: "/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory 404 (request URL without slash)", - givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Prefixed directory redirect (without slash redirect to slash)", - givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending with slash)", - givenPrefix: "/assets/", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending without slash)", - givenPrefix: "/assets", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Sub-directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/folder/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "do not allow directory traversal (backslash - windows separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/..\\middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "do not allow directory traversal (slash - unix separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/../middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - tmpFs := tc.givenFs - if tc.givenFsRoot != "" { - tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) - } - e.StaticFS(tc.givenPrefix, tmpFs) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - body := rec.Body.String() - if tc.expectBodyStartsWith != "" { - assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) - } else { - assert.Equal(t, "", body) - } - - if tc.expectHeaderLocation != "" { - assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) - } else { - _, ok := rec.Result().Header["Location"] - assert.False(t, ok) - } - }) - } -} - -func TestEcho_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestEcho_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../assets", - expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", - }, - { - name: "panics for /", - givenRoot: "/assets", - expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - assert.PanicsWithError(t, tc.expectError, func() { - e.Static("../assets", tc.givenRoot) - }) - }) - } -} diff --git a/echo_test.go b/echo_test.go index f175d765b..3d2eecf74 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,31 +3,25 @@ package echo import ( "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" - "io/ioutil" + "io/fs" "net" "net/http" "net/http/httptest" "net/url" "os" - "reflect" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` @@ -61,16 +55,17 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestEchoStatic(t *testing.T) { +func TestEcho_StaticFS(t *testing.T) { var testCases = []struct { name string givenPrefix string - givenRoot string + givenFs fs.FS + givenFsRoot string whenURL string expectStatus int expectHeaderLocation string @@ -79,7 +74,15 @@ func TestEchoStatic(t *testing.T) { { name: "ok", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("./_fixture/images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, from sub fs", + givenPrefix: "/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), @@ -87,7 +90,7 @@ func TestEchoStatic(t *testing.T) { { name: "No file", givenPrefix: "/images", - givenRoot: "_fixture/scripts", + givenFs: os.DirFS("_fixture/scripts"), whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -95,7 +98,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("_fixture/images"), whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -103,7 +106,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture/"), whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -112,7 +115,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect with non-root path", givenPrefix: "/static", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", @@ -121,7 +124,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -129,7 +132,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -138,7 +141,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -146,7 +149,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -154,7 +157,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -162,7 +165,7 @@ func TestEchoStatic(t *testing.T) { { name: "Sub-directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -170,7 +173,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -178,7 +181,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -188,10 +191,18 @@ func TestEchoStatic(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.Static(tc.givenPrefix, tc.givenRoot) + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { @@ -210,44 +221,127 @@ func TestEchoStatic(t *testing.T) { } } -func TestEchoStaticRedirectIndex(t *testing.T) { - e := New() +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } - // HandlerFunc - e.Static("/static", "_fixture") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - errCh := make(chan error) + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() - go func() { - errCh <- e.Start(":0") - }() + e.ServeHTTP(rec, req) - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) + assert.Equal(t, tc.expectCode, rec.Code) - addr := e.ListenerAddr().String() - if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(t, err.Error()) - } +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", + }, + { + name: "panics for /", + givenRoot: "/assets", + expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", + }, + } - } else { - assert.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.PanicsWithError(t, tc.expectError, func() { + e.Static("../assets", tc.givenRoot) + }) + }) } +} - if err := e.Close(); err != nil { - t.Fatal(err) +func TestEchoStaticRedirectIndex(t *testing.T) { + e := New() + + // HandlerFunc + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/static*", ri.Path()) + assert.Equal(t, "GET:/static*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) } + + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { e := New() - e.File("/walle", "_fixture/images/walle.png") + ri := e.File("/walle", "_fixture/images/walle.png") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/walle", ri.Path()) + assert.Equal(t, "GET:/walle", ri.Name()) + assert.Nil(t, ri.Params()) + c, b := request(http.MethodGet, "/walle", e) assert.Equal(t, http.StatusOK, c) assert.NotEmpty(t, b) @@ -259,7 +353,8 @@ func TestEchoMiddleware(t *testing.T) { e.Pre(func(next HandlerFunc) HandlerFunc { return func(c Context) error { - assert.Empty(t, c.Path()) + // before route match is found RouteInfo does not exist + assert.Equal(t, nil, c.RouteInfo()) buf.WriteString("-1") return next(c) } @@ -304,7 +399,7 @@ func TestEchoMiddlewareError(t *testing.T) { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -359,128 +454,202 @@ func TestEchoWrapMiddleware(t *testing.T) { } } +func TestEchoGet_routeInfoIsImmutable(t *testing.T) { + e := New() + ri := e.GET("/test", handlerFunc) + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err := e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) + + rInfo := ri.(routeInfo) + rInfo.name = "changed" // this change should not change other returned values + + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err = e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) +} + func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodConnect+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodDelete+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodGet+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodHead+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodOptions+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPatch+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPost+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPut+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) + + ri := e.TRACE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodTrace+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoAny(t *testing.T) { // JFC e := New() - e.Any("/", func(c Context) error { + ris := e.Any("/", func(c Context) error { return c.String(http.StatusOK, "Any") }) + assert.Len(t, ris, 11) } func TestEchoMatch(t *testing.T) { // JFC e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { return c.String(http.StatusOK, "Match") }) + assert.Len(t, ris, 2) } -func TestEchoURL(t *testing.T) { - e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getAny := func(Context) error { return nil } - getFile := func(Context) error { return nil } - - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - e.GET("/documents/*", getAny) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) - - assert := assert.New(t) - - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal("/documents/*", e.URL(getAny)) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) -} - -func TestEchoRoutes(t *testing.T) { - e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } -} - -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEcho_Routers_HandleHostsProperly(t *testing.T) { e := New() h := e.Host("route.com") routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + {Method: http.MethodGet, Path: "/users/:user/events"}, + {Method: http.MethodGet, Path: "/users/:user/events/public"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/refs"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/tags"}, } for _, r := range routes { h.Add(r.Method, r.Path, func(c Context) error { @@ -488,17 +657,22 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { }) } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { + routers := e.Routers() + + routeCom, ok := routers["route.com"] + assert.True(t, ok) + + if assert.Equal(t, len(routes), len(routeCom.Routes())) { + for _, r := range routeCom.Routes() { found := false for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { + if r.Method() == rr.Method && r.Path() == rr.Path { found = true break } } if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + t.Errorf("Route %s %s not found", r.Method(), r.Path()) } } } @@ -510,7 +684,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { return c.String(http.StatusOK, "/with/slash") }) e.GET("/:id", func(c Context) error { - return c.String(http.StatusOK, c.Param("id")) + return c.String(http.StatusOK, c.PathParam("id")) }) var testCases = []struct { @@ -547,8 +721,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assert := assert.New(t) - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } @@ -643,8 +815,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(tc.expectStatus, rec.Code) - assert.Equal(tc.expectBody, rec.Body.String()) + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } @@ -732,710 +904,157 @@ func TestEchoMethodNotAllowed(t *testing.T) { func TestEchoContext(t *testing.T) { e := New() c := e.AcquireContext() - assert.IsType(t, new(context), c) + assert.IsType(t, new(DefaultContext), c) e.ReleaseContext(c) } -func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(5 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - var addr net.Addr - if isTLS { - addr = e.TLSListenerAddr() - } else { - addr = e.ListenerAddr() - } - if addr != nil && strings.Contains(addr.String(), ":") { - return nil // was started - } - case err := <-errChan: - if err == http.ErrServerClosed { - return nil - } - return err - } - } -} - -func TestEchoStart(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - err := e.Start(":0") - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - assert.NoError(t, e.Close()) -} - -func TestEcho_StartTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - certFile string - keyFile string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid certFile", - addr: ":0", - certFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, invalid keyFile", - addr: ":0", - keyFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, failed to create cert out of certFile and keyFile", - addr: ":0", - keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key - expectError: "tls: found a certificate rather than a key in the PEM for the private key", - }, - { - name: "nok, invalid tls address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - certFile := "_fixture/certs/cert.pem" - if tc.certFile != "" { - certFile = tc.certFile - } - keyFile := "_fixture/certs/key.pem" - if tc.keyFile != "" { - keyFile = tc.keyFile - } - - err := e.StartTLS(tc.addr, certFile, keyFile) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - if _, ok := err.(*os.PathError); ok { - assert.Error(t, err) // error messages for unix and windows are different. so test only error type here - } else { - assert.EqualError(t, err, tc.expectError) - } - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEchoStartTLSAndStart(t *testing.T) { - // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server +func TestEcho_Start(t *testing.T) { e := New() e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "OK") + return c.String(http.StatusTeapot, "OK") }) - - errTLSChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - keyFile := "_fixture/certs/key.pem" - err := e.StartTLS("localhost:", certFile, keyFile) - if err != nil { - errTLSChan <- err - } - }() - - err := waitForServerStart(e, errTLSChan, true) - assert.NoError(t, err) - defer func() { - if err := e.Shutdown(stdContext.Background()); err != nil { - t.Error(err) - } - }() - - // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - res, err := client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - errChan := make(chan error) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer rndPort.Close() + errChan := make(chan error, 1) go func() { - err := e.Start("localhost:") - if err != nil { - errChan <- err - } + errChan <- e.Start(rndPort.Addr().String()) }() - err = waitForServerStart(e, errChan, false) - assert.NoError(t, err) - // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS - res, err = http.Get("http://" + e.ListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - // see if HTTPS works after HTTP listener is also added - res, err = client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + assert.Contains(t, err.Error(), "bind: address already in use") + } } -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) +func request(method, path string, e *Echo) (int, string) { + req := httptest.NewRequest(method, path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec.Code, rec.Body.String() +} - testCases := []struct { - cert interface{} - key interface{} - expectedErr error - name string +func TestDefaultHTTPErrorHandler(t *testing.T) { + var testCases = []struct { + name string + givenExposeError bool + givenLoggerFunc bool + whenMethod string + whenError error + expectBody string + expectStatus int + expectLogged string }{ { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, + name: "ok, expose error = true, HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error","message":"my_error"}` + "\n", }, { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, + name: "ok, expose error = true, HTTPError + internal error", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error, internal=internal_error","message":"my_error"}` + "\n", }, { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, + name: "ok, expose error = true, HTTPError + internal HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"error":"code=418, message=my_error, internal=code=425, message=early_error","message":"early_error"}` + "\n", }, { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, + name: "ok, expose error = false, HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", }, { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { - e := New() - e.HideBanner = true - - errChan := make(chan error, 0) - - go func() { - errChan <- e.StartTLS(":0", test.cert, test.key) - }() - - err := waitForServerStart(e, errChan, true) - if test.expectedErr != nil { - assert.EqualError(t, err, test.expectedErr.Error()) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartAutoTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"early_error"}` + "\n", }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error, 0) - - go func() { - errChan <- e.StartAutoTLS(tc.addr) - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartH2CServer(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ { - name: "ok", - addr: ":0", + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.Debug = true - h2s := &http2.Server{} - - errChan := make(chan error) - go func() { - err := e.StartH2CServer(tc.addr, h2s) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) -} - -func request(method, path string, e *Echo) (int, string) { - req := httptest.NewRequest(method, path, nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - return rec.Code, rec.Body.String() -} - -func TestHTTPError(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) -} - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) -} - -func TestDefaultHTTPErrorHandler(t *testing.T) { - e := New() - e.Debug = true - e.Any("/plain", func(c Context) error { - return errors.New("An error occurred") - }) - e.Any("/badrequest", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - e.Any("/servererror", func(c Context) error { - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - e.Any("/early-return", func(c Context) error { - c.String(http.StatusOK, "OK") - return errors.New("ERROR") - }) - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) - - // With Debug=true plain response contains error message - c, b := request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) - // and special handling for HTTPError - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) - // complex errors are serialized to pretty JSON - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) - // if the body is already set HTTPErrorHandler should not add anything to response body - c, b = request(http.MethodGet, "/early-return", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, "OK", b) - // internal error should be reflected in the message - c, b = request(http.MethodGet, "/internal-error", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) - - e.Debug = false - // With Debug=false the error response is shortened - c, b = request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) - // No difference for error response with non plain string errors - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) -} - -func TestEchoClose(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - assert.NoError(t, e.Close()) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -var listenerNetworkTests = []struct { - test string - network string - address string -}{ - {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, - {"tcp ipv6 address", "tcp", "[::1]:1323"}, - {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, - {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, -} - -func supportsIPv6() bool { - addrs, _ := net.InterfaceAddrs() - for _, addr := range addrs { - // Check if any interface has local IPv6 assigned - if strings.Contains(addr.String(), "::1") { - return true - } - } - return false -} - -func TestEchoListenerNetwork(t *testing.T) { - hasIPv6 := supportsIPv6() - for _, tt := range listenerNetworkTests { - if !hasIPv6 && strings.Contains(tt.address, "::") { - t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") - continue - } - t.Run(tt.test, func(t *testing.T) { - e := New() - e.ListenerNetwork = tt.network - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") + e.Logger = &jsonLogger{writer: buf} + e.Any("/path", func(c Context) error { + return tc.whenError }) - errCh := make(chan error) - - go func() { - errCh <- e.Start(tt.address) - }() + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, "OK", string(body)) - } else { - assert.Fail(t, err.Error()) - } - - } else { - assert.Fail(t, err.Error()) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod } + c, b := request(method, "/path", e) - if err := e.Close(); err != nil { - t.Fatal(err) - } + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func TestEchoListenerNetworkInvalid(t *testing.T) { - e := New() - e.ListenerNetwork = "unix" - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) -} - -func TestEchoReverse(t *testing.T) { - assert := assert.New(t) - - e := New() - dummyHandler := func(Context) error { return nil } - - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal("/params/:foo", e.Reverse("/params/:foo")) - assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +type myCustomContext struct { + DefaultContext } -func TestEchoReverseHandleHostProperly(t *testing.T) { - assert := assert.New(t) - - dummyHandler := func(Context) error { return nil } - - e := New() - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) -} - -func TestEcho_ListenerAddr(t *testing.T) { - e := New() - - addr := e.ListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) +func (c *myCustomContext) QueryParam(name string) string { + return "prefix_" + c.DefaultContext.QueryParam(name) } -func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - +func TestEcho_customContext(t *testing.T) { e := New() - - addr := e.TLSListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.StartTLS(":0", cert, key) - }() - - err = waitForServerStart(e, errCh, true) - assert.NoError(t, err) -} - -func TestEcho_StartServer(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - certs, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - - var testCases = []struct { - name string - addr string - TLSConfig *tls.Config - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "ok, start with TLS", - addr: ":0", - TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - { - name: "nok, invalid tls address", - addr: "nope", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - expectError: "listen tcp: address nope: missing port in address", - }, + e.NewContextFunc = func(ec *Echo, pathParamAllocSize int) ServableContext { + return &myCustomContext{ + DefaultContext: *NewDefaultContext(ec, pathParamAllocSize), + } } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Debug = true - - server := new(http.Server) - server.Addr = tc.addr - if tc.TLSConfig != nil { - server.TLSConfig = tc.TLSConfig - } - - errCh := make(chan error) - go func() { - errCh <- e.StartServer(server) - }() + e.GET("/info/:id/:file", func(c Context) error { + return c.String(http.StatusTeapot, c.QueryParam("param")) + }) - err := waitForServerStart(e, errCh, tc.TLSConfig != nil) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - assert.NoError(t, e.Close()) - }) - } + status, body := request(http.MethodGet, "/info/1/a.csv?param=123", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "prefix_123", body) } -func benchmarkEchoRoutes(b *testing.B, routes []*Route) { +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest("GET", "/", nil) u := req.URL diff --git a/go.mod b/go.mod index 4de2bdde1..339d24ef2 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,19 @@ -module github.com/labstack/echo/v4 +module github.com/labstack/echo/v5 go 1.17 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.1 + github.com/golang-jwt/jwt/v4 v4.2.0 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 + golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/text v0.3.3 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index f66734243..3290b99b8 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= -github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= -github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -18,25 +14,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/group.go b/group.go index bba470ce8..4f04a73a2 100644 --- a/group.go +++ b/group.go @@ -1,98 +1,121 @@ package echo import ( + "io/fs" "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + host string + prefix string + middleware []MiddlewareFunc + echo *Echo +} // Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) @@ -102,18 +125,57 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { return } -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(path, file, g.GET) +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + return g.StaticFS(pathPrefix, subFs) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return g.GET(path, StaticFileHandler(file, filesystem), m...) } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) +} + +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri +} + +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Routable) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(g.host, groupRoute) } diff --git a/group_fs.go b/group_fs.go deleted file mode 100644 index 0a1ce4a94..000000000 --- a/group_fs.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) -} diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go deleted file mode 100644 index 2ba52b5e2..000000000 --- a/group_fs_go1.16.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "io/fs" - "net/http" -) - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string) { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - g.StaticFS(pathPrefix, subFs) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { - g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} diff --git a/group_fs_go1.16_test.go b/group_fs_go1.16_test.go deleted file mode 100644 index d0caa33db..000000000 --- a/group_fs_go1.16_test.go +++ /dev/null @@ -1,106 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestGroup_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/assets/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - g := e.Group("/assets") - g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestGroup_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../images", - expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", - }, - { - name: "panics for /", - givenRoot: "/images", - expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - g := e.Group("/assets") - - assert.PanicsWithError(t, tc.expectError, func() { - g.Static("/images", tc.givenRoot) - }) - }) - } -} diff --git a/group_test.go b/group_test.go index c51fd91eb..3914c0bd8 100644 --- a/group_test.go +++ b/group_test.go @@ -1,31 +1,70 @@ package echo import ( + "github.com/stretchr/testify/assert" + "io/fs" "io/ioutil" "net/http" "net/http/httptest" + "os" + "strings" "testing" - - "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { @@ -92,11 +131,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } m2 := func(next HandlerFunc) HandlerFunc { return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } } h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } g.Use(m1) g.GET("/help", h, m2) @@ -119,3 +158,535 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ris := users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 11) + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_AnyWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/books/download*", ri.Path()) + assert.Equal(t, "GET:/books/download*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../images", + expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", + }, + { + name: "panics for /", + givenRoot: "/images", + expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.PanicsWithError(t, tc.expectError, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} diff --git a/httperror.go b/httperror.go new file mode 100644 index 000000000..5c217dac1 --- /dev/null +++ b/httperror.go @@ -0,0 +1,74 @@ +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// Errors +var ( + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) + ErrForbidden = NewHTTPError(http.StatusForbidden) + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) + ErrBadRequest = NewHTTPError(http.StatusBadRequest) + ErrBadGateway = NewHTTPError(http.StatusBadGateway) + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Code int `json:"-"` + Message interface{} `json:"message"` + Internal error `json:"-"` // Stores the error returned by an external dependency +} + +// NewHTTPError creates a new HTTPError instance. +func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used? + he := &HTTPError{Code: code, Message: http.StatusText(code)} + if len(message) > 0 { + he.Message = message[0] + } + return he +} + +// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set. +func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError { + he := NewHTTPError(code, message...) + he.Internal = internalError + return he +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } + return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) +} + +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 000000000..f9d340f11 --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,52 @@ +package echo + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHTTPError(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestNewHTTPErrorWithInternal(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message") + assert.Equal(t, "code=400, message=test message, internal=test", he.Error()) +} + +func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test")) + assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error()) +} + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/json.go b/json.go index 16b2d0577..16074fa24 100644 --- a/json.go +++ b/json.go @@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { err := json.NewDecoder(c.Request().Body).Decode(i) if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + return NewHTTPErrorWithInternal( + http.StatusBadRequest, + err, + fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset), + ) } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, + err, + fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()), + ) } return err } diff --git a/json_test.go b/json_test.go index 27ee43e73..ac64d2894 100644 --- a/json_test.go +++ b/json_test.go @@ -14,7 +14,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) assert := testify.New(t) @@ -40,7 +40,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") if assert.NoError(err) { assert.Equal(userJSONPretty+"\n", rec.Body.String()) @@ -53,7 +53,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) assert := testify.New(t) @@ -81,7 +81,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalSyntaxError) assert.IsType(&HTTPError{}, err) assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") @@ -93,7 +93,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalTypeError) assert.IsType(&HTTPError{}, err) assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") diff --git a/log.go b/log.go index 3f8de5904..ce351881c 100644 --- a/log.go +++ b/log.go @@ -1,41 +1,148 @@ package echo import ( + "bytes" "io" - - "github.com/labstack/gommon/log" + "strconv" + "sync" + "time" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) +//----------------------------------------------------------------------------- +// Example for Zap (https://github.com/uber-go/zap) +//func main() { +// e := echo.New() +// logger, _ := zap.NewProduction() +// e.Logger = &ZapLogger{logger: logger} +//} +//type ZapLogger struct { +// logger *zap.Logger +//} +// +//func (l *ZapLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZapLogger) Error(err error) { +// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo")) +//} + +//----------------------------------------------------------------------------- +// Example for Zerolog (https://github.com/rs/zerolog) +//func main() { +// e := echo.New() +// logger := zerolog.New(os.Stdout) +// e.Logger = &ZeroLogger{logger: &logger} +//} +// +//type ZeroLogger struct { +// logger *zerolog.Logger +//} +// +//func (l *ZeroLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZeroLogger) Error(err error) { +// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error()) +//} + +//----------------------------------------------------------------------------- +// Example for Logrus (https://github.com/sirupsen/logrus) +//func main() { +// e := echo.New() +// e.Logger = &LogrusLogger{logger: logrus.New()} +//} +// +//type LogrusLogger struct { +// logger *logrus.Logger +//} +// +//func (l *LogrusLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *LogrusLogger) Error(err error) { +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err) +//} + +// Logger defines the logging interface that Echo uses internally in few places. +// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice. +type Logger interface { + // Write provides writer interface for http.Server `ErrorLog` and for logging startup messages. + // `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers, + // and underlying FileSystem errors. + // `logger` middleware will use this method to write its JSON payload. + Write(p []byte) (n int, err error) + // Error logs the error + Error(err error) +} + +// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only +// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting. +// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases. +type jsonLogger struct { + writer io.Writer + bufferPool sync.Pool + lock sync.Mutex + + timeNow func() time.Time +} + +func newJSONLogger(writer io.Writer) *jsonLogger { + return &jsonLogger{ + writer: writer, + bufferPool: sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 256)) + }, + }, + timeNow: time.Now, } -) +} + +func (l *jsonLogger) Write(p []byte) (n int, err error) { + pLen := len(p) + if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message + (p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') || + (p[0] == '{' && p[pLen-1] == '}') { + return l.write(p) + } + // we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is + // called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably + // deserves at least WARN level. + return l.printf("INFO", string(p)) +} + +func (l *jsonLogger) Error(err error) { + _, _ = l.printf("ERROR", err.Error()) +} + +func (l *jsonLogger) printf(level string, message string) (n int, err error) { + buf := l.bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer l.bufferPool.Put(buf) + + buf.WriteString(`{"time":"`) + buf.WriteString(l.timeNow().Format(time.RFC3339Nano)) + buf.WriteString(`","level":"`) + buf.WriteString(level) + buf.WriteString(`","prefix":"echo","message":`) + + buf.WriteString(strconv.Quote(message)) + buf.WriteString("}\n") + + return l.write(buf.Bytes()) +} + +func (l *jsonLogger) write(p []byte) (int, error) { + l.lock.Lock() + defer l.lock.Unlock() + return l.writer.Write(p) +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 000000000..c7b4674e9 --- /dev/null +++ b/log_test.go @@ -0,0 +1,87 @@ +package echo + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +type noOpLogger struct { +} + +func (l *noOpLogger) Write(p []byte) (n int, err error) { + return 0, err +} + +func (l *noOpLogger) Error(err error) { +} + +func TestJsonLogger_Write(t *testing.T) { + var testCases = []struct { + name string + when []byte + expect string + }{ + { + name: "ok, write non JSONlike message", + when: []byte("version: %v, build: %v"), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: %v, build: %v"}` + "\n", + }, + { + name: "ok, write quoted message", + when: []byte(`version: "%v"`), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: \"%v\""}` + "\n", + }, + { + name: "ok, write JSON", + when: []byte(`{"version": 123}` + "\n"), + expect: `{"version": 123}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + _, err := logger.Write(tc.when) + + result := buf.String() + assert.Equal(t, tc.expect, result) + assert.NoError(t, err) + }) + } +} + +func TestJsonLogger_Error(t *testing.T) { + var testCases = []struct { + name string + whenError error + expect string + }{ + { + name: "ok", + whenError: ErrForbidden, + expect: `{"time":"2021-09-07T20:09:37Z","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + logger.Error(tc.whenError) + + result := buf.String() + assert.Equal(t, tc.expect, result) + }) + } +} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 000000000..77cb226dd --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,11 @@ +# Development Guidelines for middlewares + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9fc..82e2fbf7a 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,64 +1,59 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" + "fmt" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } if config.Realm == "" { config.Realm = defaultRealm @@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = fmt.Errorf("invalid basic auth value: %w", errDecode) + continue } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } + if lastError != nil { + return lastError + } + realm := defaultRealm if config.Realm != defaultRealm { realm = strconv.Quote(config.Realm) @@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..9580dff0b 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,70 +2,157 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + validatorFunc := func(c echo.Context, u, p string) (bool, error) { if u == "joe" && p == "secret" { return true, nil } + if u == "error" { + return false, errors.New(p) + } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + defaultConfig := BasicAuthConfig{Validator: validatorFunc} - assert := assert.New(t) + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "invalid basic auth value: illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + config := tc.givenConfig - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + mw, err := config.ToMiddleware() + assert.NoError(t, err) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err = h(c) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) + }) + + mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab2..390c37d64 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,71 +3,65 @@ package middleware import ( "bufio" "bytes" + "errors" "io" "io/ioutil" "net" "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } - - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte) -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } // Request reqBody := []byte{} - if c.Request().Body != nil { // Read + if c.Request().Body != nil { reqBody, _ = ioutil.ReadAll(c.Request().Body) } c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset @@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} c.Response().Writer = writer - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) // Callback config.Handler(c, reqBody, resBody.Bytes()) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f726..323f46c15 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) { requestBody = string(reqBody) responseBody = string(resBody) - }) - - assert := assert.New(t) + }}.ToMiddleware() + assert.NoError(t, err) - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, +} + +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) + isCalled = true }, - }) + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) @@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}) + assert.NotNil(t, mw) + }) +} - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) }) } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..6b32f9d43 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,98 +1,83 @@ package middleware import ( - "fmt" "io" "sync" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" + "github.com/labstack/echo/v5" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } - - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 + context echo.Context +} // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper + config.Skipper = DefaultSkipper } - - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) + pool := sync.Pool{ + New: func() interface{} { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(c, req.Body) defer pool.Put(r) req.Body = r return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -102,16 +87,8 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) { r.reader = reader r.context = context r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0e8642a06..f367f9382 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -7,11 +7,11 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestBodyLimit(t *testing.T) { +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) @@ -25,33 +25,42 @@ func TestBodyLimit(t *testing.T) { return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } // Based on content read (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) - } + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) } func TestBodyLimitReader(t *testing.T) { @@ -61,9 +70,8 @@ func TestBodyLimitReader(t *testing.T) { rec := httptest.NewRecorder() config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, + Skipper: DefaultSkipper, + LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, @@ -78,8 +86,80 @@ func TestBodyLimitReader(t *testing.T) { // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) } + +func TestBodyLimit_skipper(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw, err := BodyLimitConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimit(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimit(2 * MB) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} diff --git a/middleware/compress.go b/middleware/compress.go index ac6672e9d..d383cac63 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "compress/gzip" + "errors" "io" "io/ioutil" "net" @@ -10,54 +11,49 @@ import ( "strings" "sync" - "github.com/labstack/echo/v4" -) - -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - } - - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - wroteBody bool - } + "github.com/labstack/echo/v5" ) const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - } -) +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. + // Gzip compression level. + // Optional. Default value -1. + Level int +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteBody bool +} + +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + return GzipWithConfig(GzipConfig{}) } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 } pool := gzipCompressPool(config) @@ -98,7 +94,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index b62bffef5..d6b4f60ed 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,94 +3,128 @@ package middleware import ( "bytes" "compress/gzip" - "io" "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { + // Skip if no Accept-Encoding header + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Accept-Encoding header + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - assert.Equal("test", rec.Body.String()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal("test", buf.String()) - } + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - chunkBuf := make([]byte, 5) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("first\n")) c.Response().Flush() - // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) - - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("second\n")) c.Response().Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) + c.Response().Write([]byte("third")) + + chunkChan <- struct{}{} return nil - })(c) + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) } -func TestGzipNoContent(t *testing.T) { +func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) { } } -func TestGzipEmpty(t *testing.T) { +func TestGzip_Empty(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -127,7 +161,7 @@ func TestGzipEmpty(t *testing.T) { } } -func TestGzipErrorReturned(t *testing.T) { +func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c echo.Context) error { @@ -141,31 +175,25 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } -func TestGzipErrorReturnedInvalidConfig(t *testing.T) { - e := echo.New() - // Invalid level - e.Use(GzipWithConfig(GzipConfig{Level: 12})) - e.GET("/", func(c echo.Context) error { - c.Response().Write([]byte("test")) - return nil - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "gzip") +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) } // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() + e.Filesystem = os.DirFS("../") + e.Use(Gzip()) - e.Static("/test", "../_fixture/images") + e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. diff --git a/middleware/cors.go b/middleware/cors.go index 16259512a..78b44975d 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -6,67 +6,63 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` - - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - AllowCredentials bool `yaml:"allow_credentials"` - - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // AllowOrigin defines a list of origins that may access the resource. + // Optional. Default value []string{"*"}. + AllowOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) + + // AllowMethods defines a list methods allowed when accessing the resource. + // This is used in response to a preflight request. + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + AllowMethods []string + + // AllowHeaders defines a list of request headers that can be used when + // making the actual request. This is in response to a preflight request. + // Optional. Default value []string{}. + AllowHeaders []string + + // AllowCredentials indicates whether or not the response to the request + // can be exposed when the credentials flag is true. When used as part of + // a response to a preflight request, this indicates whether or not the + // actual request can be made using credentials. + // Optional. Default value false. + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + AllowCredentials bool + + // ExposeHeaders defines a whitelist headers that clients are allowed to + // access. + // Optional. Default value []string{}. + ExposeHeaders []string + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached. + // Optional. Default value 0. + MaxAge int +} -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS @@ -74,9 +70,14 @@ func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } -// CORSWithConfig returns a CORS middleware with config. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // See: `CORS()`. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCORSConfig.Skipper @@ -172,7 +173,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { checkPatterns := false if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { + if len(origin) <= (5+3+253) && strings.Contains(origin, "://") { checkPatterns = true } } @@ -230,5 +231,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } return c.NoContent(http.StatusNoContent) } - } + }, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index daadbab6e..2299a885d 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -17,7 +17,7 @@ func TestCORS(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) + h := CORS()(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -26,7 +26,7 @@ func TestCORS(t *testing.T) { req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h = CORS()(echo.NotFoundHandler) + h = CORS()(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) @@ -38,7 +38,7 @@ func TestCORS(t *testing.T) { AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 3600, - })(echo.NotFoundHandler) + })(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -55,7 +55,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -73,7 +73,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -90,7 +90,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) @@ -104,7 +104,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"http://*.example.com"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -324,7 +324,9 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) } - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) @@ -511,11 +513,11 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors := CORSWithConfig(CORSConfig{ - AllowOriginFunc: allowOriginFunc, - }) - h := cors(echo.NotFoundHandler) - err := h(c) + cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) + err = h(c) expected, expectedErr := allowOriginFunc(origin) if expectedErr != nil { diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..acab8790b 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -5,91 +5,93 @@ import ( "net/http" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" or "header::" - // - "query:" - // - "form:" - // Multiple sources example: - // - "header:X-CSRF-Token,query:csrf" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - } -) +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite +} // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper @@ -97,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -113,9 +118,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - extractors, err := createExtractors(config.TokenLookup, "") + extractors, err := createExtractors(config.TokenLookup) if err != nil { - panic(err) + return nil, err } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -126,7 +131,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = config.Generator() // Generate token } else { token = k.Value // Reuse token } @@ -157,17 +162,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if lastTokenErr != nil { return lastTokenErr } else if lastExtractorErr != nil { - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") - } else if lastExtractorErr == errFormExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") - } else { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) - } - return lastExtractorErr + return echo.ErrBadRequest.WithInternal(lastExtractorErr) } } @@ -197,7 +192,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } func validateCSRFToken(token, clientToken string) bool { diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..f8af5e9cc 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -7,22 +7,22 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCSRF_tokenExtractors(t *testing.T) { var testCases = []struct { - name string - whenTokenLookup string - whenCookieName string - givenCSRFCookie string - givenMethod string - givenQueryTokens map[string][]string - givenFormTokens map[string][]string - givenHeaderTokens map[string][]string - expectError string + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + expectToMiddlewareError string }{ { name: "ok, multiple token lookups sources, succeeds on last one", @@ -70,7 +70,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the form parameter", + expectError: "code=400, message=Bad Request, internal=missing value in the form", }, { name: "ok, token from POST header", @@ -106,7 +106,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in request header", + expectError: "code=400, message=Bad Request, internal=missing value in request header", }, { name: "ok, token from PUT query param", @@ -142,7 +142,15 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the query string", + expectError: "code=400, message=Bad Request, internal=missing value in the query string", + }, + { + name: "nok, invalid TokenLookup", + whenTokenLookup: "q", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, } @@ -186,16 +194,23 @@ func TestCSRF_tokenExtractors(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + config := CSRFConfig{ TokenLookup: tc.whenTokenLookup, CookieName: tc.whenCookieName, - }) + } + csrf, err := config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) + return + } else if err != nil { + assert.NoError(t, err) + } h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) - err := h(c) + err = h(c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -219,6 +234,24 @@ func TestCSRF(t *testing.T) { h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") +} + +func TestMustCSRFWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + csrf := CSRFWithConfig(CSRFConfig{ + TokenLength: 16, + }) + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Generate CSRF token + h(c) + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") + // Without CSRF cookie req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() @@ -233,7 +266,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(16) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -302,9 +335,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, - }) + }.ToMiddleware() + assert.NoError(t, err) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/decompress.go b/middleware/decompress.go index 88ec70982..dcf7172fa 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -6,21 +6,19 @@ import ( "net/http" "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} -//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers @@ -28,14 +26,6 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) - // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } @@ -44,19 +34,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} } -//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DefaultDecompressConfig) + return DecompressWithConfig(DecompressConfig{}) } -//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + config.GzipDecompressPool = &DefaultGzipDecompressPool{} } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -95,5 +89,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f1..c35ed6fa3 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -11,54 +11,82 @@ import ( "sync" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - // Skip if no Content-Encoding header h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) - - // Decompress + // Decompress request body body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } -func TestDecompressDefaultConfig(t *testing.T) { +func TestDecompress_skippedIfNoHeader(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + })(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) // Decompress body := `{"name": "echo"}` @@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { @@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) assert.NoError(t, err) @@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) { h := Decompress()(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) - if assert.NoError(t, h(c)) { + + err := h(c) + + if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) diff --git a/middleware/extractor.go b/middleware/extractor.go index a57ed4e13..94134f7e2 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -1,9 +1,8 @@ package middleware import ( - "errors" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/textproto" "strings" ) @@ -14,17 +13,27 @@ const ( extractorLimit = 20 ) -var errHeaderExtractorValueMissing = errors.New("missing value in request header") -var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") -var errQueryExtractorValueMissing = errors.New("missing value in the query string") -var errParamExtractorValueMissing = errors.New("missing value in path params") -var errCookieExtractorValueMissing = errors.New("missing value in cookies") -var errFormExtractorValueMissing = errors.New("missing value in the form") +// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups +type ValueExtractorError struct { + message string +} + +// Error returns errors text +func (e *ValueExtractorError) Error() string { + return e.message +} + +var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} +var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} +var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} +var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} +var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} +var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. type ValuesExtractor func(c echo.Context) ([]string, error) -func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { +func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } @@ -49,15 +58,6 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err prefix := "" if len(parts) > 2 { prefix = parts[2] - } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { - // backwards compatibility for JWT and KeyAuth: - // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc - // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that - // behaviour for default values and Authorization header. - prefix = authScheme - if !strings.HasSuffix(prefix, " ") { - prefix += " " - } } extractors = append(extractors, valuesFromHeader(parts[1], prefix)) } @@ -125,10 +125,9 @@ func valuesFromQuery(param string) ValuesExtractor { func valuesFromParam(param string) ValuesExtractor { return func(c echo.Context) ([]string, error) { result := make([]string, 0) - paramVales := c.ParamValues() - for i, p := range c.ParamNames() { - if param == p { - result = append(result, paramVales[i]) + for i, p := range c.PathParams() { + if param == p.Name { + result = append(result, p.Value) if i >= extractorLimit-1 { break } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index ae4b30a8a..439c4d8fe 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -2,7 +2,7 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -11,27 +11,11 @@ import ( "testing" ) -type pathParam struct { - name string - value string -} - -func setPathParams(c echo.Context, params []pathParam) { - names := make([]string, 0, len(params)) - values := make([]string, 0, len(params)) - for _, pp := range params { - names = append(names, pp.name) - values = append(values, pp.value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) -} - func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request - givenPathParams []pathParam + givenPathParams echo.PathParams whenLoopups string expectValues []string expectCreateError string @@ -72,8 +56,8 @@ func TestCreateExtractors(t *testing.T) { }, { name: "ok, param", - givenPathParams: []pathParam{ - {name: "id", value: "123"}, + givenPathParams: echo.PathParams{ + {Name: "id", Value: "123"}, }, whenLoopups: "param:id", expectValues: []string{"123"}, @@ -103,12 +87,12 @@ func TestCreateExtractors(t *testing.T) { req = tc.givenRequest() } rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } - extractors, err := createExtractors(tc.whenLoopups, "") + extractors, err := createExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return @@ -315,19 +299,19 @@ func TestValuesFromQuery(t *testing.T) { } func TestValuesFromParam(t *testing.T) { - examplePathParams := []pathParam{ - {name: "id", value: "123"}, - {name: "gid", value: "456"}, - {name: "gid", value: "789"}, + examplePathParams := echo.PathParams{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, } - examplePathParams20 := make([]pathParam, 0) + examplePathParams20 := make(echo.PathParams, 0) for i := 1; i < 25; i++ { - examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + examplePathParams20 = append(examplePathParams20, echo.PathParam{Name: "id", Value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string - givenPathParams []pathParam + givenPathParams echo.PathParams whenName string expectValues []string expectError string @@ -375,9 +359,9 @@ func TestValuesFromParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } extractor := valuesFromParam(tc.whenName) diff --git a/middleware/jwt.go b/middleware/jwt.go index bec5167e2..40b45e77e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,150 +1,98 @@ -//go:build go1.15 -// +build go1.15 - package middleware import ( "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" - "reflect" ) -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next - // middleware or handler. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. - ContinueOnIgnoredError bool - - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. - // If prefix is left empty the whole value is returned. - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiple sources example: - // - "header:Authorization,cookie:myowncookie" - TokenLookup string - - // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. - // This is one of the two options to provide a token extractor. - // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. - // You can also provide both if you want. - TokenLookupFuncs []ValuesExtractor - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } +// JWTConfig defines the config for JWT middleware. +type JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a valid token. + SuccessHandler JWTSuccessHandler + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom JWT error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. + ErrorHandler JWTErrorHandlerWithContext + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + + // Context key to store user information from the token into context. + // Optional. Default value "user". + ContextKey string + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization:Bearer ,cookie:myowncookie" + TokenLookup string + + // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []ValuesExtractor + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) +} - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(c echo.Context) +// JWTSuccessHandler defines a function which is executed for a valid token. +type JWTSuccessHandler func(c echo.Context) - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(err error) error +// JWTErrorHandler defines a function which is executed for an invalid token. +type JWTErrorHandler func(err error) error - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(err error, c echo.Context) error -) +// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. +type JWTErrorHandlerWithContext func(c echo.Context, err error) error -// Algorithms const ( + // AlgorithmHS256 is token signing algorithm AlgorithmHS256 = "HS256" ) -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) +// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request +var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt") -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - TokenLookupFuncs: nil, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) +// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired +var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") + +// DefaultJWTConfig is the default JWT auth middleware config. +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + ContextKey: "user", + TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", +} // JWT returns a JSON Web Token (JWT) auth middleware. // @@ -153,48 +101,40 @@ var ( // For missing token, it returns "400 - Bad Request" error. // // See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -func JWT(key interface{}) echo.MiddlewareFunc { +func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc { c := DefaultJWTConfig - c.SigningKey = key + c.ParseTokenFunc = parseTokenFunc return JWTWithConfig(c) } -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. +// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid. +// +// For valid token, it sets the user in context and calls next handler. +// For invalid token, it returns "401 - Unauthorized" error. +// For missing token, it returns "400 - Bad Request" error. +// +// See: https://jwt.io/introduction func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration +func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod + if config.ParseTokenFunc == nil { + return nil, errors.New("echo jwt middleware requires parse token function") } if config.ContextKey == "" { config.ContextKey = DefaultJWTConfig.ContextKey } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { config.TokenLookup = DefaultJWTConfig.TokenLookup } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - if config.KeyFunc == nil { - config.KeyFunc = config.defaultKeyFunc - } - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) + extractors, err := createExtractors(config.TokenLookup) if err != nil { - panic(err) + return nil, err } if len(config.TokenLookupFuncs) > 0 { extractors = append(config.TokenLookupFuncs, extractors...) @@ -209,17 +149,16 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - var lastExtractorErr error var lastTokenErr error for _, extractor := range extractors { - auths, err := extractor(c) - if err != nil { - lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) + auths, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, auth := range auths { - token, err := config.ParseTokenFunc(auth, c) + token, err := config.ParseTokenFunc(c, auth) if err != nil { lastTokenErr = err continue @@ -232,69 +171,23 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return next(c) } } - // we are here only when we did not successfully extract or parse any of the tokens + + // prioritize token errors over extracting errors err := lastTokenErr - if err == nil { // prioritize token errors over extracting errors + if err == nil { err = lastExtractorErr } if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - tmpErr := config.ErrorHandlerWithContext(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - - // backwards compatible errors codes - if lastTokenErr != nil { - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, - } - } - return err // this is lastExtractorErr value - } - } -} - -func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - token := new(jwt.Token) - var err error - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil -} - -// defaultKeyFunc returns a signing key of the given token. -func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil + if lastTokenErr == nil { + return ErrJWTMissing.WithInternal(err) } + return ErrJWTInvalid.WithInternal(err) } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil + }, nil } diff --git a/middleware/jwt_external_test.go b/middleware/jwt_external_test.go new file mode 100644 index 000000000..1b92f188f --- /dev/null +++ b/middleware/jwt_external_test.go @@ -0,0 +1,76 @@ +package middleware_test + +import ( + "errors" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "net/http" + "net/http/httptest" +) + +// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc +// +// signingKey is signing key to validate token. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKeys is not provided. +// +// signingKeys is Map of signing keys to validate token with kid field usage. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKey is not provided +func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) { + // keyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != middleware.AlgorithmHS256 { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + if len(signingKeys) == 0 { + return signingKey, nil + } + + if kid, ok := t.Header["kid"].(string); ok { + if key, ok := signingKeys[kid]; ok { + return key, nil + } + } + return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + +func ExampleJWTConfig_withJWTGoAsTokenParser() { + mw := middleware.JWTWithConfig(middleware.JWTConfig{ + ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil), + }) + + e := echo.New() + e.Use(mw) + + e.GET("/", func(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusTeapot, user.Claims) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + fmt.Printf("status: %v, body: %v", res.Code, res.Body.String()) + // Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"} +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index eee9df966..5e5b99121 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,6 +1,3 @@ -//go:build go1.15 -// +build go1.15 - package middleware import ( @@ -12,11 +9,32 @@ import ( "strings" "testing" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) +func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) { + // This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != signingMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return signingKey, nil + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + // jwtCustomInfo defines some custom types we're going to use within our tokens. type jwtCustomInfo struct { Name string `json:"name"` @@ -25,7 +43,7 @@ type jwtCustomInfo struct { // jwtCustomClaims are custom claims expanding default ones. type jwtCustomClaims struct { - *jwt.StandardClaims + *jwt.RegisteredClaims jwtCustomInfo } @@ -37,7 +55,7 @@ func TestJWT(t *testing.T) { return c.JSON(http.StatusOK, token.Claims) }) - e.Use(JWT([]byte("secret"))) + e.Use(JWT(createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")))) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") @@ -49,247 +67,197 @@ func TestJWT(t *testing.T) { assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) } -func TestJWTRace(t *testing.T) { +func TestJWT_combinations(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWTConfig(t *testing.T) { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" validKey := []byte("secret") invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token - - testCases := []struct { - name string - expPanic bool - expErrCode int // 0 for Success - config JWTConfig - reqURL string // "/" if empty - hdrAuth string - hdrCookie string // test.Request doesn't provide SetCookie(); use name=val - formValues map[string]string + validAuth := "Bearer " + token + + var testCases = []struct { + name string + config JWTConfig + reqURL string // "/" if empty + hdrAuth string + hdrCookie string // test.Request doesn't provide SetCookie(); use name=val + formValues map[string]string + expectPanic bool + expectToMiddlewareError string + expectError string }{ { - name: "No signing key provided", - expPanic: true, + name: "No signing key provided", + expectToMiddlewareError: "echo jwt middleware requires parse token function", }, { - name: "Unexpected signing method", - expErrCode: http.StatusBadRequest, + name: "invalid TokenLookup", config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", + ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), + TokenLookup: "q", }, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, { - name: "Invalid key", - expErrCode: http.StatusUnauthorized, - hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, + name: "Unexpected signing method", + hdrAuth: validAuth, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), + }, + expectError: "code=401, message=invalid or expired jwt, internal=unexpected jwt signing method=HS256", + }, + { + name: "Invalid key", + hdrAuth: validAuth, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey), + }, + expectError: "code=401, message=invalid or expired jwt, internal=signature is invalid", }, { name: "Valid JWT", hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, }, { name: "Valid JWT with custom AuthScheme", hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, + config: JWTConfig{ + TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, }, { name: "Valid JWT with custom claims", hdrAuth: validAuth, config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), }, }, { - name: "Invalid Authorization header", - hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, + name: "Invalid Authorization header", + hdrAuth: "invalid-auth", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", }, { - name: "Empty header auth field", - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, + name: "Empty header auth field", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", }, { name: "Valid query method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, }, { name: "Invalid query param name", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, + reqURL: "/?a=b&jwtxyz=" + token, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", }, { name: "Invalid query param value", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b&jwt=invalid-token", - expErrCode: http.StatusUnauthorized, + reqURL: "/?a=b&jwt=invalid-token", + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty query", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, + reqURL: "/?a=b", + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", }, { - name: "Valid param method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "param:jwt", }, reqURL: "/" + token, + name: "Valid param method", }, { - name: "Valid cookie method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, + name: "Valid cookie method", }, { - name: "Multiple jwt lookuop", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt,cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, + name: "Multiple jwt lookuop", }, { name: "Invalid token with cookie method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, - expErrCode: http.StatusUnauthorized, - hdrCookie: "jwt=invalid", + hdrCookie: "jwt=invalid", + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty cookie", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, - expErrCode: http.StatusBadRequest, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in cookies", }, { name: "Valid form method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, }, { name: "Invalid token with form method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, - expErrCode: http.StatusUnauthorized, - formValues: map[string]string{"jwt": "invalid"}, + formValues: map[string]string{"jwt": "invalid"}, + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty form field", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid JWT with a valid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return validKey, nil - }, + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, - }, - { - name: "Valid JWT with an invalid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return invalidKey, nil - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Token verification does not pass using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return nil, errors.New("faulty KeyFunc") - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Valid JWT with lower case AuthScheme", - hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, - config: JWTConfig{SigningKey: validKey}, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the form", }, } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() if tc.reqURL == "" { tc.reqURL = "/" } @@ -312,128 +280,36 @@ func TestJWTConfig(t *testing.T) { c := e.NewContext(req, res) if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) + cc := c.(echo.ServableContext) + cc.SetPathParams(echo.PathParams{ + {Name: "jwt", Value: token}, + }) } - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.name) + mw, err := tc.config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) return } - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.name) + hErr := mw(handler)(c) + if tc.expectError != "" { + assert.EqualError(t, hErr, tc.expectError) return } + assert.NoError(t, hErr) - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.name) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.name) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.name) - assert.Equal(t, claims.Admin, true, tc.name) - default: - panic("unexpected type of claims") - } - } - }) - } -} - -func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { user := c.Get("user").(*jwt.Token) switch claims := user.Claims.(type) { case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) + assert.Equal(t, claims["name"], "John Doe") case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) + assert.Equal(t, claims.Name, "John Doe") + assert.Equal(t, claims.Admin, true) default: panic("unexpected type of claims") } - } + }) } } @@ -444,7 +320,7 @@ func TestJWTConfig_skipper(t *testing.T) { Skipper: func(context echo.Context) bool { return true // skip everything }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) isCalled := false @@ -472,11 +348,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) { BeforeFunc: func(context echo.Context) { isCalled = true }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -493,18 +369,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom_error") }, }, @@ -539,23 +405,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, }, expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", - }, } for _, tc := range testCases { @@ -568,14 +424,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { config := tc.given parseTokenCalled := false - config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { + config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -598,7 +454,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { signingKey := []byte("secret") config := JWTConfig{ - ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) @@ -621,125 +477,130 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) } -func TestJWTConfig_TokenLookupFuncs(t *testing.T) { +func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) + success := c.Get("success").(string) + user := c.Get("user").(string) + return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) }) - e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil - }, + mw, err := JWTConfig{ + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return auth, nil }, - SigningKey: []byte("secret"), - })) + SuccessHandler: func(c echo.Context) { + c.Set("success", "yes") + }, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") res := httptest.NewRecorder() e.ServeHTTP(res, req) - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) + assert.Equal(t, "yes:valid_token_base64", res.Body.String()) + assert.Equal(t, http.StatusTeapot, res.Code) } -func TestJWTConfig_SuccessHandler(t *testing.T) { +func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { var testCases = []struct { - name string - givenToken string - expectCalled bool - expectStatus int + name string + givenContinueOnIgnoredError bool + givenErrorHandler JWTErrorHandlerWithContext + givenTokenLookup string + whenAuthHeaders []string + whenCookies []string + whenParseReturn string + whenParseError error + expectHandlerCalled bool + expect string + expectCode int }{ { - name: "ok, success handler is called", - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectCalled: true, - expectStatus: http.StatusOK, + name: "ok, with valid JWT from auth header", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + return nil + }, + whenAuthHeaders: []string{"Bearer valid_token_base64"}, + whenParseReturn: "valid_token", + expectCode: http.StatusTeapot, + expect: "valid_token", }, { - name: "nok, success handler is not called", - givenToken: "x.x.x", - expectCalled: false, - expectStatus: http.StatusUnauthorized, + name: "ok, missing header, callNext and set public_token from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + if errors.Is(err, &ValueExtractorError{}) { + panic("must get ErrJWTMissing") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusTeapot, + expect: "public_token", }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - wasCalled := false - e.Use(JWTWithConfig(JWTConfig{ - SuccessHandler: func(c echo.Context) { - wasCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectCalled, wasCalled) - assert.Equal(t, tc.expectStatus, res.Code) - }) - } -} - -func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { - var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenToken string - expectStatus int - expectBody string - }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, invalid token, callNext and set public_token from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + // this is probably not realistic usecase. on parse error you probably want to return error + if err.Error() != "parser_error" { + panic("must get parser_error") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusTeapot, + expect: "public_token", }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenToken: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "nok, invalid token, return error from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + if err.Error() != "parser_error" { + panic("must get parser_error") + } + return err + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusInternalServerError, + expect: "{\"message\":\"Internal Server Error\"}\n", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenToken: "", - expectStatus: http.StatusTeapot, - expectBody: "public-token", + name: "nok, ContinueOnIgnoredError but return error from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + return echo.ErrUnauthorized.WithInternal(err) + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"Unauthorized\"}\n", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenToken: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "nok, ContinueOnIgnoredError=false", + givenContinueOnIgnoredError: false, + givenErrorHandler: func(c echo.Context, err error) error { + return echo.ErrUnauthorized.WithInternal(err) + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"Unauthorized\"}\n", }, } @@ -748,32 +609,56 @@ func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) + token := c.Get("user").(string) + return c.String(http.StatusTeapot, token) }) - e.Use(JWTWithConfig(JWTConfig{ - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, c echo.Context) error { - if err == ErrJWTMissing { - c.Set("test", "public-token") - return nil - } - return echo.ErrUnauthorized + mw, err := JWTConfig{ + ContinueOnIgnoredError: tc.givenContinueOnIgnoredError, + TokenLookup: tc.givenTokenLookup, + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return tc.whenParseReturn, tc.whenParseError }, - })) + ErrorHandler: tc.givenErrorHandler, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenToken != "" { - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + for _, a := range tc.whenAuthHeaders { + req.Header.Add(echo.HeaderAuthorization, a) } res := httptest.NewRecorder() - e.ServeHTTP(res, req) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) + assert.Equal(t, tc.expect, res.Body.String()) + assert.Equal(t, tc.expectCode, res.Code) }) } } + +func TestJWTConfig_TokenLookupFuncs(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + e.Use(JWTWithConfig(JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + TokenLookupFuncs: []ValuesExtractor{ + func(c echo.Context) ([]string, error) { + return []string{c.Request().Header.Get("X-API-Key")}, nil + }, + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index e8a6b0853..77a001ea8 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,81 +2,69 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "fmt" + "github.com/labstack/echo/v5" "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" or ":,:" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. - // - "query:" - // - "form:" - // - "cookie:" - // Multiple sources example: - // - "header:Authorization,header:X-Api-Key" - KeyLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandler to set a default public key auth value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. - ContinueOnIgnoredError bool - } +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(c echo.Context, key string) (bool, error) - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(err error, c echo.Context) error -) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(c echo.Context, err error) error -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") -// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups -type ErrKeyAuthMissing struct { - Err error -} +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") -// Error returns errors text -func (e *ErrKeyAuthMissing) Error() string { - return e.Err.Error() -} - -// Unwrap unwraps error -func (e *ErrKeyAuthMissing) Unwrap() error { - return e.Err +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", } // KeyAuth returns an KeyAuth middleware. @@ -90,27 +78,33 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) + extractors, err := createExtractors(config.KeyLookup) if err != nil { - panic(err) + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -122,59 +116,41 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, err := extractor(c) - if err != nil { - lastExtractorErr = err + keys, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, key := range keys { - valid, err := config.Validator(key, c) + valid, err := config.Validator(c, key) if err != nil { lastValidatorErr = err continue } - if valid { - return next(c) + if !valid { + lastValidatorErr = ErrInvalidKey + continue } - lastValidatorErr = errors.New("invalid key") + return next(c) } } - // we are here only when we did not successfully extract and validate any of keys + // prioritize validator errors over extracting errors err := lastValidatorErr - if err == nil { // prioritize validator errors over extracting errors - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - err = errors.New("missing key in the query string") - } else if lastExtractorErr == errCookieExtractorValueMissing { - err = errors.New("missing key in cookies") - } else if lastExtractorErr == errFormExtractorValueMissing { - err = errors.New("missing key in the form") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - err = errors.New("missing key in request header") - } else if lastExtractorErr == errHeaderExtractorValueInvalid { - err = errors.New("invalid key in the request header") - } else { - err = lastExtractorErr - } - err = &ErrKeyAuthMissing{Err: err} + if err == nil { + err = lastExtractorErr } - if config.ErrorHandler != nil { - tmpErr := config.ErrorHandler(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - if lastValidatorErr != nil { // prioritize validator errors over extracting errors - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "Unauthorized", - Internal: lastValidatorErr, - } + if lastValidatorErr == nil { + return ErrKeyMissing.WithInternal(err) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + return echo.ErrUnauthorized.WithInternal(err) } - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index ff8968c38..1b64865fb 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -7,11 +7,11 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func testKeyValidator(key string, c echo.Context) (bool, error) { +func testKeyValidator(c echo.Context, key string) (bool, error) { switch key { case "valid-key": return true, nil @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, internal=invalid key", + expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -84,24 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=400, message=invalid key in the request header", + expectError: "code=401, message=missing key, internal=invalid value in request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", - }, - { - name: "ok, custom key lookup from multiple places, query and header", - givenRequest: func(req *http.Request) { - req.URL.RawQuery = "key=invalid-key" - req.Header.Set("API-Key", "valid-key") - }, - whenConfig: func(conf *KeyAuthConfig) { - conf.KeyLookup = "query:key,header:API-Key" - }, - expectHandlerCalled: true, + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, header", @@ -121,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, query", @@ -141,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the query string", + expectError: "code=401, message=missing key, internal=missing value in the query string", }, { name: "ok, custom key lookup, form", @@ -166,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the form", + expectError: "code=401, message=missing key, internal=missing value in the form", }, { name: "ok, custom key lookup, cookie", @@ -190,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies", + expectError: "code=401, message=missing key, internal=missing value in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=missing key in request header", + expectError: "code=418, message=custom, internal=missing value in request header", }, { name: "nok, custom errorHandler, error from validator", @@ -211,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError @@ -269,108 +258,96 @@ func TestKeyAuthWithConfig(t *testing.T) { } } -func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { - assert.PanicsWithError( - t, - "extractor source for lookup could not be split into needed parts: a", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - KeyLookup: "a", - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { - assert.PanicsWithValue( - t, - "echo: key-auth middleware requires a validator function", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: nil, - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { +func TestKeyAuthWithConfig_errors(t *testing.T) { var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenKey string - expectStatus int - expectBody string + name string + whenConfig KeyAuthConfig + expectError string }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenKey: "valid-key", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenKey: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenKey: "", - expectStatus: http.StatusTeapot, - expectBody: "public-auth", + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenKey: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) + }) +} - e.Use(KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - ErrorHandler: func(err error, c echo.Context) error { - if _, ok := err.(*ErrKeyAuthMissing); ok { - c.Set("test", "public-auth") - return nil - } - return echo.ErrUnauthorized - }, - KeyLookup: "header:X-API-Key", - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - })) +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + ContinueOnIgnoredError: true, + })(handler) - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenKey != "" { - req.Header.Set("X-API-Key", tc.givenKey) - } - res := httptest.NewRecorder() + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - e.ServeHTTP(res, req) + err := middlewareChain(c) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) } diff --git a/middleware/logger.go b/middleware/logger.go index 9baac4769..bd2d3d932 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,88 +3,86 @@ package middleware import ( "bytes" "encoding/json" + "fmt" "io" "strconv" "strings" "sync" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" + "github.com/labstack/echo/v5" "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) +// LoggerConfig defines the config for Logger middleware. +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) + // Tags to construct the logger format. + // + // - time_unix + // - time_unix_nano + // - time_rfc3339 + // - time_rfc3339_nano + // - time_custom + // - id (Request ID) + // - remote_ip + // - uri + // - host + // - method + // - path + // - protocol + // - referer + // - user_agent + // - status + // - error + // - latency (In nanoseconds) + // - latency_human (Human readable) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) + // - header: + // - query: + // - form: + // + // Example "${remote_ip} ${status}" + // + // Optional. Default value DefaultLoggerConfig.Format. + Format string + + // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. + CustomTimeFormat string + + // Output is a writer where logs in JSON format are written. + // Optional. Default destination `echo.Logger.Infof()` + Output io.Writer + + template *fasttemplate.Template + pool *sync.Pool +} + +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","level":"INFO","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", +} // Logger returns a middleware that logs HTTP requests. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. +// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration +func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultLoggerConfig.Skipper @@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if config.Format == "" { config.Format = DefaultLoggerConfig.Format } - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) config.pool = &sync.Pool{ New: func() interface{} { return bytes.NewBuffer(make([]byte, 256)) @@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() + start := time.Now() - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) stop := time.Now() + buf := config.pool.Get().(*bytes.Buffer) buf.Reset() defer config.pool.Put(buf) - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { + _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) @@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "user_agent": return buf.WriteString(req.UserAgent()) case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) + status := res.Status + if err != nil { + if httpErr, ok := err.(*echo.HTTPError); ok { + status = httpErr.Code + } } - return buf.WriteString(s) + return buf.WriteString(strconv.Itoa(status)) case "error": if err != nil { // Error may contain invalid JSON e.g. `"` @@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case strings.HasPrefix(tag, "form:"): return buf.Write([]byte(c.FormValue(tag[5:]))) case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { + cookie, cookieErr := c.Cookie(tag[7:]) + if cookieErr == nil { return buf.Write([]byte(cookie.Value)) } } } return 0, nil - }); err != nil { - return + }) + if tmplErr != nil { + if err != nil { + return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err) + } + return fmt.Errorf("failed to create log from template: %w", tmplErr) } - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return + if config.Output != nil { + if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil { + return lErr + } + } else { + if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil { + return lErr + } } - _, err = config.Output.Write(buf.Bytes()) - return + return err } - } + }, nil } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 394f62712..2f1230dda 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -12,7 +12,7 @@ import ( "time" "unsafe" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} ip := "127.0.0.1" h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..202862f3b 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -3,31 +3,27 @@ package middleware import ( "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and @@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..266a575ba 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index f250ca49a..2f8c8b5c8 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -6,17 +6,14 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(c echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error func DefaultSkipper(echo.Context) bool { return false } + +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731e..1efbc2432 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "math/rand" @@ -15,90 +16,86 @@ import ( "sync/atomic" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.RWMutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.RWMutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - *commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + *commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - *commonBalancer - i uint32 - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + *commonBalancer + i uint32 +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", - } -) +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { +func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { @@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + return nil, errors.New("echo proxy middleware requires balancer") } if config.Rewrite != nil { @@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) + proxyRaw(c, tgt).ServeHTTP(res, req) case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) + proxyHTTP(c, tgt, config).ServeHTTP(res, req) } if e, ok := c.Get("_error").(error); ok { err = e @@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { return } - } + }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c2..1d0dee91e 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -55,7 +55,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -77,7 +77,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -113,15 +113,20 @@ func TestProxy(t *testing.T) { return nil } } - rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { rb := NewRandomBalancer(nil) assert.True(t, rb.AddTarget(target)) e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(req.Context()) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index be2b348db..09237f05b 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,47 +1,42 @@ package middleware import ( + "errors" "net/http" "sync" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - panic("Store configuration must be provided") + return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -137,36 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { identifier, err := config.IdentifierExtractor(c) if err != nil { - c.Error(config.ErrorHandler(c, err)) - return nil + return config.ErrorHandler(c, err) } - if allow, err := config.Store.Allow(identifier); !allow { - c.Error(config.DenyHandler(c, identifier, err)) - return nil + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) } return next(c) } - } + }, nil } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit + burst int + expiresIn time.Duration + lastCleanup time.Time +} - burst int - expiresIn time.Duration - lastCleanup time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 89d9a6edc..de546a19c 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -10,8 +10,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiter(inMemoryStore) + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - _ = mw(handler)(c) - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } -func TestRateLimiter_panicBehaviour(t *testing.T) { +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiter(nil) + RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { - RateLimiter(inMemoryStore) + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } @@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { id string @@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) + err := mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"", http.StatusForbidden}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } } @@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return true }, @@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return false }, @@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) _ = mw(handler)(c) @@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ BeforeFunc: func(c echo.Context) { beforeRan = true }, @@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/recover.go b/middleware/recover.go index a621a9efe..70e98b261 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,53 +4,35 @@ import ( "fmt" "runtime" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" ) -type ( - // LogErrorFunc defines a function for custom logging in the middleware. - LogErrorFunc func(c echo.Context, err error, stack []byte) error +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` - - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - LogErrorFunc LogErrorFunc - } -) + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. @@ -58,9 +40,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -70,49 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } defer func() { if r := recover(); r != nil { - err, ok := r.(error) + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - var stack []byte - var length int - if !config.DisablePrintStack { - stack = make([]byte, config.StackSize) - length = runtime.Stack(stack, !config.DisableStackAll) - stack = stack[:length] + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length]) } - - if config.LogErrorFunc != nil { - err = config.LogErrorFunc(c, err, stack) - } else if !config.DisablePrintStack { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { - case log.DEBUG: - c.Logger().Debug(msg) - case log.INFO: - c.Logger().Info(msg) - case log.WARN: - c.Logger().Warn(msg) - case log.ERROR: - c.Logger().Error(msg) - case log.OFF: - // None. - default: - c.Logger().Print(msg) - } - } - c.Error(err) + err = tmpErr } }() return next(c) } - } + }, nil } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 9ac4feedc..a65df541f 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,135 +2,109 @@ package middleware import ( "bytes" - "errors" - "fmt" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c echo.Context) error { panic("test") - })) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") -} - -func TestRecoverWithConfig_LogLevel(t *testing.T) { - tests := []struct { - logLevel log.Lvl - levelName string - }{{ - logLevel: log.DEBUG, - levelName: "DEBUG", - }, { - logLevel: log.INFO, - levelName: "INFO", - }, { - logLevel: log.WARN, - levelName: "WARN", - }, { - logLevel: log.ERROR, - levelName: "ERROR", - }, { - logLevel: log.OFF, - levelName: "OFF", - }} - - for _, tt := range tests { - tt := tt - t.Run(tt.levelName, func(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - config := DefaultRecoverConfig - config.LogLevel = tt.logLevel - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - - h(c) - - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - if tt.logLevel == log.OFF { - assert.Empty(t, output) - } else { - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) - } - }) - } + }) + err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged } -func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { +func TestRecover_skipper(t *testing.T) { e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testError := errors.New("test") - config := DefaultRecoverConfig - config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) - if errors.Is(err, testError) { - c.Logger().Debug(msg) - } else { - c.Logger().Error(msg) - } - return err + config := RecoverConfig{ + Skipper: func(c echo.Context) bool { + return true + }, } + h := RecoverWithConfig(config)(func(c echo.Context) error { + panic("testPANIC") + }) - t.Run("first branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic(testError) - })) + var err error + assert.Panics(t, func() { + err = h(c) + }) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain +} - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"DEBUG"`) - }) +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() - t.Run("else branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("other") - })) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"ERROR"`) - }) + err := h(c) + + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + }) + } } diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db38..bda5ac204 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,10 +1,11 @@ package middleware import ( + "errors" "net/http" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. @@ -14,7 +15,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" && !strings.HasPrefix(host, www) { - return true, "https://www." + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if scheme != "https" { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri - } - return false, "" - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRedirectConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } + return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 9d1b56205..9484bdf20 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..b553321ec 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,50 +1,42 @@ package middleware import ( - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Default value random.String(32). + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(c echo.Context, requestID string) - // TargetHeader defines what header to look for to populate the id - TargetHeader string - } -) - -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, - } -) + // TargetHeader defines what header to look for to populate the id + TargetHeader string +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID @@ -69,9 +61,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return next(c) } - } -} - -func generator() string { - return random.String(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 21b777826..fd0ef5d56 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator and handler - customID := "customGenerator" - calledHandler := false + // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return customID }, - RequestIDHandler: func(_ echo.Context, id string) { - calledHandler = true - assert.Equal(t, customID, id) - }, + Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 1b3e3eaad..63b6402fb 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -2,7 +2,7 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" "time" ) @@ -24,6 +24,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info(). +// Date("request_start", v.StartTime). // Str("URI", v.URI). // Int("status", v.Status). // Msg("request") @@ -39,6 +40,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info("request", +// zap.Time("request_start", v.StartTime), // zap.String("URI", v.URI), // zap.Int("status", v.Status), // ) @@ -54,8 +56,9 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // log.WithFields(logrus.Fields{ -// "URI": values.URI, -// "status": values.Status, +// "request_start": values.StartTime, +// "URI": values.URI, +// "status": values.Status, // }).Info("request") // // return nil @@ -158,15 +161,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is been logger for each given header. + // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is been logger for each given query param name. + // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is been logger for each given form value name. + // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 5118b1216..c5ddced75 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -1,7 +1,7 @@ package middleware import ( - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) { req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) c.SetPath("/test*") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b56..16677263f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,62 +1,58 @@ package middleware import ( + "errors" "regexp" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` - } -) - -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string +} // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil && config.RegexRules == nil { - panic("echo: rewrite middleware requires url path rewrite rules or regex rules") - } + return toMiddlewareOrPanic(config) +} +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { @@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 0ac04bb2f..1f3419f04 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -8,7 +8,7 @@ import ( "regexp" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) { }, })) e.GET("/public/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) e.GET("/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) var testCases = []struct { @@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) { } } +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } +} + // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, - )) + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..571b35877 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -3,87 +3,83 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,