Skip to content

Commit

Permalink
Introduce limits middleware
Browse files Browse the repository at this point in the history
1. Replace original `maxrequestbody` directive.
2. Add request header limit.

fix issue caddyserver#1587

Signed-off-by: Tw <[email protected]>
  • Loading branch information
tw4452852 committed May 8, 2017
1 parent 90efff6 commit ae645ef
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 143 deletions.
2 changes: 1 addition & 1 deletion caddyhttp/caddyhttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ import (
_ "github.com/mholt/caddy/caddyhttp/header"
_ "github.com/mholt/caddy/caddyhttp/index"
_ "github.com/mholt/caddy/caddyhttp/internalsrv"
_ "github.com/mholt/caddy/caddyhttp/limits"
_ "github.com/mholt/caddy/caddyhttp/log"
_ "github.com/mholt/caddy/caddyhttp/markdown"
_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
_ "github.com/mholt/caddy/caddyhttp/mime"
_ "github.com/mholt/caddy/caddyhttp/pprof"
_ "github.com/mholt/caddy/caddyhttp/proxy"
Expand Down
2 changes: 1 addition & 1 deletion caddyhttp/httpserver/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ var directives = []string{
"root",
"index",
"bind",
"maxrequestbody", // TODO: 'limits'
"limits",
"timeouts",
"tls",

Expand Down
2 changes: 1 addition & 1 deletion caddyhttp/httpserver/replacer.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
}
_, err := ioutil.ReadAll(r.request.Body)
if err != nil {
if _, ok := err.(MaxBytesExceeded); ok {
if err == MaxBytesExceededErr {
return r.emptyValue
}
}
Expand Down
111 changes: 30 additions & 81 deletions caddyhttp/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package httpserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
sites: group,
connTimeout: GracefulTimeout,
}
s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
s.Server.Handler = s // this is weird, but whatever

// extract TLS settings from each site config to build
Expand Down Expand Up @@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
return s, nil
}

// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
var min int64
for _, cfg := range group {
limit := cfg.Limits.MaxRequestHeaderSize
if limit == 0 {
continue
}

// not set yet
if min == 0 {
min = limit
}

// find a better one
if limit < min {
min = limit
}
}

if min > 0 {
s.MaxHeaderBytes = int(min)
}
return s
}

// makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each
Expand Down Expand Up @@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
}
}

// Apply the path-based request body size limit
// The error returned by MaxBytesReader is meant to be handled
// by whichever middleware/plugin that receives it when calling
// .Read() or a similar method on the request body
// TODO: Make this middleware instead?
if r.Body != nil {
for _, pathlimit := range vhost.MaxRequestBodySizes {
if Path(r.URL.Path).Matches(pathlimit.Path) {
r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
break
}
}
}

return vhost.middlewareChain.ServeHTTP(w, r)
}

Expand Down Expand Up @@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File()
}

// MaxBytesExceeded is the error type returned by MaxBytesReader
// MaxBytesExceeded is the error returned by MaxBytesReader
// when the request body exceeds the limit imposed
type MaxBytesExceeded struct{}

func (err MaxBytesExceeded) Error() string {
return "http: request body too large"
}

// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}

type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}

func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)

if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}

n = int(l.n)
l.n = 0

// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = MaxBytesExceeded{}
return n, l.err
}

func (l *maxBytesReader) Close() error {
return l.r.Close()
}
var MaxBytesExceededErr = errors.New("http: request body too large")

// DefaultErrorFunc responds to an HTTP request with a simple description
// of the specified HTTP status code.
Expand Down
35 changes: 34 additions & 1 deletion caddyhttp/httpserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
}
}

func TestMakeHTTPServer(t *testing.T) {
func TestMakeHTTPServerWithTimeouts(t *testing.T) {
for i, tc := range []struct {
group []*SiteConfig
expected Timeouts
Expand Down Expand Up @@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
}
}
}

func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct {
group []*SiteConfig
expect int
}{
"disable": {
group: []*SiteConfig{{}},
expect: 0,
},
"oneSite": {
group: []*SiteConfig{{Limits: Limits{
MaxRequestHeaderSize: 100,
}}},
expect: 100,
},
"multiSites": {
group: []*SiteConfig{
{Limits: Limits{MaxRequestHeaderSize: 100}},
{Limits: Limits{MaxRequestHeaderSize: 50}},
},
expect: 50,
},
} {
c := c
t.Run(name, func(t *testing.T) {
actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
if got := actual.MaxHeaderBytes; got != c.expect {
t.Errorf("Expect %d, but got %d", c.expect, got)
}
})
}
}
10 changes: 8 additions & 2 deletions caddyhttp/httpserver/siteconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ type SiteConfig struct {
// for a request.
HiddenFiles []string

// Max amount of bytes a request can send on a given path
MaxRequestBodySizes []PathLimit
// Max request's header/body size
Limits Limits

// The path to the Caddyfile used to generate this site config
originCaddyfile string
Expand Down Expand Up @@ -71,6 +71,12 @@ type Timeouts struct {
IdleTimeoutSet bool
}

// Limits specify size limit of request's header and body.
type Limits struct {
MaxRequestHeaderSize int64
MaxRequestBodySizes []PathLimit
}

// PathLimit is a mapping from a site's path to its corresponding
// maximum request body size (in bytes)
type PathLimit struct {
Expand Down
90 changes: 90 additions & 0 deletions caddyhttp/limits/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package limits

import (
"io"
"net/http"

"github.com/mholt/caddy/caddyhttp/httpserver"
)

// Limit is a middleware to control request body size
type Limit struct {
Next httpserver.Handler
BodyLimits []httpserver.PathLimit
}

func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if r.Body == nil {
return l.Next.ServeHTTP(w, r)
}

// apply the path-based request body size limit.
for _, bl := range l.BodyLimits {
if httpserver.Path(r.URL.Path).Matches(bl.Path) {
r.Body = MaxBytesReader(w, r.Body, bl.Limit)
break
}
}

return l.Next.ServeHTTP(w, r)
}

// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}

type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}

func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)

if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}

n = int(l.n)
l.n = 0

// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = httpserver.MaxBytesExceededErr
return n, l.err
}

func (l *maxBytesReader) Close() error {
return l.r.Close()
}
35 changes: 35 additions & 0 deletions caddyhttp/limits/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package limits

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/mholt/caddy/caddyhttp/httpserver"
)

func TestBodySizeLimit(t *testing.T) {
var (
gotContent []byte
gotError error
expectContent = "hello"
)
l := Limit{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
gotContent, gotError = ioutil.ReadAll(r.Body)
return 0, nil
}),
BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
}

r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
l.ServeHTTP(httptest.NewRecorder(), r)
if got := string(gotContent); got != expectContent {
t.Errorf("expected content[%s], got[%s]", expectContent, got)
}
if gotError != httpserver.MaxBytesExceededErr {
t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
}
}
Loading

0 comments on commit ae645ef

Please sign in to comment.