Skip to content

Commit

Permalink
middleware: minor improvements to Compress mw handler
Browse files Browse the repository at this point in the history
  • Loading branch information
pkieltyka committed Jan 9, 2019
1 parent bb7ee27 commit 243bfce
Showing 1 changed file with 46 additions and 38 deletions.
84 changes: 46 additions & 38 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ import (
"io"
"net"
"net/http"
"regexp"
"sort"
"strings"
)

var encoders = map[string]EncoderFunc{}

var acceptEncodingAlgorithmsRe = regexp.MustCompile(`([a-z]{2,}|\*)`)
var encodingPrecedence = []string{"br", "gzip", "deflate"}

func init() {
// TODO:
// lzma: Opera.
// sdch: Chrome, Android. Gzip output + dictionary header.
// br: Brotli.
// br: Brotli, see https://github.com/go-chi/chi/pull/326

// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
Expand Down Expand Up @@ -73,13 +71,25 @@ type EncoderFunc func(w http.ResponseWriter, level int) io.Writer
// return brotli_enc.NewBrotliWriter(params, w)
// })
func SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
encoders[encoding] = fn

var e string
for _, v := range encodingPrecedence {
if v == encoding {
e = v
}
}

if e == "" {
encodingPrecedence = append([]string{e}, encodingPrecedence...)
}
}

var defaultContentTypes = map[string]struct{}{
Expand Down Expand Up @@ -107,6 +117,11 @@ func DefaultCompress(next http.Handler) http.Handler {
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// NOTE: make sure to set the Content-Type header on your response
// otherwise this middleware will not compress the response body. For ex, in
// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
// or set it manually.
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
contentTypes := defaultContentTypes
if len(types) > 0 {
Expand All @@ -119,17 +134,18 @@ func Compress(level int, types ...string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
encoder, encoding := selectEncoder(r.Header)
mcw := &maybeCompressResponseWriter{

cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: contentTypes,
encoder: encoder,
encoding: encoding,
level: level,
}
defer mcw.Close()
defer cw.Close()

next.ServeHTTP(mcw, r)
next.ServeHTTP(cw, r)
}

return http.HandlerFunc(fn)
Expand All @@ -140,38 +156,29 @@ func selectEncoder(h http.Header) (EncoderFunc, string) {
header := h.Get("Accept-Encoding")

// Parse the names of all accepted algorithms from the header.
var accepted []string
for _, m := range acceptEncodingAlgorithmsRe.FindAllStringSubmatch(header, -1) {
accepted = append(accepted, m[1])
}

sort.Sort(byPerformance(accepted))
accepted := strings.Split(strings.ToLower(header), ",")

// Select the first mutually supported algorithm.
for _, name := range accepted {
if fn, ok := encoders[name]; ok {
// Find supported encoder by accepted list by precedence
for _, name := range encodingPrecedence {
if fn, ok := encoders[name]; ok && matchAcceptEncoding(accepted, name) {
return fn, name
}
}

// No encoder found to match the accepted encoding
return nil, ""
}

type byPerformance []string

func (l byPerformance) Len() int { return len(l) }
func (l byPerformance) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l byPerformance) Less(i, j int) bool {
// Higher number = higher preference. This causes unknown names, which map
// to 0, to always be less prefered.
scores := map[string]int{
"br": 3,
"gzip": 2,
"deflate": 1,
func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
if strings.Index(v, encoding) >= 0 {
return true
}
}
return scores[l[i]] > scores[l[j]]
return false
}

type maybeCompressResponseWriter struct {
type compressResponseWriter struct {
http.ResponseWriter
w io.Writer
encoder EncoderFunc
Expand All @@ -181,21 +188,21 @@ type maybeCompressResponseWriter struct {
wroteHeader bool
}

func (w *maybeCompressResponseWriter) WriteHeader(code int) {
func (w *compressResponseWriter) WriteHeader(code int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
defer w.ResponseWriter.WriteHeader(code)
defer w.WriteHeader(code)

// Already compressed data?
if w.ResponseWriter.Header().Get("Content-Encoding") != "" {
if w.Header().Get("Content-Encoding") != "" {
return
}

// Parse the first part of the Content-Type response header.
contentType := ""
parts := strings.Split(w.ResponseWriter.Header().Get("Content-Type"), ";")
parts := strings.Split(w.Header().Get("Content-Type"), ";")
if len(parts) > 0 {
contentType = parts[0]
}
Expand All @@ -209,41 +216,42 @@ func (w *maybeCompressResponseWriter) WriteHeader(code int) {
if wr := w.encoder(w.ResponseWriter, w.level); wr != nil {
w.w = wr
w.Header().Set("Content-Encoding", w.encoding)

// The content-length after compression is unknown
w.Header().Del("Content-Length")
}
}
}

func (w *maybeCompressResponseWriter) Write(p []byte) (int, error) {
func (w *compressResponseWriter) Write(p []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}

return w.w.Write(p)
}

func (w *maybeCompressResponseWriter) Flush() {
func (w *compressResponseWriter) Flush() {
if f, ok := w.w.(http.Flusher); ok {
f.Flush()
}
}

func (w *maybeCompressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.w.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer")
}

func (w *maybeCompressResponseWriter) Push(target string, opts *http.PushOptions) error {
func (w *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
if ps, ok := w.w.(http.Pusher); ok {
return ps.Push(target, opts)
}
return errors.New("chi/middleware: http.Pusher is unavailable on the writer")
}

func (w *maybeCompressResponseWriter) Close() error {
func (w *compressResponseWriter) Close() error {
if c, ok := w.w.(io.WriteCloser); ok {
return c.Close()
}
Expand Down

0 comments on commit 243bfce

Please sign in to comment.