Skip to content

Commit

Permalink
middleware: fix group handling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
demget committed Oct 1, 2020
1 parent 09f2572 commit 90fad0b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
29 changes: 23 additions & 6 deletions bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ func NewBot(pref Settings) (*Bot, error) {
Poller: pref.Poller,
OnError: pref.OnError,

Updates: make(chan Update, pref.Updates),
stop: make(chan struct{}),
Updates: make(chan Update, pref.Updates),
handlers: make(map[string]HandlerFunc),
stop: make(chan struct{}),

synchronous: pref.Synchronous,
verbose: pref.Verbose,
Expand Down Expand Up @@ -71,6 +72,7 @@ type Bot struct {
OnError func(error, Context)

group *Group
handlers map[string]HandlerFunc
synchronous bool
verbose bool
parseMode ParseMode
Expand Down Expand Up @@ -144,7 +146,7 @@ type Command struct {

// Group returns a new group.
func (b *Bot) Group() *Group {
return &Group{handlers: make(map[string]HandlerFunc)}
return &Group{b: b}
}

// Use adds middleware to the global bot chain.
Expand Down Expand Up @@ -172,7 +174,22 @@ func (b *Bot) Use(middleware ...MiddlewareFunc) {
// b.Handle("/ban", onBan, protected)
//
func (b *Bot) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
b.group.Handle(endpoint, h, m...)
if len(b.group.middleware) > 0 {
m = append(b.group.middleware, m...)
}

handler := func(c Context) error {
return applyMiddleware(h, m...)(c)
}

switch end := endpoint.(type) {
case string:
b.handlers[end] = handler
case CallbackEndpoint:
b.handlers[end.CallbackUnique()] = handler
default:
panic("telebot: unsupported endpoint")
}
}

var (
Expand Down Expand Up @@ -365,7 +382,7 @@ func (b *Bot) ProcessUpdate(upd Update) {
if match != nil {
unique, payload := match[0][1], match[0][3]

if handler, ok := b.group.handlers["\f"+unique]; ok {
if handler, ok := b.handlers["\f"+unique]; ok {
upd.Callback.Data = payload
b.runHandler(handler, c)
return
Expand Down Expand Up @@ -410,7 +427,7 @@ func (b *Bot) ProcessUpdate(upd Update) {
}

func (b *Bot) handle(end string, c Context) bool {
if handler, ok := b.group.handlers[end]; ok {
if handler, ok := b.handlers[end]; ok {
b.runHandler(handler, c)
return true
}
Expand Down
10 changes: 5 additions & 5 deletions bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestBotHandle(t *testing.T) {
}

b.Handle("/start", func(c Context) error { return nil })
assert.Contains(t, b.group.handlers, "/start")
assert.Contains(t, b.handlers, "/start")

reply := ReplyButton{Text: "reply"}
b.Handle(&reply, func(c Context) error { return nil })
Expand All @@ -97,10 +97,10 @@ func TestBotHandle(t *testing.T) {
btnInline := (&ReplyMarkup{}).Data("", "btnInline")
b.Handle(&btnInline, func(c Context) error { return nil })

assert.Contains(t, b.group.handlers, btnReply.CallbackUnique())
assert.Contains(t, b.group.handlers, btnInline.CallbackUnique())
assert.Contains(t, b.group.handlers, reply.CallbackUnique())
assert.Contains(t, b.group.handlers, inline.CallbackUnique())
assert.Contains(t, b.handlers, btnReply.CallbackUnique())
assert.Contains(t, b.handlers, btnInline.CallbackUnique())
assert.Contains(t, b.handlers, reply.CallbackUnique())
assert.Contains(t, b.handlers, inline.CallbackUnique())
}

func TestBotStart(t *testing.T) {
Expand Down
19 changes: 2 additions & 17 deletions group.go → middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ type MiddlewareFunc func(HandlerFunc) HandlerFunc

// Group is a separated group of handlers, united by the general middleware.
type Group struct {
b *Bot
middleware []MiddlewareFunc
handlers map[string]HandlerFunc
}

// Use adds middleware to the chain.
Expand All @@ -18,20 +18,5 @@ func (g *Group) Use(middleware ...MiddlewareFunc) {
// Handle adds endpoint handler to the bot, combining group's middleware
// with the optional given middleware.
func (g *Group) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
if len(g.middleware) > 0 {
m = append(g.middleware, m...)
}

handler := func(c Context) error {
return applyMiddleware(h, m...)(c)
}

switch end := endpoint.(type) {
case string:
g.handlers[end] = handler
case CallbackEndpoint:
g.handlers[end.CallbackUnique()] = handler
default:
panic("telebot: unsupported endpoint")
}
g.b.Handle(endpoint, h, append(g.middleware, m...)...)
}

0 comments on commit 90fad0b

Please sign in to comment.