diff --git a/tgb/filter.go b/tgb/filter.go index 1f6e6fa..d4a8391 100644 --- a/tgb/filter.go +++ b/tgb/filter.go @@ -21,6 +21,7 @@ type Filter interface { // Filter that calls f. type FilterFunc func(ctx context.Context, update *Update) (bool, error) +// Allow implements Filter interface. func (filter FilterFunc) Allow(ctx context.Context, update *Update) (bool, error) { return filter(ctx, update) } @@ -127,16 +128,28 @@ func Command(command string, opts ...CommandFilterOption) *CommandFilter { return filter } +// getUpdateMessage returns first not nil message from update fields. +func getUpdateMessage(update *Update) *tg.Message { + return firstNotNil( + update.Message, + update.EditedMessage, + update.ChannelPost, + update.EditedChannelPost, + ) +} + // Allow checks if update is allowed by filter. func (filter *CommandFilter) Allow(ctx context.Context, update *Update) (bool, error) { - if update.Message == nil { + msg := getUpdateMessage(update) + + if msg == nil { return false, nil } - text := update.Message.Text + text := msg.Text if text == "" && !filter.ignoreCaption { - text = update.Message.Caption + text = msg.Caption } if text == "" { @@ -184,16 +197,18 @@ func Regexp(re *regexp.Regexp) Filter { return FilterFunc(func(ctx context.Context, update *Update) (bool, error) { var text string + msg := getUpdateMessage(update) + switch { - case update.Message != nil: - text = update.Message.Text + case msg != nil: + text = msg.Text - if text == "" && update.Message.Caption != "" { - text = update.Message.Caption + if text == "" && msg.Caption != "" { + text = msg.Caption } - if text == "" && update.Message.Poll != nil { - text = update.Message.Poll.Question + if text == "" && msg.Poll != nil { + text = msg.Poll.Question } case update.CallbackQuery != nil && update.CallbackQuery.Data != "": text = update.CallbackQuery.Data @@ -215,15 +230,11 @@ func ChatType(types ...tg.ChatType) Filter { return FilterFunc(func(ctx context.Context, update *Update) (bool, error) { var typ tg.ChatType + msg := getUpdateMessage(update) + switch { - case update.Message != nil: - typ = update.Message.Chat.Type - case update.EditedMessage != nil: - typ = update.EditedMessage.Chat.Type - case update.ChannelPost != nil: - typ = update.ChannelPost.Chat.Type - case update.EditedChannelPost != nil: - typ = update.EditedChannelPost.Chat.Type + case msg != nil: + typ = msg.Chat.Type case update.CallbackQuery != nil && update.CallbackQuery.Message != nil: typ = update.CallbackQuery.Message.Chat.Type case update.InlineQuery != nil: diff --git a/tgb/filter_test.go b/tgb/filter_test.go index bbe4cc3..8b1dae0 100644 --- a/tgb/filter_test.go +++ b/tgb/filter_test.go @@ -110,6 +110,16 @@ func TestCommandFilter(t *testing.T) { Update: &tg.Update{}, Allow: false, }, + { + Name: "ChannelPost", + Command: Command("start"), + Update: &tg.Update{ + ChannelPost: &tg.Message{ + Text: "/start azcv 5678", + }, + }, + Allow: true, + }, { Name: "InCaption", Command: Command("start",