diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml new file mode 100644 index 000000000..733a90eae --- /dev/null +++ b/.github/workflows/checks.yml @@ -0,0 +1,47 @@ +name: Run checks + +on: + push: + branches: + - master + pull_request: + branches: + - master + workflow_dispatch: + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + # run static analysis only with the latest Go version + LATEST_GO_VERSION: "1.20" + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ env.LATEST_GO_VERSION }} + check-latest: true + + - name: Run golint + run: | + go install golang.org/x/lint/golint@latest + golint -set_exit_status ./... + + - name: Run staticcheck + run: | + go install honnef.co/go/tools/cmd/staticcheck@latest + staticcheck ./... + + - name: Run govulncheck + run: | + go version + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 69535f09c..139369fcc 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,66 +4,53 @@ on: push: branches: - master - paths: - - '**.go' - - 'go.*' - - '_fixture/**' - - '.github/**' - - 'codecov.yml' pull_request: branches: - master - paths: - - '**.go' - - 'go.*' - - '_fixture/**' - - '.github/**' - - 'codecov.yml' workflow_dispatch: +permissions: + contents: read # to fetch code (actions/checkout) + +env: + # run coverage and benchmarks only with the latest Go version + LATEST_GO_VERSION: "1.20" + jobs: test: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy - # Echo tests with last four major releases - go: [1.16, 1.17, 1.18] + # Echo tests with last four major releases (unless there are pressing vulnerabilities) + # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when + # we derive from last four major releases promise. + go: ["1.19", "1.20"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - name: Checkout Code uses: actions/checkout@v3 - with: - ref: ${{ github.ref }} - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install Dependencies - run: go install golang.org/x/lint/golint@latest - - name: Run Tests - run: | - golint -set_exit_status ./... - go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v1 + if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v3 with: token: fail_ci_if_error: false + benchmark: needs: test - strategy: - matrix: - os: [ubuntu-latest] - go: [1.18] - name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} - runs-on: ${{ matrix.os }} + name: Benchmark comparison + runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) uses: actions/checkout@v3 @@ -79,7 +66,7 @@ jobs: - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: ${{ env.LATEST_GO_VERSION }} - name: Install Dependencies run: go install golang.org/x/perf/cmd/benchstat@latest diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d45ad78..000000000 --- a/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -arch: - - amd64 - - ppc64le - -language: go -go: - - 1.14.x - - 1.15.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip diff --git a/CHANGELOG.md b/CHANGELOG.md index ba75d71f6..594300420 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,161 @@ # Changelog + +## v4.11.1 - 2023-07-16 + +**Fixes** + +* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481) + + +## v4.11.0 - 2023-07-14 + + +**Fixes** + +* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409) +* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411) +* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456) +* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477) + + +**Enhancements** + +* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410) +* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424) +* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425) +* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429) +* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416) +* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436) +* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444) +* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414) +* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452) +* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267) +* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475) + + +## v4.10.2 - 2023-02-22 + +**Security** + +* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406) +* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405) + +**Enhancements** + +* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277) + + +## v4.10.1 - 2023-02-19 + +**Security** + +* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402) + + +**Enhancements** + +* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377) +* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385) +* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380) +* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394) + + +## v4.10.0 - 2022-12-27 + +**Security** + +* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead. + + JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using + which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain. + +* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are + several vulnerabilities fixed in these libraries. + + Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + +**Enhancements** + +* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305) +* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336) +* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316) +* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338) +* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315) +* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329) +* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340) +* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342) +* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343) +* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345) +* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182) +* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350) +* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341) +* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162) +* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358) +* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362) +* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366) +* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337) + + +## v4.9.1 - 2022-10-12 + +**Fixes** + +* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295) + +**Enhancements** + +* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272) +* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291) +* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254) +* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275) +* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301) + +## v4.9.0 - 2022-09-04 + +**Security** + +* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260) + +**Enhancements** + +* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257) +* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247) + + +## v4.8.0 - 2022-08-10 + +**Most notable things** + +You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237) +```go +e.Add("COPY", "/*", func(c echo.Context) error + return c.String(http.StatusOK, "OK COPY") +}) +``` + +You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217) +```go +e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) + +g := e.Group("/images") +g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) +``` + +**Enhancements** + +* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127) +* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145) +* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187) +* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191) +* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176) +* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209) +* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217) +* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227) +* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237) + + ## v4.7.2 - 2022-03-16 **Fixes** diff --git a/LICENSE b/LICENSE index c46d0105f..2f18411bd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021 LabStack +Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index a6c4aaa90..aaf72cdfb 100644 --- a/Makefile +++ b/Makefile @@ -10,8 +10,10 @@ check: lint vet race ## Check project init: @go install golang.org/x/lint/golint@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files + @staticcheck ${PKG_LIST} @golint -set_exit_status ${PKG_LIST} vet: ## Vet the files @@ -24,11 +26,11 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -bench=".*" ${PKG_LIST} + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.16" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16 +goversion ?= "1.20" +test_version: ## Run tests inside Docker with given version (defaults to 1.20 oldest supported). Example: make test_version goversion=1.20 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 8b2321f05..267ce4d08 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) @@ -11,13 +11,12 @@ ## Supported Go versions +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with +older versions. + As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: -- 1.9.7+ -- 1.10.3+ -- 1.14+ - Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -39,24 +38,13 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support -## Benchmarks - -Date: 2020/11/11
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better! - - - - -The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz - ## [Guide](https://echo.labstack.com/guide) ### Installation ```sh // go get github.com/labstack/echo/{version} -go get github.com/labstack/echo/v4 +go get github.com/labstack/echo/v5 ``` ### Example @@ -65,8 +53,8 @@ go get github.com/labstack/echo/v4 package main import ( - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" "net/http" ) @@ -82,7 +70,9 @@ func main() { e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":1323"); err != http.ErrServerClosed { + log.Fatal(err) + } } // Handler @@ -91,17 +81,30 @@ func hello(c echo.Context) error { } ``` -# Third-party middlewares - -| Repository | Description | -|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | -| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. +# Official middleware repositories + +Following list of middleware is maintained by Echo team. + +| Repository | Description | +|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | + +# Third-party middleware repositories + +Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality +of middlewares in this list. + +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | +| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | Please send a PR to add your own library here. diff --git a/bind.go b/bind.go index c841ca010..0a7eb8b42 100644 --- a/bind.go +++ b/bind.go @@ -11,42 +11,38 @@ import ( "strings" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(c Context, i interface{}) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} // BindPathParams binds path params to bindable object -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { - names := c.ParamNames() - values := c.ParamValues() +func BindPathParams(c Context, i interface{}) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathParams() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(i, params, "param"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // BindQueryParams binds query params to bindable object -func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindQueryParams(c Context, i interface{}) error { + if err := bindData(i, c.QueryParams(), "query"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } @@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { +func BindBody(c Context, i interface{}) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { case *HTTPError: return err default: - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - params, err := c.FormParams() + values, err := c.FormValues() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } - if err = b.bindData(i, params, "form"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(i, values, "form"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } default: return ErrUnsupportedMediaType @@ -97,34 +93,34 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } // BindHeaders binds HTTP headers to a bindable object -func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindHeaders(c Context, i interface{}) error { + if err := bindData(i, c.Request().Header, "header"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - if err := b.BindPathParams(c, i); err != nil { +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. +func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) { + if err := BindPathParams(c, i); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) - method := c.Request().Method + method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { - if err = b.BindQueryParams(c, i); err != nil { + if err = BindQueryParams(c, i); err != nil { return err } } - return b.BindBody(c, i) + return BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { +func bindData(destination interface{}, data map[string][]string, tag string) error { if destination == nil || len(data) == 0 { return nil } @@ -167,10 +163,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). - // structs that implement BindUnmarshaler are binded only when they have explicit tag + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). + // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := bindData(structField.Addr().Interface(), data, tag); err != nil { return err } } @@ -180,10 +176,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To - // fix this we must check all of the map values in a - // case-insensitive search. + // Go json.Unmarshal supports case-insensitive binding. However, the url params are bound case-sensitive which + // is inconsistent. To fix this we must check all the map values in a case-insensitive search. for k, v := range data { if strings.EqualFold(k, inputFieldName) { inputValue = v @@ -297,7 +291,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { @@ -308,7 +302,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { @@ -319,7 +313,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error { if value == "" { - value = "false" + return nil } boolVal, err := strconv.ParseBool(value) if err == nil { @@ -330,7 +324,7 @@ func setBoolField(value string, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0.0" + return nil } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { diff --git a/bind_test.go b/bind_test.go index 4ed8dbb50..e499cbffc 100644 --- a/bind_test.go +++ b/bind_test.go @@ -190,44 +190,39 @@ func TestToMultipleFields(t *testing.T) { } func TestBindJSON(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userJSON), nil, MIMEApplicationJSON) - testBindOkay(assert, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) - testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) + testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON) + testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } func TestBindXML(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) - testBindOkay(assert, strings.NewReader(userXML), nil, MIMETextXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMETextXML) - testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML) + testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) } func TestBindForm(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userForm), nil, MIMEApplicationForm) - testBindOkay(assert, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) err := c.Bind(&[]struct{ Field string }{}) - assert.Error(err) + assert.Error(t, err) } func TestBindQueryParams(t *testing.T) { @@ -277,7 +272,7 @@ func TestBindHeaderParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) @@ -291,7 +286,7 @@ func TestBindHeaderParamBadType(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) @@ -300,6 +295,52 @@ func TestBindHeaderParamBadType(t *testing.T) { } } +func TestBind_CombineQueryWithHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil) + req.Header.Set("language", "de") + req.Header.Set("length", "99") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPathParams(PathParams{{ + Name: "id", + Value: "999", + }}) + + type SearchOpts struct { + ID int `param:"id"` + Length int `query:"length"` + Page int `query:"page"` + Search string `query:"search"` + Language string `query:"language" header:"language"` + } + + opts := SearchOpts{ + Length: 100, + Page: 0, + Search: "default value", + Language: "en", + } + err := c.Bind(&opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // bind from query + assert.Equal(t, 10, opts.Page) // bind from query + assert.Equal(t, 999, opts.ID) // bind from path param + assert.Equal(t, "et", opts.Language) // bind from query + assert.Equal(t, "default value", opts.Search) // default value stays + + // make sure another bind will not mess already set values unless there are new values + err = BindHeaders(c, &opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists + assert.Equal(t, 10, opts.Page) + assert.Equal(t, 999, opts.ID) + assert.Equal(t, "de", opts.Language) // header overwrites now this value + assert.Equal(t, "default value", opts.Search) +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) @@ -317,20 +358,19 @@ func TestBindUnmarshalParam(t *testing.T) { err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) - assert := assert.New(t) - if assert.NoError(err) { + if assert.NoError(t, err) { // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) - assert.Equal(ts, result.T) - assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) - assert.Equal([]Timestamp{ts, ts}, result.TA) - assert.Equal(Struct{""}, result.ST) // child struct does not have a field with matching tag - assert.Equal("baz", result.StWithTag.Foo) // child struct has field with matching tag + assert.Equal(t, ts, result.T) + assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) + assert.Equal(t, []Timestamp{ts, ts}, result.TA) + assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag + assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag } } func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -406,7 +446,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -426,48 +466,47 @@ func TestBindMultipartForm(t *testing.T) { mw.Close() body := bodyBuffer.Bytes() - assert := assert.New(t) - testBindOkay(assert, bytes.NewReader(body), nil, mw.FormDataContentType()) - testBindOkay(assert, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { - assert := assert.New(t) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } func TestBindbindData(t *testing.T) { - a := assert.New(t) ts := new(bindTestStruct) - b := new(DefaultBinder) - err := b.bindData(ts, values, "form") - a.NoError(err) - - a.Equal(0, ts.I) - a.Equal(int8(0), ts.I8) - a.Equal(int16(0), ts.I16) - a.Equal(int32(0), ts.I32) - a.Equal(int64(0), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(0), ts.UI8) - a.Equal(uint16(0), ts.UI16) - a.Equal(uint32(0), ts.UI32) - a.Equal(uint64(0), ts.UI64) - a.Equal(false, ts.B) - a.Equal(float32(0), ts.F32) - a.Equal(float64(0), ts.F64) - a.Equal("", ts.S) - a.Equal("", ts.cantSet) + err := bindData(ts, values, "form") + assert.NoError(t, err) + + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(0), ts.I8) + assert.Equal(t, int16(0), ts.I16) + assert.Equal(t, int32(0), ts.I32) + assert.Equal(t, int64(0), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(0), ts.UI8) + assert.Equal(t, uint16(0), ts.UI16) + assert.Equal(t, uint32(0), ts.UI32) + assert.Equal(t, uint64(0), ts.UI64) + assert.Equal(t, false, ts.B) + assert.Equal(t, float32(0), ts.F32) + assert.Equal(t, float64(0), ts.F64) + assert.Equal(t, "", ts.S) + assert.Equal(t, "", ts.cantSet) } func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + cc := c.(RoutableContext) + cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"}) + cc.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }) u := new(user) err := c.Bind(u) @@ -478,9 +517,11 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + cc2 := c2.(RoutableContext) + cc2.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc2.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c2.Bind(u) @@ -492,15 +533,17 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + cc3 := c3.(RoutableContext) + cc3.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc3.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c3.Bind(u) @@ -528,7 +571,6 @@ func TestBindUnmarshalTypeError(t *testing.T) { } func TestBindSetWithProperType(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() @@ -543,9 +585,9 @@ func TestBindSetWithProperType(t *testing.T) { } val := values[typeField.Name][0] err := setWithProperType(typeField.Type.Kind(), val, structField) - assert.NoError(err) + assert.NoError(t, err) } - assertBindTestStruct(assert, ts) + assertBindTestStruct(t, ts) type foo struct { Bar bytes.Buffer @@ -553,50 +595,118 @@ func TestBindSetWithProperType(t *testing.T) { v := &foo{} typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() - assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) + assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - assert := assert.New(t) +func TestSetIntField(t *testing.T) { + ts := new(bindTestStruct) + ts.I = 100 + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 100, ts.I) + + // second set with value sets the value + err = setIntField("5", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) +} +func TestSetUintField(t *testing.T) { ts := new(bindTestStruct) + ts.UI = 100 + val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) - } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) - } - // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) - } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) - } + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(100), ts.UI) + + // second set with value sets the value + err = setUintField("5", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) +} - // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) - } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) - } +func TestSetFloatField(t *testing.T) { + ts := new(bindTestStruct) + ts.F32 = 100 - // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) - } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) - } + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(100), ts.F32) + + // second set with value sets the value + err = setFloatField("15.5", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) +} + +func TestSetBoolField(t *testing.T) { + ts := new(bindTestStruct) + ts.B = true + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // second set with value sets the value + err = setBoolField("true", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // fourth set to false + err = setBoolField("false", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, false, ts.B) +} + +func TestUnmarshalFieldNonPtr(t *testing.T) { + ts := new(bindTestStruct) + val := reflect.ValueOf(ts).Elem() ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) + if assert.NoError(t, err) { + assert.True(t, ok) + assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) } } @@ -604,35 +714,34 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() assert := assert.New(b) ts := new(bindTestStructWithTags) - binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = bindData(ts, values, "form") } assert.NoError(err) - assertBindTestStruct(assert, (*bindTestStruct)(ts)) + assertBindTestStruct(b, (*bindTestStruct)(ts)) } -func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { - a.Equal(0, ts.I) - a.Equal(int8(8), ts.I8) - a.Equal(int16(16), ts.I16) - a.Equal(int32(32), ts.I32) - a.Equal(int64(64), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(8), ts.UI8) - a.Equal(uint16(16), ts.UI16) - a.Equal(uint32(32), ts.UI32) - a.Equal(uint64(64), ts.UI64) - a.Equal(true, ts.B) - a.Equal(float32(32.5), ts.F32) - a.Equal(float64(64.5), ts.F64) - a.Equal("test", ts.S) - a.Equal("", ts.GetCantSet()) +func assertBindTestStruct(t testing.TB, ts *bindTestStruct) { + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(8), ts.I8) + assert.Equal(t, int16(16), ts.I16) + assert.Equal(t, int32(32), ts.I32) + assert.Equal(t, int64(64), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(8), ts.UI8) + assert.Equal(t, uint16(16), ts.UI16) + assert.Equal(t, uint32(32), ts.UI32) + assert.Equal(t, uint64(64), ts.UI64) + assert.Equal(t, true, ts.B) + assert.Equal(t, float32(32.5), ts.F32) + assert.Equal(t, float64(64.5), ts.F64) + assert.Equal(t, "test", ts.S) + assert.Equal(t, "", ts.GetCantSet()) } -func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindOkay(t testing.TB, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -644,13 +753,13 @@ func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctyp req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) - if assert.NoError(err) { - assert.Equal(1, u.ID) - assert.Equal("Jon Snow", u.Name) + if assert.NoError(t, err) { + assert.Equal(t, 1, u.ID) + assert.Equal(t, "Jon Snow", u.Name) } } -func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -662,14 +771,14 @@ func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, req.Header.Set(HeaderContentType, ctype) u := []user{} err := c.Bind(&u) - if assert.NoError(err) { - assert.Equal(1, len(u)) - assert.Equal(1, u[0].ID) - assert.Equal("Jon Snow", u[0].Name) + if assert.NoError(t, err) { + assert.Equal(t, 1, len(u)) + assert.Equal(t, 1, u[0].ID) + assert.Equal(t, "Jon Snow", u[0].Name) } } -func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { +func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) rec := httptest.NewRecorder() @@ -681,14 +790,14 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte switch { case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - if assert.IsType(new(HTTPError), err) { - assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } default: - if assert.IsType(new(HTTPError), err) { - assert.Equal(ErrUnsupportedMediaType, err) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } } } @@ -840,8 +949,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("node_from_path") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "node_from_path"}, + }) } var bindTarget interface{} @@ -852,7 +963,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(bindTarget, c) + err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1021,8 +1132,10 @@ func TestDefaultBinder_BindBody(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("real_node") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "real_node"}, + }) } var bindTarget interface{} @@ -1031,9 +1144,8 @@ func TestDefaultBinder_BindBody(t *testing.T) { } else { bindTarget = &Node{} } - b := new(DefaultBinder) - err := b.BindBody(c, bindTarget) + err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/binder.go b/binder.go index 5a6cf9d9b..5d357859d 100644 --- a/binder.go +++ b/binder.go @@ -123,10 +123,10 @@ func QueryParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, - ValueFunc: c.Param, + ValueFunc: c.PathParam, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here - value := c.Param(sourceParam) + value := c.PathParam(sourceParam) if value == "" { return nil } @@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Second) } @@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Second) } @@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Millisecond) } @@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Millisecond) } @@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Nanosecond) } @@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Nanosecond) } diff --git a/binder_external_test.go b/binder_external_test.go index f1aecb52b..585ade816 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -4,7 +4,7 @@ package echo_test import ( "encoding/base64" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "log" "net/http" "net/http/httptest" diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go deleted file mode 100644 index 018628c3a..000000000 --- a/binder_go1.15_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build go1.15 - -package echo - -/** - Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\") - So pre 1.15 these tests fail with similar error: - - expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param" - actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param" -*/ - -import ( - "errors" - "github.com/stretchr/testify/assert" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context { - e := New() - req := httptest.NewRequest(http.MethodGet, URL, body) - if body != nil { - req.Header.Set(HeaderContentType, MIMEApplicationJSON) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) - for name, value := range pathParams { - names = append(names, name) - values = append(values, value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) - } - - return c -} - -func TestValueBinder_TimeError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue time.Time - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - dest := time.Time{} - var err error - if tc.whenMust { - err = b.MustTime("param", &dest, tc.whenLayout).BindError() - } else { - err = b.Time("param", &dest, tc.whenLayout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_TimesError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue []time.Time - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - layout := time.RFC3339 - if tc.whenLayout != "" { - layout = tc.whenLayout - } - - var dest []time.Time - var err error - if tc.whenMust { - err = b.MustTimes("param", &dest, layout).BindError() - } else { - err = b.Times("param", &dest, layout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue time.Duration - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - var dest time.Duration - var err error - if tc.whenMust { - err = b.MustDuration("param", &dest).BindError() - } else { - err = b.Duration("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationsError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue []time.Duration - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - var dest []time.Duration - var err error - if tc.whenMust { - err = b.MustDurations("param", &dest).BindError() - } else { - err = b.Durations("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/binder_test.go b/binder_test.go index 910bbfc50..3c17057c0 100644 --- a/binder_test.go +++ b/binder_test.go @@ -26,14 +26,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) c := e.NewContext(req, rec) if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) + params := make(PathParams, 0) for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + params = append(params, PathParam{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + cc := c.(RoutableContext) + cc.SetRawPathParams(¶ms) } return c @@ -2917,7 +2918,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) } } @@ -2984,7 +2985,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } @@ -3029,3 +3030,224 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { } } } + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/context.go b/context.go index a4ecfadfc..50616aa5d 100644 --- a/context.go +++ b/context.go @@ -3,210 +3,214 @@ package echo import ( "bytes" "encoding/xml" + "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" "net/url" + "path/filepath" "strings" "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. + // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. + RouteInfo() RouteInfo - // SetPath sets the registered path for the handler. - SetPath(p string) + // Path returns the registered path for the handler. + Path() string - // Param returns path parameter by name. - Param(name string) string + // PathParam returns path parameter by name. + PathParam(name string) string - // ParamNames returns path parameter names. - ParamNames() []string + // PathParamDefault returns the path parameter or default value for the provided name. + // + // Notes for DefaultRouter implementation: + // Path parameter could be empty for cases like that: + // * route `/release-:version/bin` and request URL is `/release-/bin` + // * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` + // but not when path parameter is last part of route path + // * route `/download/file.:ext` will not match request `/download/file.` + PathParamDefault(name string, defaultValue string) string - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // PathParams returns path parameter values. + PathParams() PathParams - // ParamValues returns path parameter values. - ParamValues() []string + // SetPathParams sets path parameters for current request. + SetPathParams(params PathParams) - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParamDefault returns the query param or default value for the provided name. + QueryParamDefault(name, defaultValue string) string - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryString returns the URL query string. - QueryString() string + // QueryString returns the URL query string. + QueryString() string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormValueDefault returns the form field value or default value for the provided name. + FormValueDefault(name, defaultValue string) string - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // FormValues returns the form field values as `url.Values`. + FormValues() (url.Values, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // Get retrieves data from the context. - Get(key string) interface{} + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Set saves data in the context. - Set(key string, val interface{}) + // Get retrieves data from the context. + Get(key string) interface{} - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. - Bind(i interface{}) error + // Set saves data in the context. + Set(key string, val interface{}) - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. + Bind(i interface{}) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i interface{}) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data interface{}) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // String sends a string response with status code. - String(code int, s string) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // String sends a string response with status code. + String(code int, s string) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSON sends a JSON response with status code. + JSON(code int, i interface{}) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i interface{}, indent string) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i interface{}) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XML sends an XML response with status code. + XML(code int, i interface{}) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i interface{}, indent string) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // File sends a response with the content of the file. - File(file string) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // File sends a response with the content of the file. + File(file string) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // FileFS sends a response with the content of the file from given filesystem. + FileFS(file string, filesystem fs.FS) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // Error invokes the registered HTTP error handler. Generally used by middleware. - Error(err error) + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // Handler returns the matched handler by router. - Handler() HandlerFunc + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. + // Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler. + Error(err error) - // Logger returns the `Logger` instance. - Logger() Logger - - // Set the logger - SetLogger(l Logger) + // Echo returns the `Echo` instance. + // + // WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them + // anywhere in your code after Echo server has started. + Echo() *Echo +} - // Echo returns the `Echo` instance. - Echo() *Echo +// ServableContext is interface that Echo context implementation must implement to be usable in middleware/handlers and +// be able to be routed by Router. +type ServableContext interface { + Context // minimal set of methods for middlewares and handler + RoutableContext // minimal set for routing. These methods should not be accessed in middlewares/handlers - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } - - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. @@ -221,39 +225,95 @@ const ( defaultIndent = " " ) -func (c *context) writeContentType(value string) { +// DefaultContext is default implementation of Context interface and can be embedded into structs to compose +// new Contexts with extended/modified behaviour. +type DefaultContext struct { + request *http.Request + response *Response + + route RouteInfo + path string + + // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations. + pathParams *PathParams + // currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request. + // Lifecycle is not handle by Echo and could have excess allocations per served Request + currentParams PathParams + + query url.Values + store Map + echo *Echo + lock sync.RWMutex +} + +// NewDefaultContext creates new instance of DefaultContext. +// Argument pathParamAllocSize must be value that is stored in Echo.contextPathParamAllocSize field and is used +// to preallocate PathParams slice. +func NewDefaultContext(e *Echo, pathParamAllocSize int) *DefaultContext { + p := make(PathParams, pathParamAllocSize) + return &DefaultContext{ + pathParams: &p, + store: make(Map), + echo: e, + } +} + +// Reset resets the context after request completes. It must be called along +// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. +// See `Echo#ServeHTTP()` +func (c *DefaultContext) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.response.reset(w) + c.query = nil + c.store = nil + + c.route = nil + c.path = "" + // NOTE: Don't reset because it has to have length of c.echo.contextPathParamAllocSize at all times + *c.pathParams = (*c.pathParams)[:0] + c.currentParams = nil +} + +func (c *DefaultContext) writeContentType(value string) { header := c.Response().Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -func (c *context) Request() *http.Request { +// Request returns `*http.Request`. +func (c *DefaultContext) Request() *http.Request { return c.request } -func (c *context) SetRequest(r *http.Request) { +// SetRequest sets `*http.Request`. +func (c *DefaultContext) SetRequest(r *http.Request) { c.request = r } -func (c *context) Response() *Response { +// Response returns `*Response`. +func (c *DefaultContext) Response() *Response { return c.response } -func (c *context) SetResponse(r *Response) { +// SetResponse sets `*Response`. +func (c *DefaultContext) SetResponse(r *Response) { c.response = r } -func (c *context) IsTLS() bool { +// IsTLS returns true if HTTP connection is TLS otherwise false. +func (c *DefaultContext) IsTLS() bool { return c.request.TLS != nil } -func (c *context) IsWebSocket() bool { +// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. +func (c *DefaultContext) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) return strings.EqualFold(upgrade, "websocket") } -func (c *context) Scheme() string { +// Scheme returns the HTTP protocol scheme, `http` or `https`. +func (c *DefaultContext) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -274,7 +334,10 @@ func (c *context) Scheme() string { return "http" } -func (c *context) RealIP() string { +// RealIP returns the client's network address based on `X-Forwarded-For` +// or `X-Real-IP` request header. +// The behavior can be configured using `Echo#IPExtractor`. +func (c *DefaultContext) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } @@ -282,96 +345,136 @@ func (c *context) RealIP() string { if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { - return strings.TrimSpace(ip[:i]) + xffip := strings.TrimSpace(ip[:i]) + xffip = strings.TrimPrefix(xffip, "[") + xffip = strings.TrimSuffix(xffip, "]") + return xffip } return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ip = strings.TrimPrefix(ip, "[") + ip = strings.TrimSuffix(ip, "]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) return ra } -func (c *context) Path() string { +// Path returns the registered path for the handler. +func (c *DefaultContext) Path() string { return c.path } -func (c *context) SetPath(p string) { +// SetPath sets the registered path for the handler. +func (c *DefaultContext) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } - } - return "" +// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. +// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. +func (c *DefaultContext) RouteInfo() RouteInfo { + return c.route } -func (c *context) ParamNames() []string { - return c.pnames +// SetRouteInfo sets the route info of this request to the context. +func (c *DefaultContext) SetRouteInfo(ri RouteInfo) { + c.route = ri } -func (c *context) SetParamNames(names ...string) { - c.pnames = names +// RawPathParams returns raw path pathParams value. Allocation of PathParams is handled by Context. +func (c *DefaultContext) RawPathParams() *PathParams { + return c.pathParams +} - l := len(names) - if *c.echo.maxParam < l { - *c.echo.maxParam = l - } +// SetRawPathParams replaces any existing param values with new values for this context lifetime (request). +// +// DO NOT USE! +// Do not set any other value than what you got from RawPathParams as allocation of PathParams is handled by Context. +// If you mess up size of pathParams size your application will panic/crash during routing +func (c *DefaultContext) SetRawPathParams(params *PathParams) { + c.pathParams = params +} - if len(c.pvalues) < l { - // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues - newPvalues := make([]string, l) - copy(newPvalues, c.pvalues) - c.pvalues = newPvalues +// PathParam returns path parameter by name. +func (c *DefaultContext) PathParam(name string) string { + if c.currentParams != nil { + return c.currentParams.Get(name, "") } + + return c.pathParams.Get(name, "") } -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] +// PathParamDefault returns the path parameter or default value for the provided name. +func (c *DefaultContext) PathParamDefault(name, defaultValue string) string { + return c.pathParams.Get(name, defaultValue) } -func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - // It will brake the Router#Find code - limit := len(values) - if limit > *c.echo.maxParam { - limit = *c.echo.maxParam - } - for i := 0; i < limit; i++ { - c.pvalues[i] = values[i] +// PathParams returns path parameter values. +func (c *DefaultContext) PathParams() PathParams { + if c.currentParams != nil { + return c.currentParams } + + result := make(PathParams, len(*c.pathParams)) + copy(result, *c.pathParams) + return result } -func (c *context) QueryParam(name string) string { +// SetPathParams sets path parameters for current request. +func (c *DefaultContext) SetPathParams(params PathParams) { + c.currentParams = params +} + +// QueryParam returns the query param for the provided name. +func (c *DefaultContext) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } -func (c *context) QueryParams() url.Values { +// QueryParamDefault returns the query param or default value for the provided name. +// Note: QueryParamDefault does not distinguish if query had no value by that name or value was empty string +// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamDefault("search", "1")` +func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string { + value := c.QueryParam(name) + if value == "" { + value = defaultValue + } + return value +} + +// QueryParams returns the query parameters as `url.Values`. +func (c *DefaultContext) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -func (c *context) QueryString() string { +// QueryString returns the URL query string. +func (c *DefaultContext) QueryString() string { return c.request.URL.RawQuery } -func (c *context) FormValue(name string) string { +// FormValue returns the form field value for the provided name. +func (c *DefaultContext) FormValue(name string) string { return c.request.FormValue(name) } -func (c *context) FormParams() (url.Values, error) { +// FormValueDefault returns the form field value or default value for the provided name. +// Note: FormValueDefault does not distinguish if form had no value by that name or value was empty string +func (c *DefaultContext) FormValueDefault(name, defaultValue string) string { + value := c.FormValue(name) + if value == "" { + value = defaultValue + } + return value +} + +// FormValues returns the form field values as `url.Values`. +func (c *DefaultContext) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { if err := c.request.ParseMultipartForm(defaultMemory); err != nil { return nil, err @@ -384,7 +487,8 @@ func (c *context) FormParams() (url.Values, error) { return c.request.Form, nil } -func (c *context) FormFile(name string) (*multipart.FileHeader, error) { +// FormFile returns the multipart form file for the provided name. +func (c *DefaultContext) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err @@ -393,30 +497,36 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { return fh, nil } -func (c *context) MultipartForm() (*multipart.Form, error) { +// MultipartForm returns the multipart form. +func (c *DefaultContext) MultipartForm() (*multipart.Form, error) { err := c.request.ParseMultipartForm(defaultMemory) return c.request.MultipartForm, err } -func (c *context) Cookie(name string) (*http.Cookie, error) { +// Cookie returns the named cookie provided in the request. +func (c *DefaultContext) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *context) SetCookie(cookie *http.Cookie) { +// SetCookie adds a `Set-Cookie` header in HTTP response. +func (c *DefaultContext) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -func (c *context) Cookies() []*http.Cookie { +// Cookies returns the HTTP cookies sent with the request. +func (c *DefaultContext) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) interface{} { +// Get retrieves data from the context. +func (c *DefaultContext) Get(key string) interface{} { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val interface{}) { +// Set saves data in the context. +func (c *DefaultContext) Set(key string, val interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -426,18 +536,24 @@ func (c *context) Set(key string, val interface{}) { c.store[key] = val } -func (c *context) Bind(i interface{}) error { - return c.echo.Binder.Bind(i, c) +// Bind binds path params, query params and the request body into provided type `i`. The default binder +// binds body based on Content-Type header. +func (c *DefaultContext) Bind(i interface{}) error { + return c.echo.Binder.Bind(c, i) } -func (c *context) Validate(i interface{}) error { +// Validate validates provided `i`. It is usually called after `Context#Bind()`. +// Validator must be registered using `Echo#Validator`. +func (c *DefaultContext) Validate(i interface{}) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data interface{}) (err error) { +// Render renders a template with data and sends a text/html response with status +// code. Renderer must be registered using `Echo.Renderer`. +func (c *DefaultContext) Render(code int, name string, data interface{}) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } @@ -448,19 +564,22 @@ func (c *context) Render(code int, name string, data interface{}) (err error) { return c.HTMLBlob(code, buf.Bytes()) } -func (c *context) HTML(code int, html string) (err error) { +// HTML sends an HTTP response with status code. +func (c *DefaultContext) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } -func (c *context) HTMLBlob(code int, b []byte) (err error) { +// HTMLBlob sends an HTTP blob response with status code. +func (c *DefaultContext) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -func (c *context) String(code int, s string) (err error) { +// String sends a string response with status code. +func (c *DefaultContext) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { +func (c *DefaultContext) jsonPBlob(code int, callback string, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -479,13 +598,14 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error return } -func (c *context) json(code int, i interface{}, indent string) error { +func (c *DefaultContext) json(code int, i interface{}, indent string) error { c.writeContentType(MIMEApplicationJSONCharsetUTF8) c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i interface{}) (err error) { +// JSON sends a JSON response with status code. +func (c *DefaultContext) JSON(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -493,19 +613,25 @@ func (c *context) JSON(code int, i interface{}) (err error) { return c.json(code, i, indent) } -func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) { +// JSONPretty sends a pretty-print JSON with status code. +func (c *DefaultContext) JSONPretty(code int, i interface{}, indent string) (err error) { return c.json(code, i, indent) } -func (c *context) JSONBlob(code int, b []byte) (err error) { +// JSONBlob sends a JSON blob response with status code. +func (c *DefaultContext) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) } -func (c *context) JSONP(code int, callback string, i interface{}) (err error) { +// JSONP sends a JSONP response with status code. It uses `callback` to construct +// the JSONP payload. +func (c *DefaultContext) JSONP(code int, callback string, i interface{}) (err error) { return c.jsonPBlob(code, callback, i) } -func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { +// JSONPBlob sends a JSONP blob response with status code. It uses `callback` +// to construct the JSONP payload. +func (c *DefaultContext) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { @@ -518,7 +644,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i interface{}, indent string) (err error) { +func (c *DefaultContext) xml(code int, i interface{}, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -531,7 +657,8 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i interface{}) (err error) { +// XML sends an XML response with status code. +func (c *DefaultContext) XML(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -539,11 +666,13 @@ func (c *context) XML(code int, i interface{}) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) { +// XMLPretty sends a pretty-print XML with status code. +func (c *DefaultContext) XMLPretty(code int, i interface{}, indent string) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLBlob(code int, b []byte) (err error) { +// XMLBlob sends an XML blob response with status code. +func (c *DefaultContext) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -553,39 +682,86 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -func (c *context) Blob(code int, contentType string, b []byte) (err error) { +// Blob sends a blob response with status code and content type. +func (c *DefaultContext) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +// Stream sends a streaming response with status code and content type. +func (c *DefaultContext) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -func (c *context) Attachment(file, name string) error { +// File sends a response with the content of the file. +func (c *DefaultContext) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *DefaultContext) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} + +// Attachment sends a response as attachment, prompting client to save the file. +func (c *DefaultContext) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -func (c *context) Inline(file, name string) error { +// Inline sends a response as inline, opening the file in the browser. +func (c *DefaultContext) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } -func (c *context) contentDisposition(file, name, dispositionType string) error { +func (c *DefaultContext) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) return c.File(file) } -func (c *context) NoContent(code int) error { +// NoContent sends a response with no body and a status code. +func (c *DefaultContext) NoContent(code int) error { c.response.WriteHeader(code) return nil } -func (c *context) Redirect(code int, url string) error { +// Redirect redirects the request to a provided URL with status code. +func (c *DefaultContext) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -594,45 +770,17 @@ func (c *context) Redirect(code int, url string) error { return nil } -func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) +// Error invokes the registered global HTTP error handler. Generally used by middleware. +// A side-effect of calling global error handler is that now Response has been committed (sent to the client) and +// middlewares up in chain can not change Response status code or Response body anymore. +// +// Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. +// Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler. +func (c *DefaultContext) Error(err error) { + c.echo.HTTPErrorHandler(c, err) } -func (c *context) Echo() *Echo { +// Echo returns the `Echo` instance. +func (c *DefaultContext) Echo() *Echo { return c.echo } - -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res - } - return c.echo.Logger -} - -func (c *context) SetLogger(l Logger) { - c.logger = l -} - -func (c *context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.response.reset(w) - c.query = nil - c.handler = NotFoundHandler - c.store = nil - c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { - c.pvalues[i] = "" - } -} diff --git a/context_fs.go b/context_fs.go deleted file mode 100644 index 11ee84bcd..000000000 --- a/context_fs.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "os" - "path/filepath" -) - -func (c *context) File(file string) (err error) { - f, err := os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return - } - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return -} diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go deleted file mode 100644 index c1c724afd..000000000 --- a/context_fs_go1.16.go +++ /dev/null @@ -1,52 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "errors" - "io" - "io/fs" - "net/http" - "path/filepath" -) - -func (c *context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c Context, file string, filesystem fs.FS) error { - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} diff --git a/context_fs_go1.16_test.go b/context_fs_go1.16_test.go deleted file mode 100644 index 027d1c483..000000000 --- a/context_fs_go1.16_test.go +++ /dev/null @@ -1,135 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestContext_File(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok, from default file system", - whenFile: "_fixture/images/walle.png", - whenFS: nil, - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "ok, from custom file system", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - if tc.whenFS != nil { - e.Filesystem = tc.whenFS - } - - handler := func(ec Context) error { - return ec.(*context).File(tc.whenFile) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestContext_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - handler := func(ec Context) error { - return ec.(*context).FileFS(tc.whenFile, tc.whenFS) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} diff --git a/context_test.go b/context_test.go index a8b9a9946..9e06b3e9a 100644 --- a/context_test.go +++ b/context_test.go @@ -8,33 +8,33 @@ import ( "errors" "fmt" "io" + "io/fs" "math" "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "text/template" "time" - "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -46,9 +46,10 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -60,9 +61,10 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -73,7 +75,7 @@ func BenchmarkAllocXML(b *testing.B) { } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := context{request: &http.Request{ + c := DefaultContext{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { @@ -104,18 +106,16 @@ func TestContext(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Render @@ -126,106 +126,106 @@ func TestContext(t *testing.T) { } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } c.echo.Renderer = nil err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) + assert.Error(t, err) // JSON rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } // JSON with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // reset // JSONPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } // JSON (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // JSONP rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback := "callback" err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) } // XML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // XML with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // XML (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // XML response write error - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.response.Writer = responseWriterErr{} err = c.XML(0, 0) - testify.Error(t, err) + assert.Error(t, err) // XMLPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } t.Run("empty indent", func(t *testing.T) { @@ -237,166 +237,170 @@ func TestContext(t *testing.T) { t.Run("json", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New JSONBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) } }) t.Run("xml", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New XMLBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } }) }) // Legacy JSONBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) } // Legacy JSONPBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback = "callback" data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } // Legacy XMLBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // String rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // Stream rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) r := strings.NewReader("response from a stream") err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) } // Attachment rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // Inline rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // NoContent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) - - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) // Reset - c.SetParamNames("foo") - c.SetParamValues("bar") + c.pathParams = &PathParams{ + {Name: "foo", Value: "bar"}, + } c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + assert.Equal(t, 0, len(c.PathParams())) + assert.Equal(t, 0, len(c.store)) + assert.Equal(t, nil, c.RouteInfo()) + assert.Equal(t, 0, len(c.QueryParams())) +} + +func TestContext_Error(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.Error(errors.New("error")) + + assert.True(t, c.Response().Committed) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `{"message":"Internal Server Error"}`+"\n", rec.Body.String()) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -404,12 +408,11 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -421,24 +424,22 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -453,104 +454,215 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() +func TestContext_PathParams(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + expect PathParams + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + expect: PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + }, + { + name: "params is empty", + given: &PathParams{}, + expect: PathParams{}, + }, + } - handler := func(c Context) error { return c.String(http.StatusOK, "OK") } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - r.Add(http.MethodGet, "/users/:id", handler) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) + c.(RoutableContext).SetRawPathParams(tc.given) - assert := testify.New(t) + assert.EqualValues(t, tc.expect, c.PathParams()) + }) + } +} - assert.Equal("/users/:id", c.Path()) +func TestContext_PathParam(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "multiple same param values exists - return first", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "uid", Value: "202"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) - r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.EqualValues(t, tc.expect, c.PathParam(tc.whenParamName)) + }) + } } -func TestContextPathParam(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) +func TestContext_PathParamDefault(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "101", + }, + { + name: "param exists and is empty", + given: &PathParams{ + {Name: "uid", Value: ""}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "", // <-- this is different from QueryParamDefault behaviour + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + whenDefaultValue: "999", + expect: "999", + }, + } - // ParamNames - c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - // ParamValues - c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.(RoutableContext).SetRawPathParams(tc.given) - // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.EqualValues(t, tc.expect, c.PathParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextGetAndSetParam(t *testing.T) { e := New() r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + _, err := r.Add(Route{ + Method: http.MethodGet, + Path: "/:foo", + Name: "", + Handler: func(Context) error { return nil }, + Middlewares: nil, + }) + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) - c.SetParamNames("foo") + + params := &PathParams{{Name: "foo", Value: "101"}} + // ParamNames + c.(*DefaultContext).pathParams = params // round-trip param values with modification - paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) - paramVals[0] = "bar" - c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + paramVals := c.PathParams() + assert.Equal(t, *params, c.PathParams()) + + paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context + assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams()) + + pathParams := PathParams{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathParams(pathParams) + assert.Equal(t, pathParams, c.PathParams()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { - c.Reset(nil, nil) + assert.NotPanics(t, func() { + c.(ServableContext).Reset(nil, nil) }) + assert.Equal(t, PathParams{}, c.PathParams()) + assert.Len(t, *c.(*DefaultContext).pathParams, 0) + assert.Equal(t, cap(*c.(*DefaultContext).pathParams), 1) } // Issue #1655 -func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - +func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) { e := New() - assert.Equal(0, *e.maxParam) + c := e.NewContext(nil, nil).(*DefaultContext) - expectedOneParam := []string{"one"} - expectedTwoParams := []string{"one", "two"} - expectedThreeParams := []string{"one", "two", ""} - expectedABCParams := []string{"A", "B", "C"} - - c := e.NewContext(nil, nil) - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) - - c.SetParamNames("1") - assert.Equal(2, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) - - c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) - - c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) - // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + assert.Equal(t, 0, e.contextPathParamAllocSize) + expectedTwoParams := &PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetRawPathParams(expectedTwoParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, *expectedTwoParams, c.PathParams()) + + expectedThreeParams := PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, + } + c.SetPathParams(expectedThreeParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, expectedThreeParams, c.PathParams()) } func TestContextFormValue(t *testing.T) { @@ -564,44 +676,154 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) + + // FormValueDefault + assert.Equal(t, "Jon Snow", c.FormValueDefault("name", "nope")) + assert.Equal(t, "default", c.FormValueDefault("missing", "default")) - // FormParams - params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + // FormValues + values, err := c.FormValues() + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, params) + }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + values, err = c.FormValues() + assert.Nil(t, values) + assert.Error(t, err) } -func TestContextQueryParam(t *testing.T) { - q := make(url.Values) - q.Set("name", "Jon Snow") - q.Set("email", "jon@labstack.com") - req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) - e := New() - c := e.NewContext(req, nil) +func TestContext_QueryParams(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect url.Values + }{ + { + name: "multiple values in url", + givenURL: "/?test=1&test=2&email=jon%40labstack.com", + expect: url.Values{ + "test": []string{"1", "2"}, + "email": []string{"jon@labstack.com"}, + }, + }, + { + name: "single value in url", + givenURL: "/?nope=1", + expect: url.Values{ + "nope": []string{"1"}, + }, + }, + { + name: "no query params in url", + givenURL: "/?", + expect: url.Values{}, + }, + } - // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) - // QueryParams - testify.Equal(t, url.Values{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.QueryParams()) + assert.Equal(t, tc.expect, c.QueryParams()) + }) + } +} + +func TestContext_QueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + expect: "1", + }, + { + name: "multiple values exists in url", + givenURL: "/?test=9&test=8", + whenParamName: "test", + expect: "9", // <-- first value in returned + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + expect: "", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) + }) + } +} + +func TestContext_QueryParamDefault(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "1", + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextFormFile(t *testing.T) { @@ -609,7 +831,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -618,8 +840,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -634,8 +856,8 @@ func TestContextMultipartForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) } } @@ -644,22 +866,22 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &context{ + c := &DefaultContext{ echo: e, } @@ -671,42 +893,6 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context = new(context) - - testify.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - testify.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context = new(context) - - c.SetPath(path) - testify.Equal(t, path, c.Path()) -} - type validator struct{} func (*validator) Validate(i interface{}) error { @@ -717,10 +903,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -728,21 +914,21 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -751,7 +937,7 @@ func TestContext_Scheme(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, @@ -759,7 +945,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, @@ -767,7 +953,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, @@ -775,7 +961,7 @@ func TestContext_Scheme(t *testing.T) { "http", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, @@ -783,7 +969,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, @@ -791,7 +977,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, "http", @@ -799,44 +985,44 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, - testify.False, + assert.False, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -849,30 +1035,14 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - testify.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { @@ -881,7 +1051,7 @@ func TestContext_RealIP(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, @@ -889,7 +1059,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, @@ -897,7 +1067,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, }, @@ -905,7 +1075,31 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &DefaultContext{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, @@ -915,7 +1109,18 @@ func TestContext_RealIP(t *testing.T) { "192.168.0.1", }, { - &context{ + &DefaultContext{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"[2001:db8::1]"}, + }, + }, + }, + "2001:db8::1", + }, + + { + &DefaultContext{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, @@ -925,6 +1130,128 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) + } +} + +func TestContext_File(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec Context) error { + return ec.(*DefaultContext).File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec Context) error { + return ec.(*DefaultContext).FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) } } diff --git a/echo.go b/echo.go index 8829619c7..c76d9d804 100644 --- a/echo.go +++ b/echo.go @@ -3,158 +3,133 @@ Package echo implements high performance, minimalist Go web framework. Example: - package main + package main - import ( - "net/http" + import ( + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "log" + "net/http" + ) - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + // Handler + func hello(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + } - // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") - } + func main() { + // Echo instance + e := echo.New() - func main() { - // Echo instance - e := echo.New() + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + // Routes + e.GET("/", hello) - // Routes - e.GET("/", hello) - - // Start server - e.Logger.Fatal(e.Start(":1323")) - } + // Start server + if err := e.Start(":8080"); err != http.ErrServerClosed { + log.Fatal(err) + } + } Learn more at https://echo.labstack.com */ package echo import ( - "bytes" stdContext "context" - "crypto/tls" + "encoding/json" "errors" "fmt" "io" - "io/ioutil" - stdLog "log" - "net" + "io/fs" "net/http" - "reflect" - "runtime" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex - StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - } +// Echo is the top-level framework instance. +// +// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these +// fields from handlers/middlewares and changing field values at the same time leads to data-races. +// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. +type Echo struct { + // premiddleware are middlewares that are run for every request before routing is done + premiddleware []MiddlewareFunc + // middleware are middlewares that are run after router found a matching route (not found and method not found are also matches) + middleware []MiddlewareFunc - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } + router Router + routers map[string]Router + routerCreator func(e *Echo) Router - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } + contextPool sync.Pool + // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info at context + // creation moment so we can allocate path parameter values slice with correct size. + contextPathParamAllocSize int - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(next HandlerFunc) HandlerFunc + // NewContextFunc allows using custom context implementations, instead of default *echo.context + NewContextFunc func(e *Echo, pathParamAllocSize int) ServableContext + Debug bool + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(c Context) error + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. + Filesystem fs.FS - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) + // OnAddRoute is called when Echo adds new route to specific host router. Handler is called for every router + // and before route is added to the host router. + OnAddRoute func(host string, route Routable) error +} - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c Context, err error) - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc - // Common struct for Echo & Group. - common struct{} -) +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} + +// Renderer is the interface that wraps the Render function. +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} // MIME types const ( @@ -183,6 +158,8 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" ) // Headers @@ -246,297 +223,392 @@ const ( const ( // Version of Echo - Version = "4.7.2" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` + Version = "5.0.0-alpha" ) -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) - -// Errors -var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") -) - -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } - - MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) - } - return ErrMethodNotAllowed - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - filesystem: createFilesystem(), - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, +func New() *Echo { + logger := newJSONLogger(os.Stdout) + e := &Echo{ + Logger: logger, + Filesystem: newDefaultFS(), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + + routers: make(map[string]Router), + routerCreator: func(ec *Echo) Router { + return NewRouter(RouterConfig{}) }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - ListenerNetwork: "tcp", } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.JSONSerializer = &DefaultJSONSerializer{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { + + e.router = NewRouter(RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() interface{} { return e.NewContext(nil, nil) } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return + return e } -// NewContext returns a Context instance. +// NewContext returns a new Context instance. +// +// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { - return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + var c Context + if e.NewContextFunc != nil { + c = e.NewContextFunc(e, e.contextPathParamAllocSize) + } else { + c = NewDefaultContext(e, e.contextPathParamAllocSize) } + c.SetRequest(r) + c.SetResponse(NewResponse(w, e)) + return c } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } -// Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { - return e.routers +// Routers returns the new map of host => router. +func (e *Echo) Routers() map[string]Router { + result := make(map[string]Router) + for host, r := range e.routers { + result[host] = r + } + return result +} + +// RouterFor returns Router for given host. When host is left empty the default router is returned. +func (e *Echo) RouterFor(host string) (Router, bool) { + if host == "" { + return e.router, true + } + router, ok := e.routers[host] + return router, ok +} + +// ResetRouterCreator resets callback for creating new router instances. +// Note: current (default) router is immediately replaced with router created with creator func and vhost routers are cleared. +func (e *Echo) ResetRouterCreator(creator func(e *Echo) Router) { + e.routerCreator = creator + e.router = creator(e) + e.routers = make(map[string]Router) } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not // -// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "commited" the // response and status code header has been sent to the client. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - - if c.Response().Committed { - return - } - - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr - } +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c Context, err error) { + if c.Response().Committed { + return } - } else { - he = &HTTPError{ + + he := &HTTPError{ Code: http.StatusInternalServerError, Message: http.StatusText(http.StatusInternalServerError), } - } + if errors.As(err, &he) { + if he.Internal != nil { // max 2 levels of checks even if internal could have also internal + errors.As(he.Internal, &he) + } + } - // Issue #1426 - code := he.Code - message := he.Message - if m, ok := he.Message.(string); ok { - if e.Debug { - message = Map{"message": m, "error": err.Error()} - } else { - message = Map{"message": m} + // Issue #1426 + code := he.Code + message := he.Message + switch m := he.Message.(type) { + case string: + if exposeError { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } + case json.Marshaler: + // do nothing - this type knows how to format itself to JSON + case error: + message = Map{"message": m.Error()} } - } - // Send response - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) + // Send response + var cErr error + if c.Request().Method == http.MethodHead { // Issue #608 + cErr = c.NoContent(he.Code) + } else { + cErr = c.JSON(code, message) + } + if cErr != nil { + c.Echo().Logger.Error(err) // truly rare case. ala client already disconnected + } } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return e.Add(RouteNotFound, path, h, m...) +} + +// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// +// Note: this method only adds specific set of supported HTTP methods as handler and is not true +// "catch-any-arbitrary-method" way of matching requests. +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.PathParam("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) + } +} + +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { return c.File(file) - }, m...) + } + return e.Add(http.MethodGet, path, handler, middleware...) } -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Routable) (RouteInfo, error) { + return e.add("", route) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) +func (e *Echo) add(host string, route Routable) (RouteInfo, error) { + if e.OnAddRoute != nil { + if err := e.OnAddRoute(host, route); err != nil { + return nil, err + } + } + router := e.findRouter(host) - router.Add(method, path, func(c Context) error { - h := applyMiddleware(handler, middleware...) - return h(c) - }) - r := &Route{ - Method: method, - Path: path, - Name: name, + ri, err := router.Add(route) + if err != nil { + return nil, err } - e.router.routes[method+path] = r - return r + + paramsCount := len(ri.Params()) + if paramsCount > e.contextPathParamAllocSize { + e.contextPathParamAllocSize = paramsCount + } + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + "", + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Host creates a new router group for the provided host and optional host-level middleware. func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) + e.routers[name] = e.routerCreator(e) g = &Group{host: name, echo: e} g.Use(m...) return @@ -549,328 +621,85 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) -} - -// Reverse generates an URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() -} - -// Routes returns the registered routes. -func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes -} - // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) + return e.contextPool.Get().(Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) + var c ServableContext + if e.NewContextFunc != nil { + // NOTE: we are not casting always context to RoutableContext because casting to interface vs pointer to struct is + // "significantly" slower. Echo Context interface has way to many methods so these checks take time. + // These are benchmarks with 1.16: + // * interface extending another interface = +24% slower (3233 ns/op vs 2605 ns/op) + // * interface (not extending any, just methods)= +14% slower + // + // Quote from https://stackoverflow.com/a/31584377 + // "it's even worse with interface-to-interface assertion, because you also need to ensure that the type implements the interface." + // + // So most of the time we do not need custom context type and simple IF + cast to pointer to struct is fast enough. + c = e.contextPool.Get().(ServableContext) + } else { + c = e.contextPool.Get().(*DefaultContext) + } c.Reset(r, w) - var h func(Context) error + var h HandlerFunc if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc Context) error { + // NOTE: router will be executed after pre middlewares have been run. We assume here that context we receive after pre middlewares + // is the same we began with. If not - this is use-case we do not support and is probably abuse from developer. + h1 := applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) + e.HTTPErrorHandler(c, err) } - // Release context - e.pool.Put(c) + e.contextPool.Put(c) } -// Start starts an HTTP server. +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(e); err != http.ErrServerClosed { +// log.Fatal(err) +// } +// +// // or standard library `http.Server` +// +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != http.ErrServerClosed { +// log.Fatal(err) +// } func (e *Echo) Start(address string) error { - e.startupMutex.Lock() - e.Server.Addr = address - if err := e.configureServer(e.Server); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return e.Server.Serve(e.Listener) -} - -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMutex.Lock() - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - e.startupMutex.Unlock() - return - } - - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - e.startupMutex.Unlock() - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMutex.Unlock() - return - } - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return ioutil.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - e.startupMutex.Lock() - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func (e *Echo) configureTLS(address string) { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMutex.Lock() - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - if s.TLSConfig != nil { - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -func (e *Echo) configureServer(s *http.Server) error { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt) // start shutdown process on ctrl+c + defer cancel() + sc.GracefulContext = ctx - if s.TLSConfig == nil { - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return nil - } - if e.TLSListener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } - return nil -} - -// ListenerAddr returns net.Addr for Listener -func (e *Echo) ListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.Listener == nil { - return nil - } - return e.Listener.Addr() -} - -// TLSListenerAddr returns net.Addr for TLSListener -func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.TLSListener == nil { - return nil - } - return e.TLSListener.Addr() -} - -// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { - e.startupMutex.Lock() - // Setup - s := e.Server - s.Addr = address - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = h2c.NewHandler(e, h2s) - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - e.startupMutex.Unlock() - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - if he.Internal == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) - } - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he -} - -// Unwrap satisfies the Go 1.13 error wrapper interface. -func (he *HTTPError) Unwrap() error { - return he.Internal + return sc.Start(e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. @@ -895,19 +724,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -// GetPath returns RawPath, if it's empty returns Path from URL -// Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - -func (e *Echo) findRouter(host string) *Router { +func (e *Echo) findRouter(host string) Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { return r @@ -916,53 +733,74 @@ func (e *Echo) findRouter(host string) *Router { return e.router } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return t.String() + return h } -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + prefix string + fs fs.FS +} -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: nil, + } } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return +func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) } - // Ignore error from setting the KeepAlivePeriod as some systems, such as - // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP - _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) - return + return fs.fs.Open(name) } -func newListener(address, network string) (*tcpKeepAliveListener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, ErrInvalidListenerNetwork +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if !filepath.IsAbs(root) { + root = filepath.Join(dFS.prefix, root) + } + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil } - l, err := net.Listen(network, address) + return fs.Sub(currentFs, root) +} + +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) if err != nil { - return nil, err + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil + return subFs } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) } - return h + return uri } diff --git a/echo_fs.go b/echo_fs.go deleted file mode 100644 index c3790545a..000000000 --- a/echo_fs.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "net/url" - "os" - "path/filepath" -) - -type filesystem struct { -} - -func createFilesystem() filesystem { - return filesystem{} -} - -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. - } - return e.static(prefix, root, e.GET) -} - -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err - } - - name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) - if err != nil { - // The access path does not exist - return NotFoundHandler(c) - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return c.File(name) - } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) - } - get(prefix, h) - } - return get(prefix+"/*", h) -} diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go deleted file mode 100644 index eb17768ab..000000000 --- a/echo_fs_go1.16.go +++ /dev/null @@ -1,169 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "fmt" - "io/fs" - "net/http" - "net/url" - "os" - "path/filepath" - "runtime" - "strings" -) - -type filesystem struct { - // Filesystem is file system used by Static and File handlers to access files. - // Defaults to os.DirFS(".") - // - // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary - // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths - // including `assets/images` as their prefix. - Filesystem fs.FS -} - -func createFilesystem() filesystem { - return filesystem{ - Filesystem: newDefaultFS(), - } -} - -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string) *Route { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - ) -} - -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } - - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return fsFile(c, name, fileSystem) - } -} - -// FileFS registers a new route with path to serve file from the provided file system. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return e.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// StaticFileHandler creates handler function to serve file from provided file system -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c Context) error { - return fsFile(c, file, filesystem) - } -} - -// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. -// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. -// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` -// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not -// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to -// traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - prefix string - fs fs.FS -} - -func newDefaultFS() *defaultFS { - dir, _ := os.Getwd() - return &defaultFS{ - prefix: dir, - fs: nil, - } -} - -func (fs defaultFS) Open(name string) (fs.File, error) { - if fs.fs == nil { - return os.Open(name) - } - return fs.fs.Open(name) -} - -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. - // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we - // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if isRelativePath(root) { - root = filepath.Join(dFS.prefix, root) - } - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil - } - return fs.Sub(currentFs, root) -} - -func isRelativePath(path string) bool { - if path == "" { - return true - } - if path[0] == '/' { - return false - } - if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { - // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names - // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats - return false - } - return true -} - -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) - if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) - } - return subFs -} diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go deleted file mode 100644 index 07e516555..000000000 --- a/echo_fs_go1.16_test.go +++ /dev/null @@ -1,265 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" -) - -func TestEcho_StaticFS(t *testing.T) { - var testCases = []struct { - name string - givenPrefix string - givenFs fs.FS - givenFsRoot string - whenURL string - expectStatus int - expectHeaderLocation string - expectBodyStartsWith string - }{ - { - name: "ok", - givenPrefix: "/images", - givenFs: os.DirFS("./_fixture/images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "ok, from sub fs", - givenPrefix: "/images", - givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "No file", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/scripts"), - whenURL: "/images/bolt.png", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/images"), - whenURL: "/images/", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory Redirect", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/folder", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenFs: os.DirFS("_fixture"), - whenURL: "/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory 404 (request URL without slash)", - givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Prefixed directory redirect (without slash redirect to slash)", - givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending with slash)", - givenPrefix: "/assets/", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending without slash)", - givenPrefix: "/assets", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Sub-directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/folder/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "do not allow directory traversal (backslash - windows separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/..\\middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "do not allow directory traversal (slash - unix separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/../middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - tmpFs := tc.givenFs - if tc.givenFsRoot != "" { - tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) - } - e.StaticFS(tc.givenPrefix, tmpFs) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - body := rec.Body.String() - if tc.expectBodyStartsWith != "" { - assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) - } else { - assert.Equal(t, "", body) - } - - if tc.expectHeaderLocation != "" { - assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) - } else { - _, ok := rec.Result().Header["Location"] - assert.False(t, ok) - } - }) - } -} - -func TestEcho_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestEcho_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../assets", - expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", - }, - { - name: "panics for /", - givenRoot: "/assets", - expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - assert.PanicsWithError(t, tc.expectError, func() { - e.Static("../assets", tc.givenRoot) - }) - }) - } -} diff --git a/echo_test.go b/echo_test.go index 0e1e42be0..3adbd8c92 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,31 +3,26 @@ package echo import ( "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" - "io/ioutil" + "io/fs" "net" "net/http" "net/http/httptest" "net/url" "os" - "reflect" + "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` @@ -61,16 +56,17 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestEchoStatic(t *testing.T) { +func TestEcho_StaticFS(t *testing.T) { var testCases = []struct { name string givenPrefix string - givenRoot string + givenFs fs.FS + givenFsRoot string whenURL string expectStatus int expectHeaderLocation string @@ -79,15 +75,15 @@ func TestEchoStatic(t *testing.T) { { name: "ok", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("./_fixture/images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { - name: "ok with relative path for root points to directory", + name: "ok, from sub fs", givenPrefix: "/images", - givenRoot: "./_fixture/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), @@ -95,7 +91,7 @@ func TestEchoStatic(t *testing.T) { { name: "No file", givenPrefix: "/images", - givenRoot: "_fixture/scripts", + givenFs: os.DirFS("_fixture/scripts"), whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -103,7 +99,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("_fixture/images"), whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -111,7 +107,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture/"), whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -120,7 +116,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect with non-root path", givenPrefix: "/static", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", @@ -129,7 +125,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -137,7 +133,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -146,7 +142,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -154,7 +150,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -162,7 +158,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -170,7 +166,7 @@ func TestEchoStatic(t *testing.T) { { name: "Sub-directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -178,7 +174,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -186,20 +182,37 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, + { + name: "open redirect vulnerability", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad + expectBodyStartsWith: "", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.Static(tc.givenPrefix, tc.givenRoot) + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { @@ -218,39 +231,117 @@ func TestEchoStatic(t *testing.T) { } } -func TestEchoStaticRedirectIndex(t *testing.T) { - e := New() +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } - // HandlerFunc - e.Static("/static", "_fixture") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - errCh := make(chan error) + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() - go func() { - errCh <- e.Start(":0") - }() + e.ServeHTTP(rec, req) - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) + assert.Equal(t, tc.expectCode, rec.Code) - addr := e.ListenerAddr().String() - if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(t, err.Error()) - } +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", + }, + { + name: "panics for /", + givenRoot: "/assets", + expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", + }, + } - } else { - assert.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.PanicsWithError(t, tc.expectError, func() { + e.Static("../assets", tc.givenRoot) + }) + }) } +} - if err := e.Close(); err != nil { - t.Fatal(err) +func TestEchoStaticRedirectIndex(t *testing.T) { + e := New() + + // HandlerFunc + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/static*", ri.Path()) + assert.Equal(t, "GET:/static*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) } + + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { @@ -310,7 +401,8 @@ func TestEchoMiddleware(t *testing.T) { e.Pre(func(next HandlerFunc) HandlerFunc { return func(c Context) error { - assert.Empty(t, c.Path()) + // before route match is found RouteInfo does not exist + assert.Equal(t, nil, c.RouteInfo()) buf.WriteString("-1") return next(c) } @@ -355,7 +447,7 @@ func TestEchoMiddlewareError(t *testing.T) { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -410,128 +502,202 @@ func TestEchoWrapMiddleware(t *testing.T) { } } +func TestEchoGet_routeInfoIsImmutable(t *testing.T) { + e := New() + ri := e.GET("/test", handlerFunc) + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err := e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) + + rInfo := ri.(routeInfo) + rInfo.name = "changed" // this change should not change other returned values + + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err = e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) +} + func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodConnect+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodDelete+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodGet+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodHead+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodOptions+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPatch+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPost+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPut+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) + + ri := e.TRACE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodTrace+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoAny(t *testing.T) { // JFC e := New() - e.Any("/", func(c Context) error { + ris := e.Any("/", func(c Context) error { return c.String(http.StatusOK, "Any") }) + assert.Len(t, ris, 11) } func TestEchoMatch(t *testing.T) { // JFC e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { return c.String(http.StatusOK, "Match") }) + assert.Len(t, ris, 2) } -func TestEchoURL(t *testing.T) { - e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getAny := func(Context) error { return nil } - getFile := func(Context) error { return nil } - - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - e.GET("/documents/*", getAny) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) - - assert := assert.New(t) - - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal("/documents/*", e.URL(getAny)) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) -} - -func TestEchoRoutes(t *testing.T) { - e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } -} - -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEcho_Routers_HandleHostsProperly(t *testing.T) { e := New() h := e.Host("route.com") routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + {Method: http.MethodGet, Path: "/users/:user/events"}, + {Method: http.MethodGet, Path: "/users/:user/events/public"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/refs"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/tags"}, } for _, r := range routes { h.Add(r.Method, r.Path, func(c Context) error { @@ -539,17 +705,22 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { }) } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { + routers := e.Routers() + + routeCom, ok := routers["route.com"] + assert.True(t, ok) + + if assert.Equal(t, len(routes), len(routeCom.Routes())) { + for _, r := range routeCom.Routes() { found := false for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { + if r.Method() == rr.Method && r.Path() == rr.Path { found = true break } } if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + t.Errorf("Route %s %s not found", r.Method(), r.Path()) } } } @@ -561,7 +732,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { return c.String(http.StatusOK, "/with/slash") }) e.GET("/:id", func(c Context) error { - return c.String(http.StatusOK, c.Param("id")) + return c.String(http.StatusOK, c.PathParam("id")) }) var testCases = []struct { @@ -598,8 +769,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assert := assert.New(t) - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } @@ -694,8 +863,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(tc.expectStatus, rec.Code) - assert.Equal(tc.expectBody, rec.Body.String()) + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } @@ -758,737 +927,421 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } -func TestEchoNotFound(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/files", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestEchoMethodNotAllowed(t *testing.T) { - e := New() - - e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "Echo!") - }) - req := httptest.NewRequest(http.MethodPost, "/", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) - assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) -} - -func TestEchoContext(t *testing.T) { - e := New() - c := e.AcquireContext() - assert.IsType(t, new(context), c) - e.ReleaseContext(c) -} - -func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(5 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - var addr net.Addr - if isTLS { - addr = e.TLSListenerAddr() - } else { - addr = e.ListenerAddr() - } - if addr != nil && strings.Contains(addr.String(), ":") { - return nil // was started - } - case err := <-errChan: - if err == http.ErrServerClosed { - return nil - } - return err - } - } -} - -func TestEchoStart(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - err := e.Start(":0") - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - assert.NoError(t, e.Close()) -} - -func TestEcho_StartTLS(t *testing.T) { +func TestEcho_RouteNotFound(t *testing.T) { var testCases = []struct { name string - addr string - certFile string - keyFile string - expectError string + whenURL string + expectRoute interface{} + expectCode int }{ { - name: "ok", - addr: ":0", + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, }, { - name: "nok, invalid certFile", - addr: ":0", - certFile: "not existing", - expectError: "open not existing: no such file or directory", + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, }, { - name: "nok, invalid keyFile", - addr: ":0", - keyFile: "not existing", - expectError: "open not existing: no such file or directory", + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, }, { - name: "nok, failed to create cert out of certFile and keyFile", - addr: ":0", - keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key - expectError: "tls: found a certificate rather than a key in the PEM for the private key", - }, - { - name: "nok, invalid tls address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - errChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - if tc.certFile != "" { - certFile = tc.certFile - } - keyFile := "_fixture/certs/key.pem" - if tc.keyFile != "" { - keyFile = tc.keyFile - } + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } - err := e.StartTLS(tc.addr, certFile, keyFile) - if err != nil { - errChan <- err - } - }() + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - if _, ok := err.(*os.PathError); ok { - assert.Error(t, err) // error messages for unix and windows are different. so test only error type here - } else { - assert.EqualError(t, err, tc.expectError) - } - } else { - assert.NoError(t, err) - } + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) - assert.NoError(t, e.Close()) + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) }) } } -func TestEchoStartTLSAndStart(t *testing.T) { - // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server +func TestEchoNotFound(t *testing.T) { e := New() - e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - errTLSChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - keyFile := "_fixture/certs/key.pem" - err := e.StartTLS("localhost:", certFile, keyFile) - if err != nil { - errTLSChan <- err - } - }() - - err := waitForServerStart(e, errTLSChan, true) - assert.NoError(t, err) - defer func() { - if err := e.Shutdown(stdContext.Background()); err != nil { - t.Error(err) - } - }() - - // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - res, err := client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + req := httptest.NewRequest(http.MethodGet, "/files", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) +} - errChan := make(chan error) - go func() { - err := e.Start("localhost:") - if err != nil { - errChan <- err - } - }() - err = waitForServerStart(e, errChan, false) - assert.NoError(t, err) +func TestEchoMethodNotAllowed(t *testing.T) { + e := New() - // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS - res, err = http.Get("http://" + e.ListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + e.GET("/", func(c Context) error { + return c.String(http.StatusOK, "Echo!") + }) + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) - // see if HTTPS works after HTTP listener is also added - res, err = client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) +func TestEcho_OnAddRoute(t *testing.T) { + type rr struct { + host string + path string + } + exampleRoute := Route{ + Method: http.MethodGet, + Path: "/api/files/:id", + Handler: notFoundHandler, + Middlewares: nil, + Name: "x", + } - testCases := []struct { - cert interface{} - key interface{} - expectedErr error + var testCases = []struct { name string + whenHost string + whenRoute Routable + whenError error + expectLen int + expectAdded []rr + expectError string }{ { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, + name: "ok, for default host", + whenHost: "", + whenRoute: exampleRoute, + whenError: nil, + expectAdded: []rr{ + {host: "", path: "/static"}, + {host: "", path: "/api/files/:id"}, + }, + expectError: "", + expectLen: 2, }, { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, + name: "ok, for specific host", + whenHost: "test.com", + whenRoute: exampleRoute, + whenError: nil, + expectAdded: []rr{ + {host: "", path: "/static"}, + {host: "test.com", path: "/api/files/:id"}, + }, + expectError: "", + expectLen: 1, }, { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, - }, - { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, - }, - { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, + name: "nok, error is returned", + whenHost: "test.com", + whenRoute: exampleRoute, + whenError: errors.New("nope"), + expectAdded: []rr{ + {host: "", path: "/static"}, + }, + expectError: "nope", + expectLen: 0, }, } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { e := New() - e.HideBanner = true - errChan := make(chan error) + added := make([]rr, 0) + cnt := 0 + e.OnAddRoute = func(host string, route Routable) error { + if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests + return tc.whenError + } + cnt++ + added = append(added, rr{ + host: host, + path: route.ToRoute().Path, + }) + return nil + } - go func() { - errChan <- e.StartTLS(":0", test.cert, test.key) - }() + e.GET("/static", notFoundHandler) - err := waitForServerStart(e, errChan, true) - if test.expectedErr != nil { - assert.EqualError(t, err, test.expectedErr.Error()) + var err error + if tc.whenHost != "" { + _, err = e.Host(tc.whenHost).AddRoute(tc.whenRoute) } else { - assert.NoError(t, err) + _, err = e.AddRoute(tc.whenRoute) } - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartAutoTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - errChan <- e.StartAutoTLS(tc.addr) - }() - - err := waitForServerStart(e, errChan, true) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } - assert.NoError(t, e.Close()) + r, _ := e.RouterFor(tc.whenHost) + assert.Len(t, r.Routes(), tc.expectLen) + assert.Equal(t, tc.expectAdded, added) }) } } -func TestEcho_StartH2CServer(t *testing.T) { +func TestEcho_RouterFor(t *testing.T) { var testCases = []struct { - name string - addr string - expectError string + name string + whenHost string + expectLen int + expectOk bool }{ { - name: "ok", - addr: ":0", + name: "ok, default host", + whenHost: "", + expectLen: 2, + expectOk: true, + }, + { + name: "ok, specific host with routes", + whenHost: "test.com", + expectLen: 1, + expectOk: true, }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, non-existent host", + whenHost: "oups.com", + expectLen: 0, + expectOk: false, }, } - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.Debug = true - h2s := &http2.Server{} - - errChan := make(chan error) - go func() { - err := e.StartH2CServer(tc.addr, h2s) - if err != nil { - errChan <- err - } - }() - err := waitForServerStart(e, errChan, false) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) + e.GET("/1", notFoundHandler) + e.GET("/2", notFoundHandler) + e.Host("test.com").GET("/3", notFoundHandler) + + r, ok := e.RouterFor(tc.whenHost) + assert.Equal(t, tc.expectOk, ok) + if tc.expectLen > 0 { + assert.Len(t, r.Routes(), tc.expectLen) } else { - assert.NoError(t, err) + assert.Nil(t, r) } - - assert.NoError(t, e.Close()) }) } } -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) -} - -func request(method, path string, e *Echo) (int, string) { - req := httptest.NewRequest(method, path, nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - return rec.Code, rec.Body.String() -} - -func TestHTTPError(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) -} - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) -} - -func TestDefaultHTTPErrorHandler(t *testing.T) { +func TestEchoContext(t *testing.T) { e := New() - e.Debug = true - e.Any("/plain", func(c Context) error { - return errors.New("An error occurred") - }) - e.Any("/badrequest", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - e.Any("/servererror", func(c Context) error { - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - e.Any("/early-return", func(c Context) error { - c.String(http.StatusOK, "OK") - return errors.New("ERROR") - }) - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) - - // With Debug=true plain response contains error message - c, b := request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) - // and special handling for HTTPError - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) - // complex errors are serialized to pretty JSON - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) - // if the body is already set HTTPErrorHandler should not add anything to response body - c, b = request(http.MethodGet, "/early-return", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, "OK", b) - // internal error should be reflected in the message - c, b = request(http.MethodGet, "/internal-error", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) - - e.Debug = false - // With Debug=false the error response is shortened - c, b = request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) - // No difference for error response with non plain string errors - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) + c := e.AcquireContext() + assert.IsType(t, new(DefaultContext), c) + e.ReleaseContext(c) } -func TestEchoClose(t *testing.T) { +func TestEcho_Start(t *testing.T) { e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { + e.GET("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { t.Fatal(err) } - - assert.NoError(t, e.Close()) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - + defer rndPort.Close() + errChan := make(chan error, 1) go func() { - errCh <- e.Start(":0") + errChan <- e.Start(rndPort.Addr().String()) }() - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -var listenerNetworkTests = []struct { - test string - network string - address string -}{ - {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, - {"tcp ipv6 address", "tcp", "[::1]:1323"}, - {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, - {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, -} - -func supportsIPv6() bool { - addrs, _ := net.InterfaceAddrs() - for _, addr := range addrs { - // Check if any interface has local IPv6 assigned - if strings.Contains(addr.String(), "::1") { - return true + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + expectContains := "bind: address already in use" + if runtime.GOOS == "windows" { + expectContains = "bind: Only one usage of each socket address" } + assert.Contains(t, err.Error(), expectContains) } - return false } -func TestEchoListenerNetwork(t *testing.T) { - hasIPv6 := supportsIPv6() - for _, tt := range listenerNetworkTests { - if !hasIPv6 && strings.Contains(tt.address, "::") { - t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") - continue - } - t.Run(tt.test, func(t *testing.T) { - e := New() - e.ListenerNetwork = tt.network - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - errCh := make(chan error) - - go func() { - errCh <- e.Start(tt.address) - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, "OK", string(body)) - } else { - assert.Fail(t, err.Error()) - } - - } else { - assert.Fail(t, err.Error()) - } - - if err := e.Close(); err != nil { - t.Fatal(err) - } - }) - } +func request(method, path string, e *Echo) (int, string) { + req := httptest.NewRequest(method, path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec.Code, rec.Body.String() } -func TestEchoListenerNetworkInvalid(t *testing.T) { - e := New() - e.ListenerNetwork = "unix" - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +type customError struct { + s string } -func TestEchoReverse(t *testing.T) { - assert := assert.New(t) - - e := New() - dummyHandler := func(Context) error { return nil } - - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal("/params/:foo", e.Reverse("/params/:foo")) - assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +func (ce *customError) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil } -func TestEchoReverseHandleHostProperly(t *testing.T) { - assert := assert.New(t) - - dummyHandler := func(Context) error { return nil } - - e := New() - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) +func (ce *customError) Error() string { + return ce.s } -func TestEcho_ListenerAddr(t *testing.T) { - e := New() - - addr := e.ListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) -} - -func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - - e := New() - - addr := e.TLSListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.StartTLS(":0", cert, key) - }() - - err = waitForServerStart(e, errCh, true) - assert.NoError(t, err) -} - -func TestEcho_StartServer(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - certs, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - +func TestDefaultHTTPErrorHandler(t *testing.T) { var testCases = []struct { - name string - addr string - TLSConfig *tls.Config - expectError string + name string + givenExposeError bool + givenLoggerFunc bool + whenMethod string + whenError error + expectBody string + expectStatus int + expectLogged string }{ { - name: "ok", - addr: ":0", + name: "ok, expose error = true, HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error","message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = true, HTTPError + internal error", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error, internal=internal_error","message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = true, HTTPError + internal HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"error":"code=418, message=my_error, internal=code=425, message=early_error","message":"early_error"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"early_error"}` + "\n", + }, + { + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", + }, + { + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", }, { - name: "ok, start with TLS", - addr: ":0", - TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, custom error implement MarshalJSON", + whenMethod: http.MethodGet, + whenError: NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}), + expectStatus: http.StatusBadRequest, + expectBody: "{\"x\":\"custom error msg\"}\n", }, { - name: "nok, invalid tls address", - addr: "nope", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - expectError: "listen tcp: address nope: missing port in address", + name: "with Debug=false when httpError contains an error", + whenError: NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")), + expectStatus: http.StatusBadRequest, + expectBody: "{\"message\":\"error in httperror\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.Debug = true - - server := new(http.Server) - server.Addr = tc.addr - if tc.TLSConfig != nil { - server.TLSConfig = tc.TLSConfig - } + e.Logger = &jsonLogger{writer: buf} + e.Any("/path", func(c Context) error { + return tc.whenError + }) - errCh := make(chan error) - go func() { - errCh <- e.StartServer(server) - }() + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) - err := waitForServerStart(e, errCh, tc.TLSConfig != nil) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod } - assert.NoError(t, e.Close()) + c, b := request(method, "/path", e) + + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func benchmarkEchoRoutes(b *testing.B, routes []*Route) { +type myCustomContext struct { + DefaultContext +} + +func (c *myCustomContext) QueryParam(name string) string { + return "prefix_" + c.DefaultContext.QueryParam(name) +} + +func TestEcho_customContext(t *testing.T) { + e := New() + e.NewContextFunc = func(ec *Echo, pathParamAllocSize int) ServableContext { + return &myCustomContext{ + DefaultContext: *NewDefaultContext(ec, pathParamAllocSize), + } + } + + e.GET("/info/:id/:file", func(c Context) error { + return c.String(http.StatusTeapot, c.QueryParam("param")) + }) + + status, body := request(http.MethodGet, "/info/1/a.csv?param=123", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "prefix_123", body) +} + +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) u := req.URL w := httptest.NewRecorder() diff --git a/go.mod b/go.mod index 4de2bdde1..39a94922e 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,18 @@ -module github.com/labstack/echo/v4 +module github.com/labstack/echo/v5 -go 1.17 +go 1.18 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.1 - github.com/stretchr/testify v1.7.0 - github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + github.com/stretchr/testify v1.8.2 + github.com/valyala/fasttemplate v1.2.2 + golang.org/x/net v0.7.0 + golang.org/x/time v0.3.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect - golang.org/x/text v0.3.7 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + golang.org/x/text v0.7.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f66734243..10db583bf 100644 --- a/go.sum +++ b/go.sum @@ -1,45 +1,27 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= -github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= -github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= -github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go index bba470ce8..b9df5af9a 100644 --- a/group.go +++ b/group.go @@ -1,98 +1,121 @@ package echo import ( + "io/fs" "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + host string + prefix string + middleware []MiddlewareFunc + echo *Echo +} // Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) @@ -102,18 +125,64 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { return } -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(path, file, g.GET) +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + return g.StaticFS(pathPrefix, subFs) } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return g.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) +} + +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return g.Add(RouteNotFound, path, h, m...) +} + +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri +} + +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Routable) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(g.host, groupRoute) } diff --git a/group_fs.go b/group_fs.go deleted file mode 100644 index 0a1ce4a94..000000000 --- a/group_fs.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) -} diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go deleted file mode 100644 index 2ba52b5e2..000000000 --- a/group_fs_go1.16.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "io/fs" - "net/http" -) - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string) { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - g.StaticFS(pathPrefix, subFs) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { - g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} diff --git a/group_fs_go1.16_test.go b/group_fs_go1.16_test.go deleted file mode 100644 index d0caa33db..000000000 --- a/group_fs_go1.16_test.go +++ /dev/null @@ -1,106 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestGroup_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/assets/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - g := e.Group("/assets") - g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestGroup_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../images", - expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", - }, - { - name: "panics for /", - givenRoot: "/images", - expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - g := e.Group("/assets") - - assert.PanicsWithError(t, tc.expectError, func() { - g.Static("/images", tc.givenRoot) - }) - }) - } -} diff --git a/group_test.go b/group_test.go index c51fd91eb..6c08b2b6f 100644 --- a/group_test.go +++ b/group_test.go @@ -1,38 +1,76 @@ package echo import ( - "io/ioutil" + "github.com/stretchr/testify/assert" + "io/fs" "net/http" "net/http/httptest" + "os" + "strings" "testing" - - "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { e := New() g := e.Group("/group") g.File("/walle", "_fixture/images/walle.png") - expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") + expectedData, err := os.ReadFile("_fixture/images/walle.png") assert.Nil(t, err) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) rec := httptest.NewRecorder() @@ -92,11 +130,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } m2 := func(next HandlerFunc) HandlerFunc { return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } } h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } g.Use(m1) g.GET("/help", h, m2) @@ -119,3 +157,674 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ris := users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 11) + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_AnyWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/books/download*", ri.Path()) + assert.Equal(t, "GET:/books/download*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../images", + expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", + }, + { + name: "panics for /", + givenRoot: "/images", + expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.PanicsWithError(t, tc.expectError, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} + +func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { + var testCases = []struct { + name string + givenCustom404 bool + whenURL string + expectBody interface{} + expectCode int + expectMiddlewareCalled bool + }{ + { + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "404 GET /group/*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added + }, + { + name: "ok, default group 404 handler is not called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + { + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) + } + + e := New() + e.GET("/test1", okHandler) + e.RouteNotFound("/*", notFoundHandler) + + g := e.Group("/group") + g.GET("/test1", okHandler) + + middlewareCalled := false + g.Use(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + middlewareCalled = true + return next(c) + } + }) + if tc.givenCustom404 { + g.RouteNotFound("/*", notFoundHandler) + } + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} diff --git a/httperror.go b/httperror.go new file mode 100644 index 000000000..5c217dac1 --- /dev/null +++ b/httperror.go @@ -0,0 +1,74 @@ +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// Errors +var ( + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) + ErrForbidden = NewHTTPError(http.StatusForbidden) + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) + ErrBadRequest = NewHTTPError(http.StatusBadRequest) + ErrBadGateway = NewHTTPError(http.StatusBadGateway) + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Code int `json:"-"` + Message interface{} `json:"message"` + Internal error `json:"-"` // Stores the error returned by an external dependency +} + +// NewHTTPError creates a new HTTPError instance. +func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used? + he := &HTTPError{Code: code, Message: http.StatusText(code)} + if len(message) > 0 { + he.Message = message[0] + } + return he +} + +// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set. +func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError { + he := NewHTTPError(code, message...) + he.Internal = internalError + return he +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } + return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) +} + +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 000000000..f9d340f11 --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,52 @@ +package echo + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHTTPError(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestNewHTTPErrorWithInternal(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message") + assert.Equal(t, "code=400, message=test message, internal=test", he.Error()) +} + +func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test")) + assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error()) +} + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/ip.go b/ip.go index 46d464cf9..1bcd756ae 100644 --- a/ip.go +++ b/ip.go @@ -227,6 +227,8 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { return func(req *http.Request) string { realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { + realIP = strings.TrimPrefix(realIP, "[") + realIP = strings.TrimSuffix(realIP, "]") if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } @@ -248,7 +250,10 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { } ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) for i := len(ips) - 1; i >= 0; i-- { - ip := net.ParseIP(strings.TrimSpace(ips[i])) + ips[i] = strings.TrimSpace(ips[i]) + ips[i] = strings.TrimPrefix(ips[i], "[") + ips[i] = strings.TrimSuffix(ips[i], "]") + ip := net.ParseIP(ips[i]) if ip == nil { // Unable to parse IP; cannot trust entire records return directIP diff --git a/ip_test.go b/ip_test.go index 755900d3d..38c4a1cac 100644 --- a/ip_test.go +++ b/ip_test.go @@ -459,6 +459,7 @@ func TestExtractIPDirect(t *testing.T) { func TestExtractIPFromRealIPHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -493,6 +494,16 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:1", + }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -506,6 +517,19 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -520,6 +544,20 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, } for _, tc := range testCases { @@ -532,6 +570,7 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { func TestExtractIPFromXFFHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -566,6 +605,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "127.0.0.3", }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", + }, + expectIP: "fe80::3", + }, { name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", whenRequest: http.Request{ @@ -576,6 +625,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::2]:8080", + }, + expectIP: "2001:db8::2", + }, { name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", givenTrustOptions: []TrustOption{ @@ -595,6 +654,25 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.100.100", // this is first trusted IP in XFF chain }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed) + // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs) + // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office) + // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP + }, + expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain + }, } for _, tc := range testCases { diff --git a/json.go b/json.go index 16b2d0577..16074fa24 100644 --- a/json.go +++ b/json.go @@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { err := json.NewDecoder(c.Request().Body).Decode(i) if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + return NewHTTPErrorWithInternal( + http.StatusBadRequest, + err, + fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset), + ) } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, + err, + fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()), + ) } return err } diff --git a/json_test.go b/json_test.go index 27ee43e73..1d1483d22 100644 --- a/json_test.go +++ b/json_test.go @@ -1,7 +1,7 @@ package echo import ( - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "strings" @@ -14,18 +14,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -34,16 +32,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) err := enc.Serialize(c, user{1, "Jon Snow"}, "") - if assert.NoError(err) { - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -53,18 +51,16 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -74,17 +70,17 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var u = user{} err := enc.Deserialize(c, &u) - if assert.NoError(err) { - assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) } var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalSyntaxError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -93,9 +89,9 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalTypeError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") } diff --git a/log.go b/log.go index 3f8de5904..ce351881c 100644 --- a/log.go +++ b/log.go @@ -1,41 +1,148 @@ package echo import ( + "bytes" "io" - - "github.com/labstack/gommon/log" + "strconv" + "sync" + "time" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) +//----------------------------------------------------------------------------- +// Example for Zap (https://github.com/uber-go/zap) +//func main() { +// e := echo.New() +// logger, _ := zap.NewProduction() +// e.Logger = &ZapLogger{logger: logger} +//} +//type ZapLogger struct { +// logger *zap.Logger +//} +// +//func (l *ZapLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZapLogger) Error(err error) { +// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo")) +//} + +//----------------------------------------------------------------------------- +// Example for Zerolog (https://github.com/rs/zerolog) +//func main() { +// e := echo.New() +// logger := zerolog.New(os.Stdout) +// e.Logger = &ZeroLogger{logger: &logger} +//} +// +//type ZeroLogger struct { +// logger *zerolog.Logger +//} +// +//func (l *ZeroLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZeroLogger) Error(err error) { +// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error()) +//} + +//----------------------------------------------------------------------------- +// Example for Logrus (https://github.com/sirupsen/logrus) +//func main() { +// e := echo.New() +// e.Logger = &LogrusLogger{logger: logrus.New()} +//} +// +//type LogrusLogger struct { +// logger *logrus.Logger +//} +// +//func (l *LogrusLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *LogrusLogger) Error(err error) { +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err) +//} + +// Logger defines the logging interface that Echo uses internally in few places. +// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice. +type Logger interface { + // Write provides writer interface for http.Server `ErrorLog` and for logging startup messages. + // `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers, + // and underlying FileSystem errors. + // `logger` middleware will use this method to write its JSON payload. + Write(p []byte) (n int, err error) + // Error logs the error + Error(err error) +} + +// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only +// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting. +// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases. +type jsonLogger struct { + writer io.Writer + bufferPool sync.Pool + lock sync.Mutex + + timeNow func() time.Time +} + +func newJSONLogger(writer io.Writer) *jsonLogger { + return &jsonLogger{ + writer: writer, + bufferPool: sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 256)) + }, + }, + timeNow: time.Now, } -) +} + +func (l *jsonLogger) Write(p []byte) (n int, err error) { + pLen := len(p) + if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message + (p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') || + (p[0] == '{' && p[pLen-1] == '}') { + return l.write(p) + } + // we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is + // called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably + // deserves at least WARN level. + return l.printf("INFO", string(p)) +} + +func (l *jsonLogger) Error(err error) { + _, _ = l.printf("ERROR", err.Error()) +} + +func (l *jsonLogger) printf(level string, message string) (n int, err error) { + buf := l.bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer l.bufferPool.Put(buf) + + buf.WriteString(`{"time":"`) + buf.WriteString(l.timeNow().Format(time.RFC3339Nano)) + buf.WriteString(`","level":"`) + buf.WriteString(level) + buf.WriteString(`","prefix":"echo","message":`) + + buf.WriteString(strconv.Quote(message)) + buf.WriteString("}\n") + + return l.write(buf.Bytes()) +} + +func (l *jsonLogger) write(p []byte) (int, error) { + l.lock.Lock() + defer l.lock.Unlock() + return l.writer.Write(p) +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 000000000..c7b4674e9 --- /dev/null +++ b/log_test.go @@ -0,0 +1,87 @@ +package echo + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +type noOpLogger struct { +} + +func (l *noOpLogger) Write(p []byte) (n int, err error) { + return 0, err +} + +func (l *noOpLogger) Error(err error) { +} + +func TestJsonLogger_Write(t *testing.T) { + var testCases = []struct { + name string + when []byte + expect string + }{ + { + name: "ok, write non JSONlike message", + when: []byte("version: %v, build: %v"), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: %v, build: %v"}` + "\n", + }, + { + name: "ok, write quoted message", + when: []byte(`version: "%v"`), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: \"%v\""}` + "\n", + }, + { + name: "ok, write JSON", + when: []byte(`{"version": 123}` + "\n"), + expect: `{"version": 123}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + _, err := logger.Write(tc.when) + + result := buf.String() + assert.Equal(t, tc.expect, result) + assert.NoError(t, err) + }) + } +} + +func TestJsonLogger_Error(t *testing.T) { + var testCases = []struct { + name string + whenError error + expect string + }{ + { + name: "ok", + whenError: ErrForbidden, + expect: `{"time":"2021-09-07T20:09:37Z","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + logger.Error(tc.whenError) + + result := buf.String() + assert.Equal(t, tc.expect, result) + }) + } +} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 000000000..77cb226dd --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,11 @@ +# Development Guidelines for middlewares + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9fc..3071eedb3 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,64 +1,59 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" + "net/http" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } if config.Realm == "" { config.Realm = defaultRealm @@ -70,29 +65,35 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) + continue } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } + if lastError != nil { + return lastError + } + realm := defaultRealm if config.Realm != defaultRealm { realm = strconv.Quote(config.Realm) @@ -102,5 +103,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..3d69ae84d 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,70 +2,157 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + validatorFunc := func(c echo.Context, u, p string) (bool, error) { if u == "joe" && p == "secret" { return true, nil } + if u == "error" { + return false, errors.New(p) + } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + defaultConfig := BasicAuthConfig{Validator: validatorFunc} - assert := assert.New(t) + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + config := tc.givenConfig - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + mw, err := config.ToMiddleware() + assert.NoError(t, err) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err = h(c) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) + }) + + mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab2..a26dd8e77 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,74 +3,67 @@ package middleware import ( "bufio" "bytes" + "errors" "io" - "io/ioutil" "net" "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } - - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte) -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } // Request reqBody := []byte{} - if c.Request().Body != nil { // Read - reqBody, _ = ioutil.ReadAll(c.Request().Body) + if c.Request().Body != nil { + reqBody, _ = io.ReadAll(c.Request().Body) } - c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset + c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset // Response resBody := new(bytes.Buffer) @@ -78,16 +71,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} c.Response().Writer = writer - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) // Callback config.Handler(c, reqBody, resBody.Bytes()) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f726..fd608167c 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -2,13 +2,13 @@ package middleware import ( "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -19,7 +19,7 @@ func TestBodyDump(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) { requestBody = string(reqBody) responseBody = string(resBody) - }) - - assert := assert.New(t) + }}.ToMiddleware() + assert.NoError(t, err) - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, +} + +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) + isCalled = true }, - }) + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) @@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}) + assert.NotNil(t, mw) + }) +} - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) }) } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..f43556c71 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,98 +1,82 @@ package middleware import ( - "fmt" "io" "sync" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" + "github.com/labstack/echo/v5" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } - - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 +} // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper + config.Skipper = DefaultSkipper } - - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) + pool := sync.Pool{ + New: func() interface{} { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(req.Body) defer pool.Put(r) req.Body = r return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -102,16 +86,7 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader - r.context = context r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 8981534d4..4981918aa 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -2,86 +2,162 @@ package middleware import ( "bytes" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestBodyLimit(t *testing.T) { +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } - // Based on content length (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + // Based on content read (overlimit) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) - } + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) } func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, + Skipper: DefaultSkipper, + LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, - reader: ioutil.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), + reader: io.NopCloser(bytes.NewReader(hw)), } // read all should return ErrStatusRequestEntityTooLarge - _, err := ioutil.ReadAll(reader) + _, err := io.ReadAll(reader) he := err.(*echo.HTTPError) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) } + +func TestBodyLimit_skipper(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw, err := BodyLimitConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimit(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimit(2 * MB) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} diff --git a/middleware/compress.go b/middleware/compress.go index ac6672e9d..fb606aee1 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -2,65 +2,83 @@ package middleware import ( "bufio" + "bytes" "compress/gzip" + "errors" "io" - "io/ioutil" "net" "net/http" "strings" "sync" - "github.com/labstack/echo/v4" -) - -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - } - - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - wroteBody bool - } + "github.com/labstack/echo/v5" ) const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - } -) +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Gzip compression level. + // Optional. Default value -1. + Level int + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int +} -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + return GzipWithConfig(GzipConfig{}) } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 + } + if config.MinLength < 0 { + config.MinLength = 0 } pool := gzipCompressPool(config) + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -71,39 +89,60 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { - return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + return echo.NewHTTPErrorWithInternal(http.StatusInternalServerError, i.(error)) } rw := res.Writer w.Reset(rw) - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { + // There are different reasons for cases when we have not yet written response to the client and now need to do so. + // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. + // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. res.Writer = rw - w.Reset(ioutil.Discard) + w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + res.Writer = rw + if grw.wroteHeader { + grw.ResponseWriter.WriteHeader(grw.code) + } + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } w.Close() + bpool.Put(buf) pool.Put(w) }() res.Writer = grw } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + w.wroteHeader = true + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { @@ -111,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + return w.Writer.Write(w.buffer.Bytes()) + } + + return n, err + } + return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { + if !w.minLengthExceeded { + // Enforce compression because we will not know how much more data will come + w.minLengthExceeded = true + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + w.Writer.Write(w.buffer.Bytes()) + } + w.Writer.(*gzip.Writer).Flush() if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -135,7 +204,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { - w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + w, err := gzip.NewWriterLevel(io.Discard, config.Level) if err != nil { return err } @@ -143,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool { }, } } + +func bufferPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + b := &bytes.Buffer{} + return b + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index b62bffef5..551f18525 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -4,93 +4,127 @@ import ( "bytes" "compress/gzip" "io" - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { + // Skip if no Accept-Encoding header + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Accept-Encoding header + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - assert.Equal("test", rec.Body.String()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal("test", buf.String()) - } + err := h(c) + assert.NoError(t, err) - chunkBuf := make([]byte, 5) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} + +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("first\n")) c.Response().Flush() - // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) - - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("second\n")) c.Response().Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) + c.Response().Write([]byte("third")) + + chunkChan <- struct{}{} return nil - })(c) + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) } -func TestGzipNoContent(t *testing.T) { +func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) { } } -func TestGzipEmpty(t *testing.T) { +func TestGzip_Empty(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -127,7 +161,7 @@ func TestGzipEmpty(t *testing.T) { } } -func TestGzipErrorReturned(t *testing.T) { +func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c echo.Context) error { @@ -141,31 +175,25 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } -func TestGzipErrorReturnedInvalidConfig(t *testing.T) { - e := echo.New() - // Invalid level - e.Use(GzipWithConfig(GzipConfig{Level: 12})) - e.GET("/", func(c echo.Context) error { - c.Response().Write([]byte("test")) - return nil - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "gzip") +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) } // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() + e.Filesystem = os.DirFS("../") + e.Use(Gzip()) - e.Static("/test", "../_fixture/images") + e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. @@ -175,7 +203,7 @@ func TestGzipWithStatic(t *testing.T) { r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { defer r.Close() - want, err := ioutil.ReadFile("../_fixture/images/walle.png") + want, err := os.ReadFile("../_fixture/images/walle.png") if assert.NoError(t, err) { buf := new(bytes.Buffer) buf.ReadFrom(r) @@ -184,6 +212,137 @@ func TestGzipWithStatic(t *testing.T) { } } +func TestGzipWithMinLength(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("foobarfoobar")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "foobarfoobar", buf.String()) + } +} + +func TestGzipWithMinLengthTooShort(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Body.String(), "test") +} + +func TestGzipWithResponseWithoutBody(t *testing.T) { + e := echo.New() + + e.Use(Gzip()) + e.GET("/", func(c echo.Context) error { + return c.Redirect(http.StatusMovedPermanently, "http://localhost") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipWithMinLengthChunked(t *testing.T) { + e := echo.New() + + // Gzip chunked + chunkBuf := make([]byte, 5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + var r *gzip.Reader = nil + + c := e.NewContext(req, rec) + next := func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + var err error + r, err = gzip.NewReader(rec.Body) + assert.NoError(t, err) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + } + err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) + + assert.NoError(t, err) + assert.NotNil(t, r) + + buf := new(bytes.Buffer) + + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) + + r.Close() +} + +func TestGzipWithMinLengthNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + func BenchmarkGzip(b *testing.B) { e := echo.New() diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go new file mode 100644 index 000000000..98af08c82 --- /dev/null +++ b/middleware/context_timeout.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "context" + "errors" + "github.com/labstack/echo/v5" + "time" +) + +// ContextTimeoutConfig defines the config for ContextTimeout middleware. +type ContextTimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorHandler is a function when error aries in middeware execution. + ErrorHandler func(c echo.Context, err error) error + + // Timeout configures a timeout for the middleware + Timeout time.Duration +} + +// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client +// when underlying method returns context.DeadlineExceeded error. +func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { + return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) +} + +// ContextTimeoutWithConfig returns a Timeout middleware with config. +func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts Config to middleware. +func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Timeout == 0 { + return nil, errors.New("timeout must be set") + } + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c echo.Context, err error) error { + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return echo.ErrServiceUnavailable.WithInternal(err) + } + return err + } + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) + defer cancel() + + c.SetRequest(c.Request().WithContext(timeoutContext)) + + if err := next(c); err != nil { + return config.ErrorHandler(c, err) + } + return nil + } + }, nil +} diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go new file mode 100644 index 000000000..bd325bf62 --- /dev/null +++ b/middleware/context_timeout_test.go @@ -0,0 +1,225 @@ +package middleware + +import ( + "context" + "errors" + "github.com/labstack/echo/v5" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestContextTimeoutSkipper(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Skipper: func(context echo.Context) bool { + return true + }, + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + return err + } + + return errors.New("response from handler") + })(c) + + // if not skipped we would have not returned error due context timeout logic + assert.EqualError(t, err, "response from handler") +} + +func TestContextTimeoutWithTimeout0(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + ContextTimeout(time.Duration(0)) + }) +} + +func TestContextTimeoutErrorOutInHandler(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + rec.Code = 1 // we want to be sure that even 200 will not be sent + err := m(func(c echo.Context) error { + // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able + // to handle returned error and this can be done only then handler has not yet committed (written status code) + // the response. + return echo.NewHTTPError(http.StatusTeapot, "err") + })(c) + + assert.Error(t, err) + assert.EqualError(t, err, "code=418, message=err") + assert.Equal(t, 1, rec.Code) + assert.Equal(t, "", rec.Body.String()) +} + +func TestContextTimeoutSuccessfulRequest(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) + })(c) + + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) +} + +func TestContextTimeoutTestRequestClone(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 1 * time.Second, + }) + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // Cookie test + cookie, err := c.Request().Cookie("cookie") + if assert.NoError(t, err) { + assert.EqualValues(t, "cookie", cookie.Name) + assert.EqualValues(t, "value", cookie.Value) + } + + // Form values + if assert.NoError(t, c.Request().ParseForm()) { + assert.EqualValues(t, "value", c.Request().FormValue("form")) + } + + // Query string + assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { + t.Parallel() + + timeout := 10 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { + return err + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.IsType(t, &echo.HTTPError{}, err) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) +} + +func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { + t.Parallel() + + timeoutErrorHandler := func(c echo.Context, err error) error { + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return &echo.HTTPError{ + Code: http.StatusServiceUnavailable, + Message: "Timeout! change me", + } + } + return err + } + return nil + } + + timeout := 50 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + ErrorHandler: timeoutErrorHandler, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order + // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. + + if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil { + return err + } + + // The Request Context should have a Deadline set by http.ContextTimeoutHandler + if _, ok := c.Request().Context().Deadline(); !ok { + assert.Fail(t, "No timeout set on Request Context") + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.IsType(t, &echo.HTTPError{}, err) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + + defer func() { + _ = timer.Stop() + }() + + select { + case <-ctx.Done(): + return context.DeadlineExceeded + case <-timer.C: + return nil + } +} diff --git a/middleware/cors.go b/middleware/cors.go index 16259512a..74ec56739 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -6,77 +6,135 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` - - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - AllowCredentials bool `yaml:"allow_credentials"` - - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. + AllowOriginFunc func(origin string) (bool, error) + + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowMethods []string + + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // + // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowHeaders []string + + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials bool + + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool + + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header + ExposeHeaders []string + + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // + // Optional. Default value 0. The header is set only if MaxAge > 0. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + MaxAge int +} + +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. -// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// See also [MDN: Cross-Origin Resource Sharing (CORS)]. +// +// Security: Poorly configured CORS can compromise security because it allows +// relaxation of the browser's Same-Origin policy. See [Exploiting CORS +// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin +// resource sharing (CORS)] for more details. +// +// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html +// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } -// CORSWithConfig returns a CORS middleware with config. -// See: `CORS()`. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. +// See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCORSConfig.Skipper @@ -93,8 +151,8 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { allowOriginPatterns := []string{} for _, origin := range config.AllowOrigins { pattern := regexp.QuoteMeta(origin) - pattern = strings.Replace(pattern, "\\*", ".*", -1) - pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") pattern = "^" + pattern + "$" allowOriginPatterns = append(allowOriginPatterns, pattern) } @@ -155,7 +213,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } else { // Check allowed origins for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { + if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { allowOrigin = origin break } @@ -172,7 +230,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { checkPatterns := false if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { + if len(origin) <= (5+3+253) && strings.Contains(origin, "://") { checkPatterns = true } } @@ -230,5 +288,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } return c.NoContent(http.StatusNoContent) } - } + }, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index daadbab6e..424f16e6c 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -6,111 +6,195 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { - e := echo.New() + var testCases = []struct { + name string + givenMW echo.MiddlewareFunc + whenMethod string + whenHeaders map[string]string + expectHeaders map[string]string + notExpectHeaders map[string]string + }{ + { + name: "ok, wildcard origin", + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, + }, + { + name: "ok, wildcard AllowedOrigin with no Origin header in request", + notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, + }, + { + name: "ok, specific AllowOrigins and AllowCredentials", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowCredentials: "true", + }, + }, + { + name: "ok, preflight request with matching origin for `AllowOrigins`", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*` + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: false, // important for this testcase + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlMaxAge: "3600", + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowCredentials: "", + }, + }, + { + name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with Access-Control-Request-Headers", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + echo.HeaderAccessControlRequestHeaders: "Special-Request-Header", + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowHeaders: "Special-Request-Header", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + }, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + mw := CORS() + if tc.givenMW != nil { + mw = tc.givenMW + } + h := mw(func(c echo.Context) error { + return nil + }) + + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + req := httptest.NewRequest(method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + for k, v := range tc.whenHeaders { + req.Header.Set(k, v) + } + + err := h(c) - // Wildcard origin - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) - req.Header.Set(echo.HeaderOrigin, "localhost") - h(c) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - - // Wildcard AllowedOrigin with no Origin header in request - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = CORS()(echo.NotFoundHandler) - h(c) - assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) - - // Allow origins - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, - })(echo.NotFoundHandler) - req.Header.Set(echo.HeaderOrigin, "localhost") - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - - // Preflight request - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with `AllowOrigins` * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with Access-Control-Request-Headers - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - - // Preflight request with `AllowOrigins` which allow all subdomains with * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - - req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com") - h(c) - assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NoError(t, err) + header := rec.Header() + for k, v := range tc.expectHeaders { + assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v) + } + for k, v := range tc.notExpectHeaders { + if v == "" { + assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k) + } else { + assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v) + } + } + }) + } } func Test_allowOriginScheme(t *testing.T) { @@ -149,7 +233,7 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -240,7 +324,7 @@ func Test_allowOriginSubdomain(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -324,7 +408,9 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) } - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) @@ -511,11 +597,11 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors := CORSWithConfig(CORSConfig{ - AllowOriginFunc: allowOriginFunc, - }) - h := cors(echo.NotFoundHandler) - err := h(c) + cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) + err = h(c) expected, expectedErr := allowOriginFunc(origin) if expectedErr != nil { diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..e5c1af70e 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -5,91 +5,96 @@ import ( "net/http" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" or "header::" - // - "query:" - // - "form:" - // Multiple sources example: - // - "header:X-CSRF-Token,query:csrf" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - } -) +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler func(c echo.Context, err error) error +} // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper @@ -97,6 +102,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -113,9 +121,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - extractors, err := createExtractors(config.TokenLookup, "") - if err != nil { - panic(err) + extractors, cErr := createExtractors(config.TokenLookup) + if cErr != nil { + return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -126,7 +134,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = config.Generator() // Generate token } else { token = k.Value // Reuse token } @@ -139,7 +147,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { var lastTokenErr error outer: for _, extractor := range extractors { - clientTokens, err := extractor(c) + clientTokens, _, err := extractor(c) if err != nil { lastExtractorErr = err continue @@ -154,20 +162,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { lastTokenErr = ErrCSRFInvalid } } + var finalErr error if lastTokenErr != nil { - return lastTokenErr + finalErr = lastTokenErr } else if lastExtractorErr != nil { - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") - } else if lastExtractorErr == errFormExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") - } else { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) + finalErr = echo.ErrBadRequest.WithInternal(lastExtractorErr) + } + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(c, finalErr) } - return lastExtractorErr + return finalErr } } @@ -197,7 +202,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } func validateCSRFToken(token, clientToken string) bool { diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..de97cd6c5 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -7,22 +7,22 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCSRF_tokenExtractors(t *testing.T) { var testCases = []struct { - name string - whenTokenLookup string - whenCookieName string - givenCSRFCookie string - givenMethod string - givenQueryTokens map[string][]string - givenFormTokens map[string][]string - givenHeaderTokens map[string][]string - expectError string + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + expectToMiddlewareError string }{ { name: "ok, multiple token lookups sources, succeeds on last one", @@ -70,7 +70,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the form parameter", + expectError: "code=400, message=Bad Request, internal=missing value in the form", }, { name: "ok, token from POST header", @@ -106,7 +106,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in request header", + expectError: "code=400, message=Bad Request, internal=missing value in request header", }, { name: "ok, token from PUT query param", @@ -142,7 +142,15 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the query string", + expectError: "code=400, message=Bad Request, internal=missing value in the query string", + }, + { + name: "nok, invalid TokenLookup", + whenTokenLookup: "q", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, } @@ -186,16 +194,23 @@ func TestCSRF_tokenExtractors(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + config := CSRFConfig{ TokenLookup: tc.whenTokenLookup, CookieName: tc.whenCookieName, - }) + } + csrf, err := config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) + return + } else if err != nil { + assert.NoError(t, err) + } h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) - err := h(c) + err = h(c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -219,6 +234,24 @@ func TestCSRF(t *testing.T) { h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") +} + +func TestMustCSRFWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + csrf := CSRFWithConfig(CSRFConfig{ + TokenLength: 16, + }) + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Generate CSRF token + h(c) + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") + // Without CSRF cookie req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() @@ -233,7 +266,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(16) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -302,9 +335,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, - }) + }.ToMiddleware() + assert.NoError(t, err) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -358,3 +392,25 @@ func TestCSRFConfig_skipper(t *testing.T) { }) } } + +func TestCSRFErrorHandling(t *testing.T) { + cfg := CSRFConfig{ + ErrorHandler: func(c echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + } + + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) +} diff --git a/middleware/decompress.go b/middleware/decompress.go index 88ec70982..dcf7172fa 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -6,21 +6,19 @@ import ( "net/http" "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} -//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers @@ -28,14 +26,6 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) - // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } @@ -44,19 +34,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} } -//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DefaultDecompressConfig) + return DecompressWithConfig(DecompressConfig{}) } -//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + config.GzipDecompressPool = &DefaultGzipDecompressPool{} } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -95,5 +89,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f1..53b51e24e 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -4,19 +4,44 @@ import ( "bytes" "compress/gzip" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" "sync" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + // Decompress request body + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) +} + +func TestDecompress_skippedIfNoHeader(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -26,52 +51,55 @@ func TestDecompress(t *testing.T) { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) - // Decompress - body := `{"name": "echo"}` - gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) } -func TestDecompressDefaultConfig(t *testing.T) { +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil - }) - h(c) + })(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) + +} - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { @@ -82,9 +110,11 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.NotEqual(t, b, body) assert.Equal(t, b, gz) @@ -99,7 +129,10 @@ func TestDecompressNoContent(t *testing.T) { h := Decompress()(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) - if assert.NoError(t, h(c)) { + + err := h(c) + + if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -115,7 +148,9 @@ func TestDecompressErrorReturned(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -132,9 +167,11 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) } @@ -161,9 +198,11 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) assert.Equal(t, rec.Code, http.StatusInternalServerError) diff --git a/middleware/extractor.go b/middleware/extractor.go index afdfd8195..a9343c1ba 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -1,9 +1,8 @@ package middleware import ( - "errors" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/textproto" "strings" ) @@ -14,17 +13,63 @@ const ( extractorLimit = 20 ) -var errHeaderExtractorValueMissing = errors.New("missing value in request header") -var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") -var errQueryExtractorValueMissing = errors.New("missing value in the query string") -var errParamExtractorValueMissing = errors.New("missing value in path params") -var errCookieExtractorValueMissing = errors.New("missing value in cookies") -var errFormExtractorValueMissing = errors.New("missing value in the form") +// ExtractorSource is type to indicate source for extracted value +type ExtractorSource string + +const ( + // ExtractorSourceHeader means value was extracted from request header + ExtractorSourceHeader ExtractorSource = "header" + // ExtractorSourceQuery means value was extracted from request query parameters + ExtractorSourceQuery ExtractorSource = "query" + // ExtractorSourcePathParam means value was extracted from route path parameters + ExtractorSourcePathParam ExtractorSource = "param" + // ExtractorSourceCookie means value was extracted from request cookies + ExtractorSourceCookie ExtractorSource = "cookie" + // ExtractorSourceForm means value was extracted from request form values + ExtractorSourceForm ExtractorSource = "form" +) + +// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups +type ValueExtractorError struct { + message string +} + +// Error returns errors text +func (e *ValueExtractorError) Error() string { + return e.message +} + +var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} +var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} +var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} +var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} +var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} +var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. -type ValuesExtractor func(c echo.Context) ([]string, error) +type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error) -func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { +// CreateExtractors creates ValuesExtractors from given lookups. +// Lookups is a string in the form of ":" or ":,:" that is used +// to extract key from the request. +// Possible values: +// - "header:" or "header::" +// `` is argument value to cut/trim prefix of the extracted value. This is useful if header +// value has static prefix like `Authorization: ` where part that we +// want to cut is ` ` note the space at the end. +// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. +// - "query:" +// - "param:" +// - "form:" +// - "cookie:" +// +// Multiple sources example: +// - "header:Authorization,header:X-Api-Key" +func CreateExtractors(lookups string) ([]ValuesExtractor, error) { + return createExtractors(lookups) +} + +func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } @@ -49,15 +94,6 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err prefix := "" if len(parts) > 2 { prefix = parts[2] - } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { - // backwards compatibility for JWT and KeyAuth: - // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc - // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that - // behaviour for default values and Authorization header. - prefix = authScheme - if !strings.HasSuffix(prefix, " ") { - prefix += " " - } } extractors = append(extractors, valuesFromHeader(parts[1], prefix)) } @@ -75,10 +111,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { values := c.Request().Header.Values(header) if len(values) == 0 { - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } result := make([]string, 0) @@ -100,53 +136,52 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { if len(result) == 0 { if prefixLen > 0 { - return nil, errHeaderExtractorValueInvalid + return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid } - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } - return result, nil + return result, ExtractorSourceHeader, nil } } // valuesFromQuery returns a function that extracts values from the query string. func valuesFromQuery(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := c.QueryParams()[param] if len(result) == 0 { - return nil, errQueryExtractorValueMissing + return nil, ExtractorSourceQuery, errQueryExtractorValueMissing } else if len(result) > extractorLimit-1 { result = result[:extractorLimit] } - return result, nil + return result, ExtractorSourceQuery, nil } } // valuesFromParam returns a function that extracts values from the url param string. func valuesFromParam(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := make([]string, 0) - paramVales := c.ParamValues() - for i, p := range c.ParamNames() { - if param == p { - result = append(result, paramVales[i]) + for i, p := range c.PathParams() { + if param == p.Name { + result = append(result, p.Value) if i >= extractorLimit-1 { break } } } if len(result) == 0 { - return nil, errParamExtractorValueMissing + return nil, ExtractorSourcePathParam, errParamExtractorValueMissing } - return result, nil + return result, ExtractorSourcePathParam, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. func valuesFromCookie(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { cookies := c.Cookies() if len(cookies) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } result := make([]string, 0) @@ -159,26 +194,26 @@ func valuesFromCookie(name string) ValuesExtractor { } } if len(result) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } - return result, nil + return result, ExtractorSourceCookie, nil } } // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { if c.Request().Form == nil { _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { - return nil, errFormExtractorValueMissing + return nil, ExtractorSourceForm, errFormExtractorValueMissing } if len(values) > extractorLimit-1 { values = values[:extractorLimit] } result := append([]string{}, values...) - return result, nil + return result, ExtractorSourceForm, nil } } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 2e898f541..7b8b3d4ff 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -3,7 +3,7 @@ package middleware import ( "bytes" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "mime/multipart" "net/http" @@ -13,29 +13,14 @@ import ( "testing" ) -type pathParam struct { - name string - value string -} - -func setPathParams(c echo.Context, params []pathParam) { - names := make([]string, 0, len(params)) - values := make([]string, 0, len(params)) - for _, pp := range params { - names = append(names, pp.name) - values = append(values, pp.value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) -} - func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request - givenPathParams []pathParam + givenPathParams echo.PathParams whenLoopups string expectValues []string + expectSource ExtractorSource expectCreateError string expectError string }{ @@ -48,6 +33,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "header:Authorization:Bearer ", expectValues: []string{"token"}, + expectSource: ExtractorSourceHeader, }, { name: "ok, form", @@ -61,6 +47,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "form:name", expectValues: []string{"Jon Snow"}, + expectSource: ExtractorSourceForm, }, { name: "ok, cookie", @@ -71,14 +58,16 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "cookie:_csrf", expectValues: []string{"token"}, + expectSource: ExtractorSourceCookie, }, { name: "ok, param", - givenPathParams: []pathParam{ - {name: "id", value: "123"}, + givenPathParams: echo.PathParams{ + {Name: "id", Value: "123"}, }, whenLoopups: "param:id", expectValues: []string{"123"}, + expectSource: ExtractorSourcePathParam, }, { name: "ok, query", @@ -88,6 +77,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "query:id", expectValues: []string{"999"}, + expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", @@ -105,12 +95,12 @@ func TestCreateExtractors(t *testing.T) { req = tc.givenRequest() } rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } - extractors, err := createExtractors(tc.whenLoopups, "") + extractors, err := CreateExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return @@ -118,8 +108,9 @@ func TestCreateExtractors(t *testing.T) { assert.NoError(t, err) for _, e := range extractors { - values, eErr := e(c) + values, source, eErr := e(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return @@ -244,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) { extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -305,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) { extractor := valuesFromQuery(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -317,19 +310,19 @@ func TestValuesFromQuery(t *testing.T) { } func TestValuesFromParam(t *testing.T) { - examplePathParams := []pathParam{ - {name: "id", value: "123"}, - {name: "gid", value: "456"}, - {name: "gid", value: "789"}, + examplePathParams := echo.PathParams{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, } - examplePathParams20 := make([]pathParam, 0) + examplePathParams20 := make(echo.PathParams, 0) for i := 1; i < 25; i++ { - examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + examplePathParams20 = append(examplePathParams20, echo.PathParam{Name: "id", Value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string - givenPathParams []pathParam + givenPathParams echo.PathParams whenName string expectValues []string expectError string @@ -377,15 +370,16 @@ func TestValuesFromParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } extractor := valuesFromParam(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -464,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) { extractor := valuesFromCookie(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -594,8 +589,9 @@ func TestValuesFromForm(t *testing.T) { extractor := valuesFromForm(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceForm, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/middleware/jwt.go b/middleware/jwt.go deleted file mode 100644 index bec5167e2..000000000 --- a/middleware/jwt.go +++ /dev/null @@ -1,300 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package middleware - -import ( - "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" - "net/http" - "reflect" -) - -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next - // middleware or handler. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. - ContinueOnIgnoredError bool - - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. - // If prefix is left empty the whole value is returned. - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiple sources example: - // - "header:Authorization,cookie:myowncookie" - TokenLookup string - - // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. - // This is one of the two options to provide a token extractor. - // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. - // You can also provide both if you want. - TokenLookupFuncs []ValuesExtractor - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } - - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(c echo.Context) - - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(err error) error - - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(err error, c echo.Context) error -) - -// Algorithms -const ( - AlgorithmHS256 = "HS256" -) - -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) - -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - TokenLookupFuncs: nil, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) - -// JWT returns a JSON Web Token (JWT) auth middleware. -// -// For valid token, it sets the user in context and calls next handler. -// For invalid token, it returns "401 - Unauthorized" error. -// For missing token, it returns "400 - Bad Request" error. -// -// See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -func JWT(key interface{}) echo.MiddlewareFunc { - c := DefaultJWTConfig - c.SigningKey = key - return JWTWithConfig(c) -} - -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. -func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultJWTConfig.Skipper - } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod - } - if config.ContextKey == "" { - config.ContextKey = DefaultJWTConfig.ContextKey - } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } - if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { - config.TokenLookup = DefaultJWTConfig.TokenLookup - } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - if config.KeyFunc == nil { - config.KeyFunc = config.defaultKeyFunc - } - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) - if err != nil { - panic(err) - } - if len(config.TokenLookupFuncs) > 0 { - extractors = append(config.TokenLookupFuncs, extractors...) - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - if config.BeforeFunc != nil { - config.BeforeFunc(c) - } - - var lastExtractorErr error - var lastTokenErr error - for _, extractor := range extractors { - auths, err := extractor(c) - if err != nil { - lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) - continue - } - for _, auth := range auths { - token, err := config.ParseTokenFunc(auth, c) - if err != nil { - lastTokenErr = err - continue - } - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) - } - return next(c) - } - } - // we are here only when we did not successfully extract or parse any of the tokens - err := lastTokenErr - if err == nil { // prioritize token errors over extracting errors - err = lastExtractorErr - } - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - tmpErr := config.ErrorHandlerWithContext(err, c) - if config.ContinueOnIgnoredError && tmpErr == nil { - return next(c) - } - return tmpErr - } - - // backwards compatible errors codes - if lastTokenErr != nil { - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, - } - } - return err // this is lastExtractorErr value - } - } -} - -func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - token := new(jwt.Token) - var err error - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil -} - -// defaultKeyFunc returns a signing key of the given token. -func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil -} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go deleted file mode 100644 index eee9df966..000000000 --- a/middleware/jwt_test.go +++ /dev/null @@ -1,779 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package middleware - -import ( - "errors" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -// jwtCustomInfo defines some custom types we're going to use within our tokens. -type jwtCustomInfo struct { - Name string `json:"name"` - Admin bool `json:"admin"` -} - -// jwtCustomClaims are custom claims expanding default ones. -type jwtCustomClaims struct { - *jwt.StandardClaims - jwtCustomInfo -} - -func TestJWT(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - e.Use(JWT([]byte("secret"))) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) -} - -func TestJWTRace(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWTConfig(t *testing.T) { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - validKey := []byte("secret") - invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token - - testCases := []struct { - name string - expPanic bool - expErrCode int // 0 for Success - config JWTConfig - reqURL string // "/" if empty - hdrAuth string - hdrCookie string // test.Request doesn't provide SetCookie(); use name=val - formValues map[string]string - }{ - { - name: "No signing key provided", - expPanic: true, - }, - { - name: "Unexpected signing method", - expErrCode: http.StatusBadRequest, - config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", - }, - }, - { - name: "Invalid key", - expErrCode: http.StatusUnauthorized, - hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, - }, - { - name: "Valid JWT", - hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, - }, - { - name: "Valid JWT with custom AuthScheme", - hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - }, - { - name: "Valid JWT with custom claims", - hdrAuth: validAuth, - config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), - }, - }, - { - name: "Invalid Authorization header", - hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, - }, - { - name: "Empty header auth field", - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid query method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=" + token, - }, - { - name: "Invalid query param name", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, - }, - { - name: "Invalid query param value", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=invalid-token", - expErrCode: http.StatusUnauthorized, - }, - { - name: "Empty query", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid param method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", - }, - reqURL: "/" + token, - }, - { - name: "Valid cookie method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - hdrCookie: "jwt=" + token, - }, - { - name: "Multiple jwt lookuop", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt,cookie:jwt", - }, - hdrCookie: "jwt=" + token, - }, - { - name: "Invalid token with cookie method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusUnauthorized, - hdrCookie: "jwt=invalid", - }, - { - name: "Empty cookie", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid form method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - formValues: map[string]string{"jwt": token}, - }, - { - name: "Invalid token with form method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusUnauthorized, - formValues: map[string]string{"jwt": "invalid"}, - }, - { - name: "Empty form field", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid JWT with a valid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return validKey, nil - }, - }, - }, - { - name: "Valid JWT with an invalid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return invalidKey, nil - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Token verification does not pass using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return nil, errors.New("faulty KeyFunc") - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Valid JWT with lower case AuthScheme", - hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, - config: JWTConfig{SigningKey: validKey}, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - if tc.reqURL == "" { - tc.reqURL = "/" - } - - var req *http.Request - if len(tc.formValues) > 0 { - form := url.Values{} - for k, v := range tc.formValues { - form.Set(k, v) - } - req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) - req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") - req.ParseForm() - } else { - req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - } - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) - - if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } - - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.name) - return - } - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.name) - return - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.name) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.name) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.name) - assert.Equal(t, claims.Admin, true, tc.name) - default: - panic("unexpected type of claims") - } - } - }) - } -} - -func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") - } - } - } -} - -func TestJWTConfig_skipper(t *testing.T) { - e := echo.New() - - e.Use(JWTWithConfig(JWTConfig{ - Skipper: func(context echo.Context) bool { - return true // skip everything - }, - SigningKey: []byte("secret"), - })) - - isCalled := false - e.GET("/", func(c echo.Context) error { - isCalled = true - return c.String(http.StatusTeapot, "test") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.True(t, isCalled) -} - -func TestJWTConfig_BeforeFunc(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusTeapot, "test") - }) - - isCalled := false - e.Use(JWTWithConfig(JWTConfig{ - BeforeFunc: func(context echo.Context) { - isCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.True(t, isCalled) -} - -func TestJWTConfig_extractorErrorHandling(t *testing.T) { - var testCases = []struct { - name string - given JWTConfig - expectStatusCode int - }{ - { - name: "ok, ErrorHandler is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusNotImplemented, "should not end up here") - }) - - e.Use(JWTWithConfig(tc.given)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectStatusCode, res.Code) - }) - } -} - -func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { - var testCases = []struct { - name string - given JWTConfig - expectErr string - }{ - { - name: "ok, ErrorHandler is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - //e.Debug = true - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusNotImplemented, "should not end up here") - }) - - config := tc.given - parseTokenCalled := false - config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { - parseTokenCalled = true - return nil, errors.New("parsing failed") - } - e.Use(JWTWithConfig(config)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.Equal(t, tc.expectErr, res.Body.String()) - assert.True(t, parseTokenCalled) - }) - } -} - -func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusTeapot, "test") - }) - - // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/golang-jwt/jwt` - // with current JWT middleware - signingKey := []byte("secret") - - config := JWTConfig{ - ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { - keyFunc := func(t *jwt.Token) (interface{}, error) { - if t.Method.Alg() != "HS256" { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - return signingKey, nil - } - - // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` - token, err := jwt.Parse(auth, keyFunc) - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil - }, - } - - e.Use(JWTWithConfig(config)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) -} - -func TestJWTConfig_TokenLookupFuncs(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil - }, - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) -} - -func TestJWTConfig_SuccessHandler(t *testing.T) { - var testCases = []struct { - name string - givenToken string - expectCalled bool - expectStatus int - }{ - { - name: "ok, success handler is called", - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectCalled: true, - expectStatus: http.StatusOK, - }, - { - name: "nok, success handler is not called", - givenToken: "x.x.x", - expectCalled: false, - expectStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - wasCalled := false - e.Use(JWTWithConfig(JWTConfig{ - SuccessHandler: func(c echo.Context) { - wasCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectCalled, wasCalled) - assert.Equal(t, tc.expectStatus, res.Code) - }) - } -} - -func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { - var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenToken string - expectStatus int - expectBody string - }{ - { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectStatus: http.StatusTeapot, - expectBody: "", - }, - { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenToken: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", - }, - { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenToken: "", - expectStatus: http.StatusTeapot, - expectBody: "public-token", - }, - { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenToken: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) - - e.Use(JWTWithConfig(JWTConfig{ - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, c echo.Context) error { - if err == ErrJWTMissing { - c.Set("test", "public-token") - return nil - } - return echo.ErrUnauthorized - }, - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenToken != "" { - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - } - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } -} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index e8a6b0853..3964381ac 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,81 +2,69 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "fmt" + "github.com/labstack/echo/v5" "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" or ":,:" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. - // - "query:" - // - "form:" - // - "cookie:" - // Multiple sources example: - // - "header:Authorization,header:X-Api-Key" - KeyLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandler to set a default public key auth value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. - ContinueOnIgnoredError bool - } +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(c echo.Context, key string, source ExtractorSource) (bool, error) - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(err error, c echo.Context) error -) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(c echo.Context, err error) error -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") -// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups -type ErrKeyAuthMissing struct { - Err error -} +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") -// Error returns errors text -func (e *ErrKeyAuthMissing) Error() string { - return e.Err.Error() -} - -// Unwrap unwraps error -func (e *ErrKeyAuthMissing) Unwrap() error { - return e.Err +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", } // KeyAuth returns an KeyAuth middleware. @@ -90,27 +78,33 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) - if err != nil { - panic(err) + extractors, cErr := createExtractors(config.KeyLookup) + if cErr != nil { + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -122,59 +116,41 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, err := extractor(c) - if err != nil { - lastExtractorErr = err + keys, source, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, key := range keys { - valid, err := config.Validator(key, c) + valid, err := config.Validator(c, key, source) if err != nil { lastValidatorErr = err continue } - if valid { - return next(c) + if !valid { + lastValidatorErr = ErrInvalidKey + continue } - lastValidatorErr = errors.New("invalid key") + return next(c) } } - // we are here only when we did not successfully extract and validate any of keys + // prioritize validator errors over extracting errors err := lastValidatorErr - if err == nil { // prioritize validator errors over extracting errors - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - err = errors.New("missing key in the query string") - } else if lastExtractorErr == errCookieExtractorValueMissing { - err = errors.New("missing key in cookies") - } else if lastExtractorErr == errFormExtractorValueMissing { - err = errors.New("missing key in the form") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - err = errors.New("missing key in request header") - } else if lastExtractorErr == errHeaderExtractorValueInvalid { - err = errors.New("invalid key in the request header") - } else { - err = lastExtractorErr - } - err = &ErrKeyAuthMissing{Err: err} + if err == nil { + err = lastExtractorErr } - if config.ErrorHandler != nil { - tmpErr := config.ErrorHandler(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - if lastValidatorErr != nil { // prioritize validator errors over extracting errors - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "Unauthorized", - Internal: lastValidatorErr, - } + if lastValidatorErr == nil { + return ErrKeyMissing.WithInternal(err) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + return echo.ErrUnauthorized.WithInternal(err) } - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index ff8968c38..fa182e6c3 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -7,11 +7,11 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func testKeyValidator(key string, c echo.Context) (bool, error) { +func testKeyValidator(c echo.Context, key string, source ExtractorSource) (bool, error) { switch key { case "valid-key": return true, nil @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, internal=invalid key", + expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -84,24 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=400, message=invalid key in the request header", + expectError: "code=401, message=missing key, internal=invalid value in request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", - }, - { - name: "ok, custom key lookup from multiple places, query and header", - givenRequest: func(req *http.Request) { - req.URL.RawQuery = "key=invalid-key" - req.Header.Set("API-Key", "valid-key") - }, - whenConfig: func(conf *KeyAuthConfig) { - conf.KeyLookup = "query:key,header:API-Key" - }, - expectHandlerCalled: true, + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, header", @@ -121,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, query", @@ -141,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the query string", + expectError: "code=401, message=missing key, internal=missing value in the query string", }, { name: "ok, custom key lookup, form", @@ -166,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the form", + expectError: "code=401, message=missing key, internal=missing value in the form", }, { name: "ok, custom key lookup, cookie", @@ -190,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies", + expectError: "code=401, message=missing key, internal=missing value in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=missing key in request header", + expectError: "code=418, message=custom, internal=missing value in request header", }, { name: "nok, custom errorHandler, error from validator", @@ -211,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError @@ -229,6 +218,25 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, internal=some user defined error", }, + { + name: "ok, custom validator checks source", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + conf.Validator = func(c echo.Context, key string, source ExtractorSource) (bool, error) { + if source == ExtractorSourceQuery { + return true, nil + } + return false, errors.New("invalid source") + } + + }, + expectHandlerCalled: true, + }, } for _, tc := range testCases { @@ -269,108 +277,96 @@ func TestKeyAuthWithConfig(t *testing.T) { } } -func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { - assert.PanicsWithError( - t, - "extractor source for lookup could not be split into needed parts: a", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - KeyLookup: "a", - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { - assert.PanicsWithValue( - t, - "echo: key-auth middleware requires a validator function", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: nil, - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { +func TestKeyAuthWithConfig_errors(t *testing.T) { var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenKey string - expectStatus int - expectBody string + name string + whenConfig KeyAuthConfig + expectError string }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenKey: "valid-key", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenKey: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenKey: "", - expectStatus: http.StatusTeapot, - expectBody: "public-auth", + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenKey: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) + }) +} - e.Use(KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - ErrorHandler: func(err error, c echo.Context) error { - if _, ok := err.(*ErrKeyAuthMissing); ok { - c.Set("test", "public-auth") - return nil - } - return echo.ErrUnauthorized - }, - KeyLookup: "header:X-API-Key", - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - })) +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + ContinueOnIgnoredError: true, + })(handler) - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenKey != "" { - req.Header.Set("X-API-Key", tc.givenKey) - } - res := httptest.NewRecorder() + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - e.ServeHTTP(res, req) + err := middlewareChain(c) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) } diff --git a/middleware/logger.go b/middleware/logger.go index 9baac4769..6039003a5 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,88 +3,96 @@ package middleware import ( "bytes" "encoding/json" + "errors" + "fmt" "io" "strconv" "strings" "sync" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" + "github.com/labstack/echo/v5" "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) +// LoggerConfig defines the config for Logger middleware. +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) + // Tags to construct the logger format. + // + // - time_unix + // - time_unix_milli + // - time_unix_micro + // - time_unix_nano + // - time_rfc3339 + // - time_rfc3339_nano + // - time_custom + // - id (Request ID) + // - remote_ip + // - uri + // - host + // - method + // - path + // - route + // - protocol + // - referer + // - user_agent + // - status + // - error + // - latency (In nanoseconds) + // - latency_human (Human readable) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) + // - header: + // - query: + // - form: + // - custom (see CustomTagFunc field) + // + // Example "${remote_ip} ${status}" + // + // Optional. Default value DefaultLoggerConfig.Format. + Format string + + // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. + CustomTimeFormat string + + // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. + // Make sure that outputted text creates valid JSON string with other logged tags. + // Optional. + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + + // Output is a writer where logs in JSON format are written. + // Optional. Default destination `echo.Logger.Infof()` + Output io.Writer + + template *fasttemplate.Template + pool *sync.Pool +} + +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","level":"INFO","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", +} // Logger returns a middleware that logs HTTP requests. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. +// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration +func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultLoggerConfig.Skipper @@ -92,13 +100,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if config.Format == "" { config.Format = DefaultLoggerConfig.Format } - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) config.pool = &sync.Pool{ New: func() interface{} { return bytes.NewBuffer(make([]byte, 256)) @@ -106,34 +109,48 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() + start := time.Now() - if err = next(c); err != nil { + err := next(c) + if err != nil { + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. c.Error(err) } stop := time.Now() + buf := config.pool.Get().(*bytes.Buffer) buf.Reset() defer config.pool.Put(buf) - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { + _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { + case "custom": + if config.CustomTagFunc == nil { + return 0, nil + } + return config.CustomTagFunc(c, buf) case "time_unix": - return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) + return buf.WriteString(strconv.FormatInt(stop.Unix(), 10)) + case "time_unix_milli": + return buf.WriteString(strconv.FormatInt(stop.UnixMilli(), 10)) + case "time_unix_micro": + return buf.WriteString(strconv.FormatInt(stop.UnixMicro(), 10)) case "time_unix_nano": - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) + return buf.WriteString(strconv.FormatInt(stop.UnixNano(), 10)) case "time_rfc3339": - return buf.WriteString(time.Now().Format(time.RFC3339)) + return buf.WriteString(stop.Format(time.RFC3339)) case "time_rfc3339_nano": - return buf.WriteString(time.Now().Format(time.RFC3339Nano)) + return buf.WriteString(stop.Format(time.RFC3339Nano)) case "time_custom": - return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) + return buf.WriteString(stop.Format(config.CustomTimeFormat)) case "id": id := req.Header.Get(echo.HeaderXRequestID) if id == "" { @@ -154,6 +171,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { p = "/" } return buf.WriteString(p) + case "route": + return buf.WriteString(c.Path()) case "protocol": return buf.WriteString(req.Proto) case "referer": @@ -161,17 +180,14 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "user_agent": return buf.WriteString(req.UserAgent()) case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) + status := res.Status + if err != nil { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { + status = httpErr.Code + } } - return buf.WriteString(s) + return buf.WriteString(strconv.Itoa(status)) case "error": if err != nil { // Error may contain invalid JSON e.g. `"` @@ -201,23 +217,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case strings.HasPrefix(tag, "form:"): return buf.Write([]byte(c.FormValue(tag[5:]))) case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { + cookie, cookieErr := c.Cookie(tag[7:]) + if cookieErr == nil { return buf.Write([]byte(cookie.Value)) } } } return 0, nil - }); err != nil { - return + }) + if tmplErr != nil { + if err != nil { + return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err) + } + return fmt.Errorf("failed to create log from template: %w", tmplErr) } - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return + if config.Output != nil { + if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil { + return lErr + } + } else { + if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil { + return lErr + } } - _, err = config.Output.Write(buf.Bytes()) - return + return err } - } + }, nil } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 394f62712..d311da15f 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -7,12 +7,13 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" "strings" "testing" "time" "unsafe" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -61,7 +62,7 @@ func TestLoggerIPAddress(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} ip := "127.0.0.1" h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -91,17 +92,17 @@ func TestLoggerTemplate(t *testing.T) { e.Use(LoggerWithConfig(LoggerConfig{ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", Output: buf, })) - e.GET("/", func(c echo.Context) error { + e.GET("/users/:id", func(c echo.Context) error { return c.String(http.StatusOK, "Header Logged") }) - req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) + req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil) req.RequestURI = "/" req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") req.Header.Add("Referer", "google.com") @@ -126,7 +127,8 @@ func TestLoggerTemplate(t *testing.T) { "hexvalue": false, "GET": true, "127.0.0.1": true, - "\"path\":\"/\"": true, + "\"path\":\"/users/1\"": true, + "\"route\":\"/users/:id\"": true, "\"uri\":\"/\"": true, "\"status\":200": true, "\"bytes_in\":0": true, @@ -172,6 +174,52 @@ func TestLoggerCustomTimestamp(t *testing.T) { assert.Error(t, err) } +func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_milli}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMillis, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second) +} + +func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_micro}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMicros, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second) +} + func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { e := echo.New() @@ -244,3 +292,25 @@ func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { buf.Reset() } } + +func TestLoggerCustomTagFunc(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"method":"${method}",${custom}}` + "\n", + CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + return buf.WriteString(`"tag":"my-value"`) + }, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "custom time stamp test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String()) +} diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..202862f3b 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -3,31 +3,27 @@ package middleware import ( "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and @@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..4ca10b84e 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -22,28 +22,68 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index f250ca49a..0f99d6d6c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -6,17 +6,14 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(c echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -38,9 +35,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*?)", -1) + k = strings.ReplaceAll(k, `\*`, "(.*?)") if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) + k = strings.ReplaceAll(k, `\^`, "^") } k = k + "$" rulesRegex[regexp.MustCompile(k)] = v @@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error func DefaultSkipper(echo.Context) bool { return false } + +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731e..d1183d6f4 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "math/rand" @@ -12,104 +13,127 @@ import ( "regexp" "strings" "sync" - "sync/atomic" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) (*ProxyTarget, error) +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.RWMutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.Mutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - *commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - *commonBalancer - i uint32 - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + commonBalancer + // tracking the index on `targets` slice for the next `*ProxyTarget` to be used + i int +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", - } -) +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { +func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { - c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() @@ -117,7 +141,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { // Write header err = r.Write(out) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } @@ -131,39 +155,44 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { go cp(in, out) err = <-errCh if err != nil && err != io.EOF { - c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL)) } }) } // NewRandomBalancer returns a random proxy balancer. func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &randomBalancer{commonBalancer: new(commonBalancer)} + b := randomBalancer{} b.targets = targets - return b + b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + return &b } // NewRoundRobinBalancer returns a round-robin proxy balancer. func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &roundRobinBalancer{commonBalancer: new(commonBalancer)} + b := roundRobinBalancer{} b.targets = targets - return b + return &b } -// AddTarget adds an upstream target to the list. +// AddTarget adds an upstream target to the list and returns `true`. +// +// However, if a target with the same name already exists then the operation is aborted returning `false`. func (b *commonBalancer) AddTarget(target *ProxyTarget) bool { + b.mutex.Lock() + defer b.mutex.Unlock() for _, t := range b.targets { if t.Name == target.Name { return false } } - b.mutex.Lock() - defer b.mutex.Unlock() b.targets = append(b.targets, target) return true } -// RemoveTarget removes an upstream target from the list. +// RemoveTarget removes an upstream target from the list by name. +// +// Returns `true` on success, `false` if no target with the name is found. func (b *commonBalancer) RemoveTarget(name string) bool { b.mutex.Lock() defer b.mutex.Unlock() @@ -177,21 +206,57 @@ func (b *commonBalancer) RemoveTarget(name string) bool { } // Next randomly returns an upstream target. -func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { - if b.random == nil { - b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) +// +// Note: `nil` is returned in case upstream target list is empty. +func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil, nil + } else if len(b.targets) == 1 { + return b.targets[0], nil } - b.mutex.RLock() - defer b.mutex.RUnlock() - return b.targets[b.random.Intn(len(b.targets))] + return b.targets[b.random.Intn(len(b.targets))], nil } -// Next returns an upstream target using round-robin technique. -func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { - b.i = b.i % uint32(len(b.targets)) - t := b.targets[b.i] - atomic.AddUint32(&b.i, 1) - return t +// Next returns an upstream target using round-robin technique. In the case +// where a previously failed request is being retried, the round-robin +// balancer will attempt to use the next target relative to the original +// request. If the list of targets held by the balancer is modified while a +// failed request is being retried, it is possible that the balancer will +// return the original failed target. +// +// Note: `nil` is returned in case upstream target list is empty. +func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil, nil + } else if len(b.targets) == 1 { + return b.targets[0], nil + } + + var i int + const lastIdxKey = "_round_robin_last_index" + // This request is a retry, start from the index of the previous + // target to ensure we don't attempt to retry the request with + // the same failed target + if c.Get(lastIdxKey) != nil { + i = c.Get(lastIdxKey).(int) + i++ + if i >= len(b.targets) { + i = 0 + } + } else { + // This is a first time request, use the global index + if b.i >= len(b.targets) { + b.i = 0 + } + i = b.i + b.i++ + } + c.Set(lastIdxKey, i) + return b.targets[i], nil } // Proxy returns a Proxy middleware. @@ -203,15 +268,36 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + return nil, errors.New("echo proxy middleware requires balancer") + } + if config.RetryFilter == nil { + config.RetryFilter = func(c echo.Context, e error) bool { + if httpErr, ok := e.(*echo.HTTPError); ok { + return httpErr.Code == http.StatusBadGateway + } + return false + } + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c echo.Context, err error) error { + return err + } } if config.Rewrite != nil { @@ -231,11 +317,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - tgt := config.Balancer.Next(c) - c.Set(config.ContextKey, tgt) - if err := rewriteURL(config.RegexRewrite, req); err != nil { - return err + return config.ErrorHandler(c, err) } // Fix header @@ -251,21 +334,45 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } - // Proxy - switch { - case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) - } - if e, ok := c.Get("_error").(error); ok { - err = e + retries := config.RetryCount + for { + tgt, err := config.Balancer.Next(c) + if err != nil { + return config.ErrorHandler(c, err) + } + + c.Set(config.ContextKey, tgt) + + //If retrying a failed request, clear any previous errors from + //context here so that balancers have the option to check for + //errors that occurred using previous target + if retries < config.RetryCount { + c.Set("_error", nil) + } + + // Proxy + switch { + case c.IsWebSocket(): + proxyRaw(c, tgt).ServeHTTP(res, req) + case req.Header.Get(echo.HeaderAccept) == "text/event-stream": + default: + proxyHTTP(c, tgt, config).ServeHTTP(res, req) + } + + err, hasError := c.Get("_error").(error) + if !hasError { + return nil + } + + retry := retries > 0 && config.RetryFilter(c, err) + if !retry { + return config.ErrorHandler(c, err) + } + + retries-- } - - return } - } + }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -275,7 +382,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c2..c17328408 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -3,8 +3,9 @@ package middleware import ( "bytes" "context" + "errors" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -14,11 +15,11 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -//Assert expected with url.EscapedPath method to obtain the path. +// Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -55,7 +56,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -77,7 +78,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -94,7 +95,7 @@ func TestProxy(t *testing.T) { e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, ModifyResponse: func(res *http.Response) error { - res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified"))) res.Header.Set("X-Modified", "1") return nil }, @@ -113,15 +114,20 @@ func TestProxy(t *testing.T) { return nil } } - rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -129,7 +135,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -334,12 +340,11 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() // Remote unreachable - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() req.URL.Path = "/api/users" e.ServeHTTP(rec, req) assert.Equal(t, "/api/users", req.URL.Path) @@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { rb := NewRandomBalancer(nil) assert.True(t, rb.AddTarget(target)) e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(req.Context()) @@ -375,3 +380,389 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { timeoutStop.Done() assert.Equal(t, 499, rec.Code) } + +type testProvider struct { + commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} + +func TestRandomBalancerWithNoTargets(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Assert balancer with empty targets does return `nil` on `Next()` + rb := NewRandomBalancer(nil) + target, err := rb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} + +func TestRoundRobinBalancerWithNoTargets(t *testing.T) { + // Assert balancer with empty targets does return `nil` on `Next()` + rrb := NewRoundRobinBalancer([]*ProxyTarget{}) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + target, err := rrb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} + +func TestProxyRetries(t *testing.T) { + newServer := func(res int) (*url.URL, *httptest.Server) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(res) + }), + ) + targetURL, _ := url.Parse(server.URL) + return targetURL, server + } + + targetURL, server := newServer(http.StatusOK) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: targetURL, + } + + targetURL, server = newServer(http.StatusBadRequest) + defer server.Close() + goodTargetWith40X := &ProxyTarget{ + Name: "Good with 40X", + URL: targetURL, + } + + targetURL, _ = url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: targetURL, + } + + alwaysRetryFilter := func(c echo.Context, e error) bool { return true } + neverRetryFilter := func(c echo.Context, e error) bool { return false } + + testCases := []struct { + name string + retryCount int + retryFilters []func(c echo.Context, e error) bool + targets []*ProxyTarget + expectedResponse int + }{ + { + name: "retry count 0 does not attempt retry on fail", + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 1 does not attempt retry on success", + retryCount: 1, + targets: []*ProxyTarget{ + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does retry on handler return true", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does not retry on handler return false", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when no more retries left", + retryCount: 2, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as only 2 retries + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when retries left but handler returns false", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as retry handler returns false on 2nd check + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 3 succeeds", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "40x responses are not retried", + retryCount: 1, + targets: []*ProxyTarget{ + goodTargetWith40X, + goodTarget, + }, + expectedResponse: http.StatusBadRequest, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + retryFilterCall := 0 + retryFilter := func(c echo.Context, e error) bool { + if len(tc.retryFilters) == 0 { + assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) + } + + retryFilterCall++ + + nextRetryFilter := tc.retryFilters[0] + tc.retryFilters = tc.retryFilters[1:] + + return nextRetryFilter(c, e) + } + + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer(tc.targets), + RetryCount: tc.retryCount, + RetryFilter: retryFilter, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedResponse, rec.Code) + if len(tc.retryFilters) > 0 { + assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) + } + }) + } +} + +func TestProxyRetryWithBackendTimeout(t *testing.T) { + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.ResponseHeaderTimeout = time.Millisecond * 500 + + timeoutBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(404) + }), + ) + defer timeoutBackend.Close() + + timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) + goodBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }), + ) + defer goodBackend.Close() + + goodTargetURL, _ := url.Parse(goodBackend.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Transport: transport, + Balancer: NewRoundRobinBalancer([]*ProxyTarget{ + { + Name: "Timeout", + URL: timeoutTargetURL, + }, + { + Name: "Good", + URL: goodTargetURL, + }, + }), + RetryCount: 1, + }, + )) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, 200, rec.Code) + }() + } + + wg.Wait() + +} + +func TestProxyErrorHandler(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + goodURL, _ := url.Parse(server.URL) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: goodURL, + } + + badURL, _ := url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: badURL, + } + + transformedError := errors.New("a new error") + + testCases := []struct { + name string + target *ProxyTarget + errorHandler func(c echo.Context, e error) error + expectFinalError func(t *testing.T, err error) + }{ + { + name: "Error handler not invoked when request success", + target: goodTarget, + errorHandler: func(c echo.Context, e error) error { + assert.FailNow(t, "error handler should not be invoked") + return e + }, + }, + { + name: "Error handler invoked when request fails", + target: badTarget, + errorHandler: func(c echo.Context, e error) error { + httpErr, ok := e.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return transformedError + }, + expectFinalError: func(t *testing.T, err error) { + assert.Equal(t, transformedError, err, "transformed error not returned from proxy") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), + ErrorHandler: tc.errorHandler, + }, + )) + + errorHandlerCalled := false + dheh := echo.DefaultHTTPErrorHandler(false) + e.HTTPErrorHandler = func(c echo.Context, err error) { + errorHandlerCalled = true + tc.expectFinalError(t, err) + dheh(c, err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + if !errorHandlerCalled && tc.expectFinalError != nil { + t.Fatalf("error handler was not called") + } + + }) + } +} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index be2b348db..5b30b6123 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,47 +1,42 @@ package middleware import ( + "errors" "net/http" "sync" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(c echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(c echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(c echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -50,14 +45,14 @@ var DefaultRateLimiterConfig = RateLimiterConfig{ id := ctx.RealIP() return id, nil }, - ErrorHandler: func(context echo.Context, err error) error { + ErrorHandler: func(c echo.Context, err error) error { return &echo.HTTPError{ Code: ErrExtractorError.Code, Message: ErrExtractorError.Message, Internal: err, } }, - DenyHandler: func(context echo.Context, identifier string, err error) error { + DenyHandler: func(c echo.Context, identifier string, err error) error { return &echo.HTTPError{ Code: ErrRateLimitExceeded.Code, Message: ErrRateLimitExceeded.Message, @@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - panic("Store configuration must be provided") + return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -137,58 +137,57 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { identifier, err := config.IdentifierExtractor(c) if err != nil { - c.Error(config.ErrorHandler(c, err)) - return nil + return config.ErrorHandler(c, err) } - if allow, err := config.Store.Allow(identifier); !allow { - c.Error(config.DenyHandler(c, identifier, err)) - return nil + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) } return next(c) } - } + }, nil } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit + burst int + expiresIn time.Duration + lastCleanup time.Time - burst int - expiresIn time.Duration - lastCleanup time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) + timeNow func() time.Time +} + +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with -the provided rate (as req/s). The provided rate less than 1 will be treated as zero. +the provided rate (as req/s). for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst and ExpiresIn will be set to default values. +Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. + Example (with 20 requests/sec): limiterStore := middleware.NewRateLimiterMemoryStore(20) - */ -func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { +func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) { return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ - Rate: rate, + Rate: rateLimit, }) } /* NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore -with the provided configuration. Rate must be provided. Burst will be set to the value of +with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of the configured rate if not provided or set to 0. The build-in memory store is usually capable for modest loads. For higher loads other @@ -218,14 +217,15 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s store.burst = int(config.Rate) } store.visitors = make(map[string]*Visitor) - store.lastCleanup = now() + store.timeNow = time.Now + store.lastCleanup = store.timeNow() return } // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { - Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. - Burst int // Burst additionally allows a number of requests to pass when rate limit is reached + Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } @@ -240,15 +240,16 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) - limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) store.visitors[identifier] = limiter } - limiter.lastSeen = now() - if now().Sub(store.lastCleanup) > store.expiresIn { + now := store.timeNow() + limiter.lastSeen = now + if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors() } store.mutex.Unlock() - return limiter.AllowN(now(), 1), nil + return limiter.AllowN(store.timeNow(), 1), nil } /* @@ -257,14 +258,9 @@ of users who haven't visited again after the configured expiry time has elapsed */ func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { for id, visitor := range store.visitors { - if now().Sub(visitor.lastSeen) > store.expiresIn { + if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } - store.lastCleanup = now() + store.lastCleanup = store.timeNow() } - -/* -actual time method which is mocked in test file -*/ -var now = time.Now diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 89d9a6edc..7a63fb262 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -2,7 +2,6 @@ package middleware import ( "errors" - "fmt" "math/rand" "net/http" "net/http/httptest" @@ -10,8 +9,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -25,19 +23,19 @@ func TestRateLimiter(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiter(inMemoryStore) + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -47,20 +45,25 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - _ = mw(handler)(c) - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } -func TestRateLimiter_panicBehaviour(t *testing.T) { +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiter(nil) + RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { - RateLimiter(inMemoryStore) + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } @@ -73,7 +76,7 @@ func TestRateLimiterWithConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -88,7 +91,8 @@ func TestRateLimiterWithConfig(t *testing.T) { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { id string @@ -111,8 +115,9 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) + err := mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -126,7 +131,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -135,19 +140,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"", http.StatusForbidden}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -158,9 +164,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -174,21 +184,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -199,9 +210,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } } @@ -222,7 +237,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return true }, @@ -233,10 +248,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -256,7 +273,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return false }, @@ -267,7 +284,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) _ = mw(handler)(c) @@ -291,7 +309,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ BeforeFunc: func(c echo.Context) { beforeRan = true }, @@ -299,10 +317,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -340,7 +360,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { for i, tc := range testCases { t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) - now = func() time.Time { + inMemoryStore.timeNow = func() time.Time { return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) } allowed, _ := inMemoryStore.Allow(tc.id) @@ -350,24 +370,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - now = time.Now - fmt.Println(now()) inMemoryStore.visitors = map[string]*Visitor{ "A": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now(), + lastSeen: time.Now(), }, "B": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-1 * time.Minute), + lastSeen: time.Now().Add(-1 * time.Minute), }, "C": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-5 * time.Minute), + lastSeen: time.Now().Add(-5 * time.Minute), }, "D": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-10 * time.Minute), + lastSeen: time.Now().Add(-10 * time.Minute), }, } @@ -391,7 +409,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { func TestNewRateLimiterMemoryStore(t *testing.T) { testCases := []struct { - rate rate.Limit + rate float64 burst int expiresIn time.Duration expectedExpiresIn time.Duration @@ -413,7 +431,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/recover.go b/middleware/recover.go index 7b6128533..7e46ccd7b 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,53 +5,35 @@ import ( "net/http" "runtime" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" ) -type ( - // LogErrorFunc defines a function for custom logging in the middleware. - LogErrorFunc func(c echo.Context, err error, stack []byte) error +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` - - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - LogErrorFunc LogErrorFunc - } -) + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. @@ -59,9 +41,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -71,7 +57,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -81,42 +67,19 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if r == http.ErrAbortHandler { panic(r) } - err, ok := r.(error) + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - var stack []byte - var length int - if !config.DisablePrintStack { - stack = make([]byte, config.StackSize) - length = runtime.Stack(stack, !config.DisableStackAll) - stack = stack[:length] + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length]) } - - if config.LogErrorFunc != nil { - err = config.LogErrorFunc(c, err, stack) - } else if !config.DisablePrintStack { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { - case log.DEBUG: - c.Logger().Debug(msg) - case log.INFO: - c.Logger().Info(msg) - case log.WARN: - c.Logger().Warn(msg) - case log.ERROR: - c.Logger().Error(msg) - case log.OFF: - // None. - default: - c.Logger().Print(msg) - } - } - c.Error(err) + err = tmpErr } }() return next(c) } - } + }, nil } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index b27f3b41c..f8d0db5e2 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,42 +2,63 @@ package middleware import ( "bytes" - "errors" - "fmt" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c echo.Context) error { panic("test") - })) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") + }) + err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged +} + +func TestRecover_skipper(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := RecoverConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + } + h := RecoverWithConfig(config)(func(c echo.Context) error { + panic("testPANIC") + }) + + var err error + assert.Panics(t, func() { + err = h(c) + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain } func TestRecoverErrAbortHandler(t *testing.T) { e := echo.New() - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c echo.Context) error { panic(http.ErrAbortHandler) - })) + }) defer func() { r := recover() if r == nil { @@ -51,115 +72,66 @@ func TestRecoverErrAbortHandler(t *testing.T) { } }() - h(c) + hErr := h(c) assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.NotContains(t, buf.String(), "PANIC RECOVER") + assert.NotContains(t, hErr.Error(), "PANIC RECOVER") } -func TestRecoverWithConfig_LogLevel(t *testing.T) { - tests := []struct { - logLevel log.Lvl - levelName string - }{{ - logLevel: log.DEBUG, - levelName: "DEBUG", - }, { - logLevel: log.INFO, - levelName: "INFO", - }, { - logLevel: log.WARN, - levelName: "WARN", - }, { - logLevel: log.ERROR, - levelName: "ERROR", - }, { - logLevel: log.OFF, - levelName: "OFF", - }} - - for _, tt := range tests { - tt := tt - t.Run(tt.levelName, func(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := DefaultRecoverConfig - config.LogLevel = tt.logLevel - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - - h(c) + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) - assert.Equal(t, http.StatusInternalServerError, rec.Code) + err := h(c) - output := buf.String() - if tt.logLevel == log.OFF { - assert.Empty(t, output) + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) } else { - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + assert.NoError(t, err) } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } - -func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - testError := errors.New("test") - config := DefaultRecoverConfig - config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) - if errors.Is(err, testError) { - c.Logger().Debug(msg) - } else { - c.Logger().Error(msg) - } - return err - } - - t.Run("first branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic(testError) - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"DEBUG"`) - }) - - t.Run("else branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("other") - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"ERROR"`) - }) -} diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db38..bda5ac204 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,10 +1,11 @@ package middleware import ( + "errors" "net/http" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. @@ -14,7 +15,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" && !strings.HasPrefix(host, www) { - return true, "https://www." + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if scheme != "https" { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri - } - return false, "" - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRedirectConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } + return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 9d1b56205..9484bdf20 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..b553321ec 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,50 +1,42 @@ package middleware import ( - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Default value random.String(32). + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(c echo.Context, requestID string) - // TargetHeader defines what header to look for to populate the id - TargetHeader string - } -) - -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, - } -) + // TargetHeader defines what header to look for to populate the id + TargetHeader string +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID @@ -69,9 +61,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return next(c) } - } -} - -func generator() string { - return random.String(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 21b777826..fd0ef5d56 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator and handler - customID := "customGenerator" - calledHandler := false + // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return customID }, - RequestIDHandler: func(_ echo.Context, id string) { - calledHandler = true - assert.Equal(t, customID, id) - }, + Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 1b3e3eaad..46539e6a9 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -2,17 +2,23 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" "time" ) // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogStatus: true, -// LogURI: true, +// LogStatus: true, +// LogURI: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// if v.Error == nil { +// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// } else { +// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error) +// } // return nil // }, // })) @@ -20,14 +26,23 @@ import ( // Example for Zerolog (https://github.com/rs/zerolog) // logger := zerolog.New(os.Stdout) // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info(). -// Str("URI", v.URI). -// Int("status", v.Status). -// Msg("request") -// +// if v.Error == nil { +// logger.Info(). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request") +// } else { +// logger.Error(). +// Err(v.Error). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request error") +// } // return nil // }, // })) @@ -35,29 +50,47 @@ import ( // Example for Zap (https://github.com/uber-go/zap) // logger, _ := zap.NewProduction() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info("request", -// zap.String("URI", v.URI), -// zap.Int("status", v.Status), -// ) -// +// if v.Error == nil { +// logger.Info("request", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// ) +// } else { +// logger.Error("request error", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// zap.Error(v.Error), +// ) +// } // return nil // }, // })) // // Example for Logrus (https://github.com/sirupsen/logrus) -// log := logrus.New() +// log := logrus.New() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, -// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { -// log.WithFields(logrus.Fields{ -// "URI": values.URI, -// "status": values.Status, -// }).Info("request") -// +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// }).Info("request") +// } else { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// "error": v.Error, +// }).Error("request error") +// } // return nil // }, // })) @@ -73,6 +106,13 @@ type RequestLoggerConfig struct { // Mandatory. LogValuesFunc func(c echo.Context, v RequestLoggerValues) error + // HandleError instructs logger to call global error handler when next middleware/handler returns an error. + // This is useful when you have custom error handler that can decide to use different status codes. + // + // A side-effect of calling global error handler is that now Response has been committed and sent to the client + // and middlewares up in chain can not change Response status code or response body. + HandleError bool + // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). LogLatency bool // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) @@ -158,15 +198,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is been logger for each given header. + // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is been logger for each given query param name. + // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is been logger for each given form value name. + // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } @@ -184,7 +224,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } - now = time.Now + now := time.Now if config.timeNow != nil { now = config.timeNow } @@ -216,6 +256,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.BeforeNextFunc(c) } err := next(c) + if err != nil && config.HandleError { + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. + c.Error(err) + } v := RequestLoggerValues{ StartTime: start, @@ -263,8 +308,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } if config.LogStatus { v.Status = res.Status - if err != nil { - if httpErr, ok := err.(*echo.HTTPError); ok { + if err != nil && !config.HandleError { + // this block should not be executed in case of HandleError=true as the global error handler will decide + // the status code. In that case status code could be different from what err contains. + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { v.Status = httpErr.Code } } @@ -307,7 +355,10 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { return errOnLog } - + // in case of HandleError=true we are returning the error that we already have handled with global error handler + // this is deliberate as this error could be useful for upstream middlewares and default global error handler + // will ignore that error when it bubbles up in middleware chain. + // Committed response can be checked in custom error handler with `c.Response().Committed` field return err } }, nil diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 5118b1216..3049d5923 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -1,7 +1,7 @@ package middleware import ( - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -103,12 +103,12 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) { func TestRequestLogger_logError(t *testing.T) { e := echo.New() - var expect RequestLoggerValues + var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogError: true, LogStatus: true, LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { - expect = values + actual = values return nil }, })) @@ -123,8 +123,52 @@ func TestRequestLogger_logError(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotAcceptable, rec.Code) - assert.Equal(t, http.StatusNotAcceptable, expect.Status) - assert.EqualError(t, expect.Error, "code=406, message=nope") + assert.Equal(t, http.StatusNotAcceptable, actual.Status) + assert.EqualError(t, actual.Error, "code=406, message=nope") +} + +func TestRequestLogger_HandleError(t *testing.T) { + e := echo.New() + + var actual RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + timeNow: func() time.Time { + return time.Unix(1631045377, 0).UTC() + }, + HandleError: true, + LogError: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + actual = values + return nil + }, + })) + + // to see if "HandleError" works we create custom error handler that uses its own status codes + e.HTTPErrorHandler = func(c echo.Context, err error) { + if c.Response().Committed { + return + } + c.JSON(http.StatusTeapot, "custom error handler") + } + + e.GET("/test", func(c echo.Context) error { + return echo.NewHTTPError(http.StatusForbidden, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + + expect := RequestLoggerValues{ + StartTime: time.Unix(1631045377, 0).UTC(), + Status: http.StatusTeapot, + Error: echo.NewHTTPError(http.StatusForbidden, "nope"), + } + assert.Equal(t, expect, actual) } func TestRequestLogger_LogValuesFuncError(t *testing.T) { @@ -289,7 +333,7 @@ func TestRequestLogger_allFields(t *testing.T) { req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) c.SetPath("/test*") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b56..16677263f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,62 +1,58 @@ package middleware import ( + "errors" "regexp" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` - } -) - -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string +} // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil && config.RegexRules == nil { - panic("echo: rewrite middleware requires url path rewrite rules or regex rules") - } + return toMiddlewareOrPanic(config) +} +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { @@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 0ac04bb2f..c4044dcc4 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,14 +1,14 @@ package middleware import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" "regexp" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) { }, })) e.GET("/public/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) e.GET("/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) var testCases = []struct { @@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) { } } +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } +} + // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, - )) + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) @@ -142,7 +195,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() - bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) + bodyBytes, _ := io.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..571b35877 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -3,87 +3,83 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,