Skip to content

Commit

Permalink
Chaining functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
justinas committed May 25, 2014
1 parent f197ae7 commit 41455ca
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
22 changes: 22 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,25 @@ func New(constructors ...Constructor) Chain {

return c
}

// Chains the middleware and returns the final http.Handler
// New(m1, m2, m3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// When the request comes in, it will be passed to m1, then m2, then m3
// and finally, the given handler
// (assuming every middleware calls the following one)
func (c Chain) Then(h http.Handler) http.Handler {
var final http.Handler
if h != nil {
final = h
} else {
final = http.DefaultServeMux
}

for i := len(c.constructors) - 1; i >= 0; i-- {
final = c.constructors[i](final)
}

return final
}
34 changes: 34 additions & 0 deletions chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@ package alice

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

// A constructor for middleware
// that writes its own "tag" into the RW and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Constructor {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
h.ServeHTTP(w, r)
})
}
}

// Tests creating a new chain
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
Expand All @@ -22,3 +35,24 @@ func TestNew(t *testing.T) {
assert.Equal(t, chain.constructors[0], slice[0])
assert.Equal(t, chain.constructors[1], slice[1])
}

func TestThen(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")
app := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("app\n"))
})

chained := New(t1, t2, t3).Then(app)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.ServeHTTP(w, r)

assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n")
}

0 comments on commit 41455ca

Please sign in to comment.