Skip to content

Commit

Permalink
Convert Filter type to a func
Browse files Browse the repository at this point in the history
- This halves the stack trace -- every interface call adds 2 stack
frames.  Since the filters are all nested, it got pretty overwhelming.
- Reduces the boilerplate for all the built-in filters (net -50 lines).
 I think that package level variables are always fine (in fact
preferred) for the filter, so the struct type gives no benefit.
- Reduces the ways to get into OnAppStart back to 1.

Downsides:
- Now the filters do not have a name that can be referenced in error
messages , etc.
- It is common for a filter to have some initialization, so it's a
little too bad we can't syntactically sugar that a little more than the
usual OnAppStart declaration.

Also, remove the FilterChain type as it was not adding anything.
  • Loading branch information
robfig committed May 25, 2013
1 parent 0fc5313 commit b916640
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 135 deletions.
19 changes: 3 additions & 16 deletions filter.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
package revel

type FilterChain []Filter

type Filter interface {
Call(c *Controller, chain FilterChain)
}

type InitializingFilter interface {
Filter
OnAppStart()
}
type Filter func(c *Controller, filterChain []Filter)

// Filters is the default set of global filters.
// It may be set by the application on initialization.
Expand All @@ -28,10 +19,6 @@ var Filters = []Filter{

// NilFilter and NilChain are helpful in writing filter tests.
var (
NilFilter nilFilter
NilChain = FilterChain{NilFilter}
NilFilter = func(_ *Controller, _ []Filter) {}
NilChain = []Filter{NilFilter}
)

type nilFilter struct{}

func (f nilFilter) Call(_ *Controller, _ FilterChain) {}
41 changes: 18 additions & 23 deletions filterconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package revel

import "reflect"

// Map from "Controller" or "Controller.Method" to FilterChain
var filterOverrides = make(map[string]FilterChain)
// Map from "Controller" or "Controller.Method" to the Filter chain
var filterOverrides = make(map[string][]Filter)

// FilterConfigurator allows the developer configure the filter chain on a
// per-controller or per-action basis. The filter configuration is applied by
Expand Down Expand Up @@ -96,17 +96,14 @@ func (conf FilterConfigurator) Add(f Filter) FilterConfigurator {

// Remove a filter from the filter chain.
func (conf FilterConfigurator) Remove(target Filter) FilterConfigurator {
var (
targetType = reflect.TypeOf(target)
filters = conf.getOverrideFilters()
)
filters := conf.getOverrideFilters()
for i, f := range filters {
if reflect.TypeOf(f) == targetType {
if FilterEq(f, target) {
filterOverrides[conf.key] = append(filters[:i], filters[i+1:]...)
return conf
}
}
panic("Did not find target filter: " + targetType.Name())
panic("Did not find target filter to remove")
}

// Insert a filter into the filter chain before or after another.
Expand All @@ -118,17 +115,14 @@ func (conf FilterConfigurator) Insert(insert Filter, where When, target Filter)
if where != BEFORE && where != AFTER {
panic("where must be BEFORE or AFTER")
}
var (
targetType = reflect.TypeOf(target)
filters = conf.getOverrideFilters()
)
filters := conf.getOverrideFilters()
for i, f := range filters {
if reflect.TypeOf(f) == targetType {
if FilterEq(f, target) {
filterOverrides[conf.key] = append(filters[:i], append([]Filter{insert}, filters[i:]...)...)
return conf
}
}
panic("Did not find target filter: " + targetType.Name())
panic("Did not find target filter for insert")
}

// getOverrideFilters returns the filter chain that applies to the given
Expand All @@ -145,7 +139,7 @@ func (conf FilterConfigurator) getOverrideFilters() []Filter {
if !ok {
// The override starts with all filters after FilterConfiguringFilter
for i, f := range Filters {
if f == FilterConfiguringFilter {
if FilterEq(f, FilterConfiguringFilter) {
filters = make([]Filter, len(Filters)-i-1)
copy(filters, Filters[i+1:])
break
Expand All @@ -159,22 +153,23 @@ func (conf FilterConfigurator) getOverrideFilters() []Filter {
return filters
}

// FilterEq returns true if the two filters reference the same filter.
func FilterEq(a, b Filter) bool {
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
}

// FilterConfiguringFilter is a filter stage that customizes the remaining
// filter chain for the action being invoked.
var FilterConfiguringFilter filterConfiguringFilter

type filterConfiguringFilter struct{}

func (f filterConfiguringFilter) Call(c *Controller, fc FilterChain) {
var FilterConfiguringFilter = func(c *Controller, fc []Filter) {
if newChain, ok := filterOverrides[c.Name+"."+c.Action]; ok {
newChain[0].Call(c, newChain[1:])
newChain[0](c, newChain[1:])
return
}

if newChain, ok := filterOverrides[c.Name]; ok {
newChain[0].Call(c, newChain[1:])
newChain[0](c, newChain[1:])
return
}

fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])
}
18 changes: 12 additions & 6 deletions filterconfig_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package revel

import (
"reflect"
"testing"
)
import "testing"

type FakeController struct{}

Expand Down Expand Up @@ -57,7 +54,7 @@ func TestFilterConfiguratorOps(t *testing.T) {
ActionInvoker,
}
actual := conf.getOverrideFilters()
if len(actual) != len(expected) || !reflect.DeepEqual(actual, expected) {
if len(actual) != len(expected) || !filterSliceEqual(actual, expected) {
t.Errorf("getOverrideFilter failed.\nActual: %#v\nExpect: %#v", actual, expected)
}

Expand All @@ -72,7 +69,16 @@ func TestFilterConfiguratorOps(t *testing.T) {
ActionInvoker,
}
actual = filterOverrides[conf.key]
if len(actual) != len(expected) || !reflect.DeepEqual(actual, expected) {
if len(actual) != len(expected) || !filterSliceEqual(actual, expected) {
t.Errorf("Ops failed.\nActual: %#v\nExpect: %#v", actual, expected)
}
}

func filterSliceEqual(a, e []Filter) bool {
for i, f := range a {
if !FilterEq(f, e[i]) {
return false
}
}
return true
}
8 changes: 2 additions & 6 deletions flash.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ func (f Flash) Success(msg string, args ...interface{}) {
}
}

var FlashFilter flashFilter

type flashFilter struct{}

func (p flashFilter) Call(c *Controller, fc FilterChain) {
var FlashFilter = func(c *Controller, fc []Filter) {
c.Flash = restoreFlash(c.Request.Request)
c.RenderArgs["flash"] = c.Flash.Data

fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])

// Store the flash.
var flashValue string
Expand Down
14 changes: 6 additions & 8 deletions i18n.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,13 @@ func parseLocaleFromFileName(file string) string {
return strings.ToLower(extension)
}

var I18nFilter i18nFilter

type i18nFilter struct{}

func (p i18nFilter) OnAppStart() {
loadMessages(filepath.Join(BasePath, messageFilesDirectory))
func init() {
OnAppStart(func() {
loadMessages(filepath.Join(BasePath, messageFilesDirectory))
})
}

func (p i18nFilter) Call(c *Controller, fc FilterChain) {
var I18nFilter = func(c *Controller, fc []Filter) {
if foundCookie, cookieValue := hasLocaleCookie(c.Request); foundCookie {
TRACE.Printf("Found locale cookie value: %s", cookieValue)
setCurrentLocaleControllerArguments(c, cookieValue)
Expand All @@ -155,7 +153,7 @@ func (p i18nFilter) Call(c *Controller, fc FilterChain) {
TRACE.Println("Unable to find locale in cookie or header, using empty string")
setCurrentLocaleControllerArguments(c, "")
}
fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])
}

// Set the current locale controller argument (CurrentLocaleControllerArg) with the given locale.
Expand Down
6 changes: 3 additions & 3 deletions i18n_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ func TestBeforeRequest(t *testing.T) {
loadTestI18nConfig(t)

c := NewController(buildEmptyRequest(), nil)
if I18nFilter.Call(c, NilChain); c.Request.Locale != "" {
if I18nFilter(c, NilChain); c.Request.Locale != "" {
t.Errorf("Expected to find current language '%s' in controller, found '%s' instead", "", c.Request.Locale)
}

c = NewController(buildRequestWithCookie("APP_LANG", "en-US"), nil)
if I18nFilter.Call(c, NilChain); c.Request.Locale != "en-US" {
if I18nFilter(c, NilChain); c.Request.Locale != "en-US" {
t.Errorf("Expected to find current language '%s' in controller, found '%s' instead", "en-US", c.Request.Locale)
}

c = NewController(buildRequestWithAcceptLanguages("en-GB", "en-US"), nil)
if I18nFilter.Call(c, NilChain); c.Request.Locale != "en-GB" {
if I18nFilter(c, NilChain); c.Request.Locale != "en-GB" {
t.Errorf("Expected to find current language '%s' in controller, found '%s' instead", "en-GB", c.Request.Locale)
}
}
Expand Down
8 changes: 2 additions & 6 deletions intercept.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ func (i Interception) Invoke(val reflect.Value) reflect.Value {
return vals[0]
}

var InterceptorFilter interceptorFilter

type interceptorFilter struct{}

func (p interceptorFilter) Call(c *Controller, fc FilterChain) {
var InterceptorFilter = func(c *Controller, fc []Filter) {
defer invokeInterceptors(FINALLY, c)
defer func() {
if err := recover(); err != nil {
Expand All @@ -104,7 +100,7 @@ func (p interceptorFilter) Call(c *Controller, fc FilterChain) {
return
}

fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])
invokeInterceptors(AFTER, c)
}

Expand Down
7 changes: 1 addition & 6 deletions invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@ var (
websocketType = reflect.TypeOf((*websocket.Conn)(nil))
)

var ActionInvoker actionInvoker

type actionInvoker struct{}

// Instantiate, bind params, and invoke the given action.
func (f actionInvoker) Call(c *Controller, _ FilterChain) {
var ActionInvoker = func(c *Controller, _ []Filter) {
// Instantiate the method.
methodValue := reflect.ValueOf(c.AppController).MethodByName(c.MethodType.Name)

Expand Down
2 changes: 1 addition & 1 deletion invoker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ func BenchmarkInvoker(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
ActionInvoker.Call(&c, nil)
ActionInvoker(&c, nil)
}
}
10 changes: 3 additions & 7 deletions modules/db/app/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ var (
Spec string
)

var DbFilter dbFilter

type dbFilter struct{}

func (p dbFilter) OnAppStart() {
func Init() {
// Read configuration.
var found bool
if Driver, found = revel.Config.String("db.driver"); !found {
Expand All @@ -37,7 +33,7 @@ func (p dbFilter) OnAppStart() {
}
}

func (p dbFilter) Call(c *revel.Controller, fc revel.FilterChain) {
var DbFilter = func(c *revel.Controller, fc []revel.Filter) {
// Begin transaction
txn, err := Db.Begin()
if err != nil {
Expand All @@ -54,7 +50,7 @@ func (p dbFilter) Call(c *revel.Controller, fc revel.FilterChain) {
}
}()

fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])

// Commit
if err := c.Txn.Commit(); err != nil {
Expand Down
8 changes: 2 additions & 6 deletions panic.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ import (

// PanicFilter wraps the action invocation in a protective defer blanket that
// converts panics into 500 error pages.
var PanicFilter panicFilter

type panicFilter struct{}

func (f panicFilter) Call(c *Controller, fc FilterChain) {
var PanicFilter = func(c *Controller, fc []Filter) {
defer func() {
if err := recover(); err != nil {
handleInvocationPanic(c, err)
}
}()
fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])
}

// This function handles a panic in an action invocation.
Expand Down
8 changes: 2 additions & 6 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,7 @@ func (p *Params) calcValues() url.Values {
return values
}

var ParamsFilter paramsFilter

type paramsFilter struct{}

func (f paramsFilter) Call(c *Controller, fc FilterChain) {
var ParamsFilter = func(c *Controller, fc []Filter) {
ParseParams(c.Params, c.Request)

// Clean up from the request.
Expand All @@ -130,5 +126,5 @@ func (f paramsFilter) Call(c *Controller, fc FilterChain) {
}
}()

fc[0].Call(c, fc[1:])
fc[0](c, fc[1:])
}
4 changes: 2 additions & 2 deletions params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func BenchmarkParams(b *testing.B) {
Params: &Params{},
}
for i := 0; i < b.N; i++ {
ParamsFilter.Call(&c, NilChain)
ParamsFilter(&c, NilChain)
}
}

Expand All @@ -98,7 +98,7 @@ func TestMultipartForm(t *testing.T) {
Request: NewRequest(getMultipartRequest()),
Params: &Params{},
}
ParamsFilter.Call(&c, NilChain)
ParamsFilter(&c, NilChain)

if !reflect.DeepEqual(expectedValues, map[string][]string(c.Params.Values)) {
t.Errorf("Param values: (expected) %v != %v (actual)",
Expand Down
Loading

0 comments on commit b916640

Please sign in to comment.