Skip to content

Commit

Permalink
Guess the scheme if r.URL.Scheme is unset (gorilla#474)
Browse files Browse the repository at this point in the history
* Guess the scheme if r.URL.Scheme is unset
It's not expected that the request's URL is fully populated when used on
the server-side (it's more of a client-side field), so we shouldn't
expect it to be present.

In practice, it's only rarely set at all on the server, making mux's
`Schemes` matcher tricky to use as it is.

This commit adds a test which would have failed before demonstrating the
problem, as well as a fix which I think makes `.Schemes` match what
users expect.

* [doc] Add more detail to Schemes and URL godocs

* Add route url test for schemes

* Make httpserver test use more specific scheme matchers

* Update test to have different responses per route
  • Loading branch information
euank authored and elithrar committed Oct 18, 2019
1 parent 884b5ff commit ff4e71f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 19 deletions.
49 changes: 49 additions & 0 deletions mux_httpserver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// +build go1.9

package mux

import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestSchemeMatchers(t *testing.T) {
router := NewRouter()
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello http world"))
}).Schemes("http")
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello https world"))
}).Schemes("https")

assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected error reading body: %v", err)
}
if !bytes.Equal(body, []byte(expectedBody)) {
t.Fatalf("response should be hello world, was: %q", string(body))
}
}

t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(router)
defer s.Close()
assertResponseBody(t, s, "hello http world")
})
t.Run("httpsServer", func(t *testing.T) {
s := httptest.NewTLSServer(router)
defer s.Close()
assertResponseBody(t, s, "hello https world")
})
}
6 changes: 2 additions & 4 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
Expand Down Expand Up @@ -2895,10 +2896,7 @@ func newRequestWithHeaders(method, url string, headers ...string) *http.Request

// newRequestHost a new request with a method, url, and host header
func newRequestHost(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req := httptest.NewRequest(method, url, nil)
req.Host = host
return req
}
17 changes: 5 additions & 12 deletions old_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ var urlBuildingTests = []urlBuildingTest{
vars: []string{"subdomain", "foo", "category", "technology", "id", "42"},
url: "http://foo.domain.com/articles/technology/42",
},
{
route: new(Route).Host("example.com").Schemes("https", "http"),
vars: []string{},
url: "https://example.com",
},
}

func TestHeaderMatcher(t *testing.T) {
Expand Down Expand Up @@ -502,18 +507,6 @@ func TestUrlBuilding(t *testing.T) {
url := u.String()
if url != v.url {
t.Errorf("expected %v, got %v", v.url, url)
/*
reversePath := ""
reverseHost := ""
if v.route.pathTemplate != nil {
reversePath = v.route.pathTemplate.Reverse
}
if v.route.hostTemplate != nil {
reverseHost = v.route.hostTemplate.Reverse
}
t.Errorf("%#v:\nexpected: %q\ngot: %q\nreverse path: %q\nreverse host: %q", v.route, v.url, url, reversePath, reverseHost)
*/
}
}

Expand Down
32 changes: 29 additions & 3 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,30 @@ func (r *Route) Queries(pairs ...string) *Route {
type schemeMatcher []string

func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.URL.Scheme)
scheme := r.URL.Scheme
// https://golang.org/pkg/net/http/#Request
// "For [most] server requests, fields other than Path and RawQuery will be
// empty."
// Since we're an http muxer, the scheme is either going to be http or https
// though, so we can just set it based on the tls termination state.
if scheme == "" {
if r.TLS == nil {
scheme = "http"
} else {
scheme = "https"
}
}
return matchInArray(m, scheme)
}

// Schemes adds a matcher for URL schemes.
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
// If the request's URL has a scheme set, it will be matched against.
// Generally, the URL scheme will only be set if a previous handler set it,
// such as the ProxyHeaders handler from gorilla/handlers.
// If unset, the scheme will be determined based on the request's TLS
// termination state.
// The first argument to Schemes will be used when constructing a route URL.
func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
Expand Down Expand Up @@ -493,15 +512,22 @@ func (r *Route) Subrouter() *Router {
// This also works for host variables:
//
// r := mux.NewRouter()
// r.Host("{subdomain}.domain.com").
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Host("{subdomain}.domain.com").
// Name("article")
//
// // url.String() will be "http://news.domain.com/articles/technology/42"
// url, err := r.Get("article").URL("subdomain", "news",
// "category", "technology",
// "id", "42")
//
// The scheme of the resulting url will be the first argument that was passed to Schemes:
//
// // url.String() will be "https://example.com"
// r := mux.NewRouter()
// url, err := r.Host("example.com")
// .Schemes("https", "http").URL()
//
// All variables defined in the route are required, and their values must
// conform to the corresponding patterns.
func (r *Route) URL(pairs ...string) (*url.URL, error) {
Expand Down

0 comments on commit ff4e71f

Please sign in to comment.