diff --git a/examples/README.md b/examples/README.md index 41a844d..07832aa 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,3 +15,5 @@ Listen port set to `:8000` by default, you can override it with the `-webhook-li | [text-filter](https://github.com/mr-linch/go-tg/tree/main/examples/text-filter) | Text Filter usage | Text filter, reply keyboard markup | | [webapps](https://github.com/mr-linch/go-tg/tree/main/examples/webapps) | Parse and validate Login Widget & WebApp data, host simple webapp | WebApps, Login Widget, Embed webhook to http.Mux | | [session-filter](https://github.com/mr-linch/go-tg/tree/main/examples/session-filter) | Simple form filling with persistent session | Router, Session Manager, Session Filters | +| [menu](https://github.com/mr-linch/go-tg/tree/main/examples/menu) | Hiearchical menu with API integration | ButtonLayout, TextMessageBuilder, CallbackDataFilter | +| [retry-flood](https://github.com/mr-linch/go-tg/tree/main/examples/retry-flood) | Retry on flood error | Interceptors | \ No newline at end of file diff --git a/examples/menu/client.go b/examples/menu/client.go new file mode 100644 index 0000000..e9177e6 --- /dev/null +++ b/examples/menu/client.go @@ -0,0 +1,148 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" +) + +// Primitive API client for https://jsonplaceholder.typicode.com +type API struct { + BaseURL string + Client *http.Client +} + +type User struct { + ID int `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + Email string `json:"email"` + Address struct { + Street string `json:"street"` + Suite string `json:"suite"` + City string `json:"city"` + Zipcode string `json:"zipcode"` + Geo struct { + Lat float64 `json:"lat,string"` + Lng float64 `json:"lng,string"` + } `json:"geo"` + } `json:"address"` + Phone string `json:"phone"` + Website string `json:"website"` + Company struct { + Name string `json:"name"` + CatchPhrase string `json:"catchPhrase"` + Bs string `json:"bs"` + } `json:"company"` +} + +type Post struct { + UserID int `json:"userId"` + ID int `json:"id"` + Title string `json:"title"` + Body string `json:"body"` +} + +type Comment struct { + PostID int `json:"postId"` + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + Body string `json:"body"` +} + +func (a *API) request(ctx context.Context, path string, params url.Values, dst any) error { + // better use url.JoinPath, but it's not available at go1.18 that specified as minimal version + endpoint := a.BaseURL + path + + if len(params) > 0 { + endpoinAsURL, err := url.Parse(endpoint) + if err != nil { + return fmt.Errorf("parse endpoint: %w", err) + } + + endpoinAsURL.RawQuery = params.Encode() + + endpoint = endpoinAsURL.String() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, http.NoBody) + if err != nil { + return fmt.Errorf("build http request: %w", err) + } + + res, err := a.Client.Do(req) + if err != nil { + return fmt.Errorf("execute http request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return fmt.Errorf("http status: %d", res.StatusCode) + } + + if err := json.NewDecoder(res.Body).Decode(dst); err != nil { + return fmt.Errorf("unmarshal: %w", err) + } + + return nil +} + +func (a *API) Users(ctx context.Context) (users []User, err error) { + err = a.request(ctx, "/users", nil, &users) + return +} + +func (a *API) User(ctx context.Context, id int) (user User, err error) { + err = a.request(ctx, fmt.Sprintf("/users/%d", id), nil, &user) + return +} + +type PostsParams struct { + UserID int `json:"userId"` +} + +func (a *API) Posts(ctx context.Context, params *PostsParams) (posts []Post, err error) { + vs := url.Values{} + + if params != nil { + if params.UserID != 0 { + vs.Set("userId", strconv.Itoa(params.UserID)) + } + } + + err = a.request(ctx, "/posts", vs, &posts) + + return +} + +func (a *API) Post(ctx context.Context, id int) (post Post, err error) { + err = a.request(ctx, fmt.Sprintf("/posts/%d", id), nil, &post) + return +} + +type CommentsParams struct { + PostID int `json:"postId"` +} + +func (a *API) Comments(ctx context.Context, params *CommentsParams) (comments []Comment, err error) { + vs := url.Values{} + + if params != nil { + if params.PostID != 0 { + vs.Set("postId", strconv.Itoa(params.PostID)) + } + } + + err = a.request(ctx, "/comments", vs, &comments) + + return +} + +func (a *API) Comment(ctx context.Context, id int) (comment Comment, err error) { + err = a.request(ctx, fmt.Sprintf("/comments/%d", id), nil, &comment) + return +} diff --git a/examples/menu/main.go b/examples/menu/main.go index 64a8467..abaa28a 100644 --- a/examples/menu/main.go +++ b/examples/menu/main.go @@ -3,73 +3,296 @@ package main import ( "context" + "fmt" + "net/http" "strconv" - "time" "github.com/mr-linch/go-tg" "github.com/mr-linch/go-tg/examples" "github.com/mr-linch/go-tg/tgb" ) -var pm = tg.HTML +type userDetailsCallbackData struct { + UserID int +} + +type userLocationCallbackData struct { + UserID int + Lat float64 + Lng float64 +} + +type postDetailsCallbackData struct { + UserID int + PostID int +} + +type commentDetailsCallbackData struct { + UserID int + PostID int + CommentID int +} + +var ( + userDetailsCallbackDataFilter = tgb.NewCallbackDataFilter[userDetailsCallbackData]( + "user_details", + ) + + userLocationCallbackDataFilter = tgb.NewCallbackDataFilter[userLocationCallbackData]( + "user_location", + ) + + userListCallbackDataFilter = tgb.NewCallbackDataFilter[struct{}]( + "user_list", + ) + + postDetailsCallbackDataFilter = tgb.NewCallbackDataFilter[postDetailsCallbackData]( + "post_details", + ) + + commentDetailsCallbackDataFilter = tgb.NewCallbackDataFilter[commentDetailsCallbackData]( + "comment_details", + ) +) + +func newUserListMessage(pm tg.ParseMode, users []User) *tgb.TextMessageCallBuilder { + buttons := make([]tg.InlineKeyboardButton, 0, len(users)) + for _, user := range users { + buttons = append(buttons, userDetailsCallbackDataFilter.MustButton( + user.Name, + userDetailsCallbackData{UserID: user.ID}, + )) + } -func newMenuMainMessage() *tgb.TextMessageCallBuilder { return tgb.NewTextMessageCallBuilder( pm.Text( - pm.Bold("👋 Hi, I'm demo of", pm.Code("tg.TextMessageCallBuilder")), + pm.Bold("👥 Users"), + "", + pm.Line("Total users: ", strconv.Itoa(len(users))), "", - pm.Italic("Use attached keyboard or commands to navigate"), + pm.Italic("Select user to view details:"), ), - ).ReplyMarkup(tg.NewInlineKeyboardMarkup( - tg.NewButtonRow( - tg.NewInlineKeyboardButtonCallback("menu 1", "menu_1"), - tg.NewInlineKeyboardButtonCallback("menu 2", "menu_2"), - tg.NewInlineKeyboardButtonCallback("menu 3", "menu_3"), + ). + ParseMode(pm). + ReplyMarkup( + tg.NewInlineKeyboardMarkup( + tg.NewButtonLayout(2, buttons...).Keyboard()..., + ), + ) +} + +func newUserDetailsMessage(pm tg.ParseMode, user User, posts []Post) *tgb.TextMessageCallBuilder { + buttons := make([]tg.InlineKeyboardButton, 0, len(posts)+1) + + for _, post := range posts { + buttons = append(buttons, postDetailsCallbackDataFilter.MustButton( + post.Title, + postDetailsCallbackData{PostID: post.ID, UserID: user.ID}, + )) + } + + layout := tg.NewButtonLayout[tg.InlineKeyboardButton](2) + + layout.Row(userLocationCallbackDataFilter.MustButton("📍 Location", userLocationCallbackData{ + UserID: user.ID, + Lat: user.Address.Geo.Lat, + Lng: user.Address.Geo.Lng, + })) + + layout.Add(buttons...) + + layout.Row(userListCallbackDataFilter.MustButton("🔙 Back", struct{}{})) + + buttons = append(buttons, userListCallbackDataFilter.MustButton("🔙 Back", struct{}{})) + + return tgb.NewTextMessageCallBuilder( + pm.Text( + pm.Bold("👤 User Details"), + "", + pm.Line("ID: ", strconv.Itoa(user.ID)), + pm.Line("Name: ", user.Name), + pm.Line("Username: ", user.Username), + pm.Line("Email: ", user.Email), + "", + pm.Bold("Address:"), + pm.Line("Street: ", user.Address.Street), + pm.Line("Suite: ", user.Address.Suite), + pm.Line("City: ", user.Address.City), + pm.Line("Zipcode: ", user.Address.Zipcode), + "", + pm.Line("Phone: ", user.Phone), + pm.Line("Website: ", user.Website), + "", + pm.Bold("Company:"), + pm.Line("Name: ", user.Company.Name), + pm.Line("Catch Phrase: ", user.Company.CatchPhrase), + pm.Line("Bs: ", user.Company.Bs), ), - )).ParseMode(pm) + ). + ReplyMarkup(tg.NewInlineKeyboardMarkup( + layout.Keyboard()..., + )). + ParseMode(pm) } -func newSubmenu(n int) *tgb.TextMessageCallBuilder { +func newPostDetails(pm tg.ParseMode, userID int, post Post, comments []Comment) *tgb.TextMessageCallBuilder { + buttons := make([]tg.InlineKeyboardButton, 0, len(comments)+1) + + for _, comment := range comments { + buttons = append(buttons, commentDetailsCallbackDataFilter.MustButton("💬 "+comment.Name, commentDetailsCallbackData{ + UserID: userID, + PostID: post.ID, + CommentID: comment.ID, + })) + } + + buttons = append(buttons, userDetailsCallbackDataFilter.MustButton("🔙 Back", userDetailsCallbackData{ + UserID: userID, + })) + return tgb.NewTextMessageCallBuilder( pm.Text( - pm.Bold("Menu ", strconv.Itoa(n)), + pm.Bold("📝 Post Details"), "", - pm.Bold("Now:", " ", pm.Code(time.Now().Format(time.RFC3339))), + pm.Line(pm.Bold("ID: "), strconv.Itoa(post.ID)), + pm.Line(pm.Bold("Title: "), post.Title), "", - pm.Italic("Use attached keyboard or commands to navigate"), + pm.Blockquote(post.Body), ), - ).ReplyMarkup( - tg.NewInlineKeyboardMarkup( - tg.NewButtonRow( - tg.NewInlineKeyboardButtonCallback("go back", "menu_main"), - tg.NewInlineKeyboardButtonCallback("refresh", "menu_"+strconv.Itoa(n)), - ), + ). + ParseMode(pm). + ReplyMarkup(tg.NewInlineKeyboardMarkup( + tg.NewButtonLayout(1, buttons...).Keyboard()..., + )) +} + +func newCommentDetails(pm tg.ParseMode, userID int, postID int, comment Comment) *tgb.TextMessageCallBuilder { + buttons := []tg.InlineKeyboardButton{ + postDetailsCallbackDataFilter.MustButton("🔙 Back to Post", postDetailsCallbackData{ + UserID: userID, + PostID: postID, + }), + + userDetailsCallbackDataFilter.MustButton("🔙 Back to User", userDetailsCallbackData{ + UserID: userID, + }), + } + + return tgb.NewTextMessageCallBuilder( + pm.Text( + pm.Bold("💬 Comment Details"), + "", + pm.Line(pm.Bold("ID: "), strconv.Itoa(comment.ID)), + pm.Line(pm.Bold("Name: "), comment.Name), + pm.Line(pm.Bold("Email: "), comment.Email), + "", + pm.Blockquote(comment.Body), ), - ).ParseMode(pm) + ). + ParseMode(pm). + ReplyMarkup(tg.NewInlineKeyboardMarkup( + tg.NewButtonLayout(1, buttons...).Keyboard()..., + )) } func main() { + client := API{ + BaseURL: "https://jsonplaceholder.typicode.com", + Client: http.DefaultClient, + } + + newUserListBuilder := func(ctx context.Context) (*tgb.TextMessageCallBuilder, error) { + users, err := client.Users(ctx) + if err != nil { + return nil, fmt.Errorf("get users: %w", err) + } + + return newUserListMessage(tg.HTML, users), nil + } + examples.Run(tgb.NewRouter(). // start message and cbq handlers Message(func(ctx context.Context, msg *tgb.MessageUpdate) error { - return msg.Update.Reply(ctx, newMenuMainMessage().AsSend(msg.Chat)) + builder, err := newUserListBuilder(ctx) + if err != nil { + return err + } + + return msg.Update.Reply(ctx, builder.AsSend(msg.Chat)) }, tgb.Command("start")). CallbackQuery(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate) error { - return cbq.Update.Reply(ctx, newMenuMainMessage().AsEditTextFromCBQ(cbq.CallbackQuery)) - }, tgb.TextEqual("menu_main")). + _ = cbq.Answer().DoVoid(ctx) - // switch menu handlers - CallbackQuery(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate) error { - _ = cbq.Update.Reply(ctx, cbq.Answer()) + builder, err := newUserListBuilder(ctx) + if err != nil { + return err + } + + return cbq.Update.Reply(ctx, builder.AsEditTextFromCBQ(cbq.CallbackQuery)) + }, userListCallbackDataFilter.Filter()). + CallbackQuery(userDetailsCallbackDataFilter.Handler(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate, cbd userDetailsCallbackData) error { + _ = cbq.Answer().DoVoid(ctx) + + user, err := client.User(ctx, cbd.UserID) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + + posts, err := client.Posts(ctx, &PostsParams{ + UserID: cbd.UserID, + }) + if err != nil { + return fmt.Errorf("get posts: %w", err) + } + + return cbq.Update.Reply(ctx, + newUserDetailsMessage(tg.HTML, user, posts). + AsEditTextFromCBQ(cbq.CallbackQuery), + ) + }), userDetailsCallbackDataFilter.Filter()). + CallbackQuery(userLocationCallbackDataFilter.Handler(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate, cbd userLocationCallbackData) error { + _ = cbq.Answer().DoVoid(ctx) + + return cbq.Update.Reply(ctx, + tg.NewSendLocationCall( + cbq.Message.Chat(), + cbd.Lat, + cbd.Lng, + ), + ) + }), userLocationCallbackDataFilter.Filter()). + CallbackQuery(postDetailsCallbackDataFilter.Handler(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate, cbd postDetailsCallbackData) error { + _ = cbq.Answer().DoVoid(ctx) + + post, err := client.Post(ctx, cbd.PostID) + if err != nil { + return fmt.Errorf("get post: %w", err) + } + + comments, err := client.Comments(ctx, &CommentsParams{ + PostID: cbd.PostID, + }) + if err != nil { + return fmt.Errorf("get comments: %w", err) + } - menuNum := cbq.Data[len("menu_"):] + return cbq.Update.Reply(ctx, + newPostDetails(tg.HTML, cbd.UserID, post, comments). + AsEditTextFromCBQ(cbq.CallbackQuery), + ) + }), postDetailsCallbackDataFilter.Filter()). + CallbackQuery(commentDetailsCallbackDataFilter.Handler(func(ctx context.Context, cbq *tgb.CallbackQueryUpdate, cbd commentDetailsCallbackData) error { + _ = cbq.Answer().DoVoid(ctx) - n, err := strconv.Atoi(menuNum) + comment, err := client.Comment(ctx, cbd.CommentID) if err != nil { - return cbq.AnswerText("invalid menu number", true).DoVoid(ctx) + return fmt.Errorf("get comment: %w", err) } - return cbq.Update.Reply(ctx, newSubmenu(n).AsEditTextFromCBQ(cbq.CallbackQuery)) - }, tgb.Any(tgb.TextHasPrefix("menu_"))), + return cbq.Update.Reply(ctx, + newCommentDetails(tg.HTML, cbd.UserID, cbd.PostID, comment). + AsEditTextFromCBQ(cbq.CallbackQuery), + ) + }), commentDetailsCallbackDataFilter.Filter()), ) } diff --git a/tgb/callback_data.go b/tgb/callback_data.go new file mode 100644 index 0000000..425212b --- /dev/null +++ b/tgb/callback_data.go @@ -0,0 +1,432 @@ +package tgb + +import ( + "context" + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/mr-linch/go-tg" +) + +// CallbackDataCodec is a helper for parsing and serializing callback data. +type CallbackDataCodec struct { + delimiter rune + intBase int + floatFmt byte + floatPrec int + disableLengthCheck bool +} + +const callbackDataMaxLen = 64 + +// CallbackDataIsTooLongError is returned when callback data length is too long. +type CallbackDataIsTooLongError struct { + Length int +} + +// Error returns a string representation of the error. +func (e *CallbackDataIsTooLongError) Error() string { + return fmt.Sprintf("callback data length is too long: %v, max: %v", e.Length, callbackDataMaxLen) +} + +// NewCallbackDataParser creates a new CallbackDataParser with default options. +type CallbackDataCodecOption func(*CallbackDataCodec) + +// WithCallbackDataCodecDelimiter sets a delimiter for callback data. +// Default is ':'. +func WithCallbackDataCodecDelimiter(delimiter rune) CallbackDataCodecOption { + return func(p *CallbackDataCodec) { + p.delimiter = delimiter + } +} + +// WithCallbackDataCodecIntBase sets a base for integer fields in callback data. +// Default is 36. +func WithCallbackDataCodecIntBase(base int) CallbackDataCodecOption { + return func(p *CallbackDataCodec) { + p.intBase = base + } +} + +// WithCallbackDataCodecFloatFmt sets a format for float fields in callback data. +// Default is 'f'. +func WithCallbackDataCodecFloatFmt(fmt byte) CallbackDataCodecOption { + return func(p *CallbackDataCodec) { + p.floatFmt = fmt + } +} + +// WithCallbackDataCodecFloatPrec sets a precision for float fields in callback data. +// Default is -1. +func WithCallbackDataCodecFloatPrec(prec int) CallbackDataCodecOption { + return func(p *CallbackDataCodec) { + p.floatPrec = prec + } +} + +// WithCallbackDataCodecDisableLengthCheck disables length check for callback data. +// Default is false. +func WithCallbackDataCodecDisableLengthCheck(disable bool) CallbackDataCodecOption { + return func(p *CallbackDataCodec) { + p.disableLengthCheck = disable + } +} + +// NewCallackDataCodec creates a new CallbackDataParser with custom options. +// With no options it will use ':' as a delimiter, 36 as a base for integer fields, 'f' as a format and -1 as a precision for float fields. +func NewCallackDataCodec(opts ...CallbackDataCodecOption) *CallbackDataCodec { + parser := &CallbackDataCodec{ + delimiter: ':', + intBase: 36, + floatFmt: 'f', + floatPrec: -1, + disableLengthCheck: false, + } + + for _, opt := range opts { + opt(parser) + } + + return parser +} + +func (p *CallbackDataCodec) getIntFieldBaseOrDefault(field reflect.StructField) (int, error) { + baseStr, ok := field.Tag.Lookup("tgbase") + if !ok { + return p.intBase, nil + } + + base, err := strconv.Atoi(baseStr) + if err != nil { + return 0, fmt.Errorf("invalid base value: %w", err) + } + + return base, nil +} + +func (p *CallbackDataCodec) getFloatFieldFmtOrDefault(field reflect.StructField) (byte, error) { + fmtStr, ok := field.Tag.Lookup("tgfmt") + if !ok { + return p.floatFmt, nil + } + + if len(fmtStr) != 1 { + return 0, fmt.Errorf("invalid fmt value: %v", fmtStr) + } + + return fmtStr[0], nil +} + +func (p *CallbackDataCodec) getFloatFieldPrecOrDefault(field reflect.StructField) (int, error) { + precStr, ok := field.Tag.Lookup("tgprec") + if !ok { + return p.floatPrec, nil + } + + prec, err := strconv.Atoi(precStr) + if err != nil { + return 0, fmt.Errorf("invalid prec value: %w", err) + } + + return prec, nil +} + +// MarshalCallbackData serializes a struct into callback data. +// This data will be in format prefix:field_value_1:field_value_2:...:field_value_n +// Only plain structures are supported. +func (p *CallbackDataCodec) Encode(src any) (string, error) { + structValue := reflect.ValueOf(src) + + if structValue.Type().Kind() == reflect.Ptr { + structValue = structValue.Elem() + } + + if !structValue.IsValid() { + return "", fmt.Errorf("src is nil") + } + + if structValue.Kind() != reflect.Struct { + return "", fmt.Errorf("src should be a struct") + } + + var result strings.Builder + + fieldsCount := structValue.NumField() + + structType := structValue.Type() + + for i := 0; i < fieldsCount; i++ { + if i > 0 { + result.WriteRune(p.delimiter) + } + + field := structValue.Field(i) + structField := structType.Field(i) + + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + result.WriteString("1") + } else { + result.WriteString("0") + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + base, err := p.getIntFieldBaseOrDefault(structField) + if err != nil { + return "", fmt.Errorf("field %v: %w", structField.Name, err) + } + + result.WriteString(strconv.FormatInt(field.Int(), base)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + base, err := p.getIntFieldBaseOrDefault(structField) + if err != nil { + return "", fmt.Errorf("field %v: %w", structField.Name, err) + } + + result.WriteString(strconv.FormatUint(field.Uint(), base)) + case reflect.String: + result.WriteString(field.String()) + case reflect.Float32, reflect.Float64: + format, err := p.getFloatFieldFmtOrDefault(structField) + if err != nil { + return "", fmt.Errorf("field %v: %w", structField.Name, err) + } + + prec, err := p.getFloatFieldPrecOrDefault(structField) + if err != nil { + return "", fmt.Errorf("field %v: %w", structField.Name, err) + } + + result.WriteString(strconv.FormatFloat(field.Float(), format, prec, 64)) + default: + return "", fmt.Errorf("unsupported field type: %v", field.Kind()) + } + } + + if !p.disableLengthCheck && result.Len() > callbackDataMaxLen { + return "", &CallbackDataIsTooLongError{Length: result.Len()} + } + + return result.String(), nil +} + +func (p *CallbackDataCodec) Decode(data string, dst any) error { + structValue := reflect.ValueOf(dst) + + if structValue.Type().Kind() != reflect.Ptr { + return fmt.Errorf("dst should be a pointer to a struct") + } + + structValue = structValue.Elem() + + if structValue.Kind() != reflect.Struct { + return fmt.Errorf("dst should be a pointer to a struct") + } + + fieldsCount := structValue.NumField() + + structType := structValue.Type() + + var values []string + if len(data) > 0 { + values = strings.Split(data, string(p.delimiter)) + } + + if len(values) != fieldsCount { + return fmt.Errorf("invalid data length: expected %v, got %v", fieldsCount, len(values)) + } + + for i := 0; i < fieldsCount; i++ { + field := structValue.Field(i) + structField := structType.Field(i) + + switch field.Kind() { + case reflect.Bool: + if values[i] == "1" { + field.SetBool(true) + } else if values[i] == "0" { + field.SetBool(false) + } else { + return fmt.Errorf("invalid bool value: %v", values[i]) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + base, err := p.getIntFieldBaseOrDefault(structField) + if err != nil { + return fmt.Errorf("field %v: %w", structField.Name, err) + } + + value, err := strconv.ParseInt(values[i], base, 64) + if err != nil { + return fmt.Errorf("field %v: %w", structField.Name, err) + } + + field.SetInt(value) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + base, err := p.getIntFieldBaseOrDefault(structField) + if err != nil { + return fmt.Errorf("field %v: %w", structField.Name, err) + } + + value, err := strconv.ParseUint(values[i], base, 64) + if err != nil { + return fmt.Errorf("field %v: %w", structField.Name, err) + } + + field.SetUint(value) + case reflect.String: + field.SetString(values[i]) + case reflect.Float32, reflect.Float64: + value, err := strconv.ParseFloat(values[i], 64) + if err != nil { + return fmt.Errorf("field %v: %w", structField.Name, err) + } + + field.SetFloat(value) + default: + return fmt.Errorf("unsupported field type: %v", field.Kind()) + } + } + + return nil +} + +var DefaultCallbackDataCodec = NewCallackDataCodec() + +// EncodeCallbackData serializes a struct into callback data using default parser. +func EncodeCallbackData(src any) (string, error) { + return DefaultCallbackDataCodec.Encode(src) +} + +// DecodeCallbackData deserializes callback data into a struct using default parser. +func DecodeCallbackData(data string, dst any) error { + return DefaultCallbackDataCodec.Decode(data, dst) +} + +type CallbackDataFilter[T any] struct { + prefix string + codec *CallbackDataCodec +} + +// NewCallbackDataFilter creates a new CallbackDataPrefixFilter with default options. +func NewCallbackDataFilter[T any](prefix string, opts ...CallbackDataCodecOption) *CallbackDataFilter[T] { + return &CallbackDataFilter[T]{ + prefix: prefix, + codec: NewCallackDataCodec(opts...), + } +} + +// MustButton returns a new tg.InlineKeyboardButton with the given data as callback data. +// If an error occurs while encoding, empty button will be returned. +func (p *CallbackDataFilter[T]) MustButton(text string, v T) tg.InlineKeyboardButton { + data, err := p.Encode(v) + if err != nil { + return tg.InlineKeyboardButton{} + } + + return tg.NewInlineKeyboardButtonCallback(text, data) +} + +// Button returns a new tg.InlineKeyboardButton with the given data as callback data. +// If an error occurs while encoding, it will be returned. +func (p *CallbackDataFilter[T]) Button(text string, v T) (tg.InlineKeyboardButton, error) { + data, err := p.Encode(v) + if err != nil { + return tg.InlineKeyboardButton{}, fmt.Errorf("encode: %w", err) + } + + return tg.NewInlineKeyboardButtonCallback(text, data), nil +} + +// Encode serializes a struct into callback data using the filter's parser. +func (p *CallbackDataFilter[T]) Encode(src T) (string, error) { + body, err := p.codec.Encode(src) + if err != nil { + return "", fmt.Errorf("body decode: %w", err) + } + + var builder strings.Builder + + builder.WriteString(p.prefix) + builder.WriteRune(p.codec.delimiter) + builder.WriteString(body) + + return builder.String(), nil +} + +// Decode deserializes callback data into a struct using the filter's codec. +// It checks if the data has the correct prefix. +// If not, an error will be returned. +func (p *CallbackDataFilter[T]) Decode(data string) (T, error) { + var dst T + if !strings.HasPrefix(data, p.prefix) { + return dst, fmt.Errorf("invalid prefix: expected %v, got %v", p.prefix, data) + } + + data = strings.TrimPrefix(data, p.prefix+string(p.codec.delimiter)) + + err := p.codec.Decode(data, &dst) + if err != nil { + return dst, fmt.Errorf("body decode: %w", err) + } + + return dst, nil +} + +// Filter returns a tgb.Filter for the given prefix +// It checks if the data has the correct prefix. +// If not, it will return false. +func (p *CallbackDataFilter[T]) Filter() Filter { + prefixWithDelimiter := p.prefix + string(p.codec.delimiter) + + return FilterFunc(func(ctx context.Context, update *Update) (bool, error) { + if update.CallbackQuery == nil { + return false, nil + } + + if strings.HasPrefix(update.CallbackQuery.Data, prefixWithDelimiter) { + return true, nil + } + + return false, nil + }) +} + +// FilterFunc returns a tgb.Filter for the given prefix and a custom data check function. +// It checks if the data has the correct prefix and if the custom function returns true. +func (p *CallbackDataFilter[T]) FilterFunc(check func(v T) bool) Filter { + prefixWithDelimiter := p.prefix + string(p.codec.delimiter) + + return FilterFunc(func(ctx context.Context, update *Update) (bool, error) { + if update.CallbackQuery == nil { + return false, nil + } + + v, err := p.Decode(update.CallbackQuery.Data) + if err != nil { + return false, fmt.Errorf("decode: %w", err) + } + + if strings.HasPrefix(update.CallbackQuery.Data, prefixWithDelimiter) && check(v) { + return true, nil + } + + return false, nil + }) +} + +type CallbackDataFilterHandler[T any] func(ctx context.Context, cbq *CallbackQueryUpdate, cbd T) error + +// Handler returns a tgb.CallbackQueryHandler that wraps the given handler with decoded callback data. +// If an error occurs while decoding, it will be returned and passed handler will not be called. +func (p *CallbackDataFilter[T]) Handler(handler CallbackDataFilterHandler[T]) CallbackQueryHandler { + return func(ctx context.Context, cqu *CallbackQueryUpdate) error { + cbd, err := p.Decode(cqu.CallbackQuery.Data) + if err != nil { + return fmt.Errorf("decode: %w", err) + } + + return handler(ctx, cqu, cbd) + } +} diff --git a/tgb/callback_data_test.go b/tgb/callback_data_test.go new file mode 100644 index 0000000..fd23098 --- /dev/null +++ b/tgb/callback_data_test.go @@ -0,0 +1,582 @@ +package tgb + +import ( + "context" + "testing" + + tg "github.com/mr-linch/go-tg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCallbackDataParser(t *testing.T) { + parser := NewCallackDataCodec( + WithCallbackDataCodecDelimiter('$'), + WithCallbackDataCodecIntBase(16), + WithCallbackDataCodecFloatFmt('e'), + WithCallbackDataCodecFloatPrec(3), + WithCallbackDataCodecDisableLengthCheck(true), + ) + + assert.Equal(t, '$', parser.delimiter) + assert.Equal(t, 16, parser.intBase) + assert.Equal(t, byte('e'), parser.floatFmt) + assert.Equal(t, 3, parser.floatPrec) +} + +func TestCallbackDataParserEncode(t *testing.T) { + t.Run("NotStruct", func(t *testing.T) { + _, err := EncodeCallbackData(1) + assert.ErrorContains(t, err, "src should be a struct") + }) + + t.Run("Nil", func(t *testing.T) { + type test struct { + } + var nilStruct *test + _, err := EncodeCallbackData(nilStruct) + assert.ErrorContains(t, err, "src is nil") + }) + + t.Run("Empty", func(t *testing.T) { + type test struct{} + cbd, err := EncodeCallbackData(test{}) + require.NoError(t, err) + + assert.Equal(t, "", cbd) + }) + + t.Run("AllTypes", func(t *testing.T) { + type test struct { + Bool bool + BoolFalse bool + Int int + Uint uint `tgbase:"10"` + String string + Float32 float32 `tgfmt:"f" tgprec:"2"` + Float64 float64 `tgprec:"3"` + Floag64DefaultPrec float64 + } + + cbd, err := EncodeCallbackData(test{ + Bool: true, + Int: -1234567890, + Uint: 1234567890, + String: "xyz", + Float32: 123.456, + Float64: 123.4564, + Floag64DefaultPrec: 123.45, + }) + require.NoError(t, err) + + assert.Equal(t, "1:0:-kf12oi:1234567890:xyz:123.46:123.456:123.45", cbd) + }) + + t.Run("InvalidInt", func(t *testing.T) { + type test struct { + Int int `tgbase:"invalid"` + } + + _, err := EncodeCallbackData(test{}) + assert.ErrorContains(t, err, "invalid base") + }) + + t.Run("InvalidUint", func(t *testing.T) { + type test struct { + Uint uint `tgbase:"invalid"` + } + + _, err := EncodeCallbackData(test{}) + assert.ErrorContains(t, err, "invalid base") + }) + + t.Run("InvalidFloatFmt", func(t *testing.T) { + type test struct { + Float32 float32 `tgfmt:"invalid"` + } + + _, err := EncodeCallbackData(test{}) + assert.ErrorContains(t, err, "invalid fmt value") + }) + + t.Run("InvalidFloatPrec", func(t *testing.T) { + type test struct { + Float32 float32 `tgprec:"invalid"` + } + + _, err := EncodeCallbackData(test{}) + assert.ErrorContains(t, err, "invalid prec value") + }) + + t.Run("UnsupportedFieldType", func(t *testing.T) { + type test struct { + Unsupported chan int + } + + _, err := EncodeCallbackData(test{}) + assert.ErrorContains(t, err, "unsupported field type: chan") + }) + + t.Run("CallbackDataIsTooLong", func(t *testing.T) { + type test struct { + Str string + } + + _, err := EncodeCallbackData(test{ + Str: "12345678901234567890123456789012345678901234567890123456789012345678901234567890", + }) + assert.ErrorContains(t, err, "callback data length is too long: 80, max: 64") + }) +} + +func TestCallbackDataParserDecode(t *testing.T) { + t.Run("NotStruct", func(t *testing.T) { + var v int + err := DecodeCallbackData("", &v) + assert.ErrorContains(t, err, "dst should be a pointer to a struct") + }) + + t.Run("Nil", func(t *testing.T) { + type test struct { + } + var nilStruct *test + err := DecodeCallbackData("", nilStruct) + assert.ErrorContains(t, err, "dst should be a pointer to a struct") + + var notNilStruct test + err = DecodeCallbackData("", notNilStruct) + assert.ErrorContains(t, err, "dst should be a pointer to a struct") + }) + + t.Run("InvalidDataLength", func(t *testing.T) { + var dst struct { + A int + B int + } + + err := DecodeCallbackData("1", &dst) + + assert.ErrorContains(t, err, "invalid data length") + }) + + t.Run("InvalidBoolValue", func(t *testing.T) { + type test struct { + Bool bool + } + + var dst test + err := DecodeCallbackData("invalid", &dst) + assert.ErrorContains(t, err, "invalid bool value") + }) + + t.Run("InvalidInt", func(t *testing.T) { + var dst struct { + Int int `tgbase:"invalid"` + } + + err := DecodeCallbackData("invalid", &dst) + assert.ErrorContains(t, err, "invalid syntax") + + var dst2 struct { + Int int `tgbase:"102"` + } + + err = DecodeCallbackData("invalid", &dst2) + assert.ErrorContains(t, err, "invalid base 102") + }) + + t.Run("InvalidInt", func(t *testing.T) { + var dst struct { + Uint uint `tgbase:"invalid"` + } + + err := DecodeCallbackData("invalid", &dst) + assert.ErrorContains(t, err, "invalid syntax") + + var dst2 struct { + Uint uint `tgbase:"102"` + } + + err = DecodeCallbackData("invalid", &dst2) + assert.ErrorContains(t, err, "invalid base 102") + }) + + t.Run("InvalidFloat", func(t *testing.T) { + var dst struct { + Float32 float32 `tgfmt:"invalid"` + } + + err := DecodeCallbackData("invalid", &dst) + assert.ErrorContains(t, err, "invalid syntax") + + var dst2 struct { + Float32 float32 `tgfmt:"e"` + } + + err = DecodeCallbackData("invalid", &dst2) + assert.ErrorContains(t, err, "invalid syntax") + + var dst3 struct { + Float32 float32 `tgfmt:"e" tgprec:"invalid"` + } + + err = DecodeCallbackData("invalid", &dst3) + assert.ErrorContains(t, err, "invalid syntax") + + }) + + t.Run("Empty", func(t *testing.T) { + type test struct{} + var dst test + err := DecodeCallbackData("", &dst) + require.NoError(t, err) + }) + + t.Run("AllTypes", func(t *testing.T) { + type test struct { + Bool bool + FalseBool bool + Int int + Uint uint `tgbase:"10"` + String string + Float32 float32 `tgfmt:"f" tgprec:"1"` + Float64 float64 `tgprec:"1"` + } + + var dst test + err := DecodeCallbackData("1:0:-kf12oi:1234567890:xyz:123.4:123.5", &dst) + require.NoError(t, err) + + assert.Equal(t, test{ + Bool: true, + Int: -1234567890, + Uint: 1234567890, + String: "xyz", + Float32: 123.4, + Float64: 123.5, + }, dst) + }) + + t.Run("UnsupportedFieldType", func(t *testing.T) { + type test struct { + Unsupported chan int + } + + var dst test + err := DecodeCallbackData("1", &dst) + assert.ErrorContains(t, err, "unsupported field type: chan") + }) +} + +func TestCallbackDataFilter(t *testing.T) { + t.Run("ButtonError", func(t *testing.T) { + type test struct { + invalidType chan int // nolint:unused + } + + filter := NewCallbackDataFilter[test]("prefix") + + _, err := filter.Button("test", test{}) + require.Error(t, err) + }) + + t.Run("ButtonOk", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + btn, err := filter.Button("test", test{Bool: true}) + require.NoError(t, err) + assert.Equal(t, "prefix:1", btn.CallbackData) + }) + + t.Run("MustButtonOk", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + btn := filter.MustButton("test", test{Bool: true}) + assert.Equal(t, "prefix:1", btn.CallbackData) + }) + + t.Run("MustButtonError", func(t *testing.T) { + type test struct { + invalidType chan int // nolint:unused + } + + filter := NewCallbackDataFilter[test]("prefix") + + x := filter.MustButton("test", test{}) + assert.Zero(t, x) + }) + + t.Run("CallbackDataEmpty", func(t *testing.T) { + type empty struct{} + + filter := NewCallbackDataFilter[empty]("prefix") + + btn := filter.MustButton("test", empty{}) + + assert.Equal(t, "prefix:", btn.CallbackData) + }) + + t.Run("Decode", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + btn := filter.MustButton("test", test{Bool: true}) + + decoded, err := filter.Decode(btn.CallbackData) + require.NoError(t, err) + assert.Equal(t, test{Bool: true}, decoded) + }) + + t.Run("DecodeErrorInvalidPrefix", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + _, err := filter.Decode("invalid:1") + assert.ErrorContains(t, err, "invalid prefix") + }) + + t.Run("DecodeErrorInvalidData", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + _, err := filter.Decode("prefix:invalid") + assert.ErrorContains(t, err, "invalid bool value") + }) + + t.Run("Handler", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + calls := 0 + + handler := filter.Handler(func(ctx context.Context, cbq *CallbackQueryUpdate, cbd test) error { + calls++ + assert.Equal(t, test{Bool: true}, cbd) + return nil + }) + + err := handler(context.Background(), &CallbackQueryUpdate{ + CallbackQuery: &tg.CallbackQuery{ + Data: filter.MustButton("test", test{Bool: true}).CallbackData, + }, + }) + + assert.NoError(t, err) + assert.Equal(t, 1, calls) + }) + + t.Run("HandlerError", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + handler := filter.Handler(func(ctx context.Context, cbq *CallbackQueryUpdate, cbd test) error { + return assert.AnError + }) + + err := handler(context.Background(), &CallbackQueryUpdate{ + CallbackQuery: &tg.CallbackQuery{ + Data: filter.MustButton("test", test{Bool: true}).CallbackData, + }, + }) + + assert.ErrorContains(t, err, "assert.AnError") + }) + + t.Run("HandlerErrorInvalidData", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + handler := filter.Handler(func(ctx context.Context, cbq *CallbackQueryUpdate, cbd test) error { + return assert.AnError + }) + + err := handler(context.Background(), &CallbackQueryUpdate{ + CallbackQuery: &tg.CallbackQuery{ + Data: "invalid:1", + }, + }) + + assert.ErrorContains(t, err, "invalid prefix") + }) + + t.Run("HandlerFilterTrue", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.Filter().Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix:1", + }, + }, + }) + require.NoError(t, err) + assert.True(t, allowed) + }) + + t.Run("HandlerFilterFalse", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.Filter().Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix-other:1", + }, + }, + }) + require.NoError(t, err) + assert.False(t, allowed) + }) + + t.Run("HanlerFilterNotCallbackQuery", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.Filter().Allow(context.Background(), &Update{ + Update: &tg.Update{ + Message: &tg.Message{}, + }, + }) + require.NoError(t, err) + assert.False(t, allowed) + }) + + t.Run("HandlerFilterFuncNotCallbackQuery", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.FilterFunc(func(v test) bool { + return true + }).Allow(context.Background(), &Update{ + Update: &tg.Update{ + Message: &tg.Message{}, + }, + }) + + require.NoError(t, err) + assert.False(t, allowed) + }) + + t.Run("HandlerFilterFuncDecodeError", func(t *testing.T) { + type test struct { + Bool chan int + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.FilterFunc(func(v test) bool { + return true + }).Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix:invalid", + }, + }, + }) + + require.Error(t, err) + assert.False(t, allowed) + }) + + t.Run("HandlerFilterFuncTrue", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.FilterFunc(func(v test) bool { + return true + }).Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix:1", + }, + }, + }) + + require.NoError(t, err) + assert.True(t, allowed) + }) + + t.Run("HandlerFilterFuncFalse", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.FilterFunc(func(v test) bool { + return false + }).Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix:1", + }, + }, + }) + + require.NoError(t, err) + assert.False(t, allowed) + }) + + t.Run("HandlerFilterFuncOKParsed", func(t *testing.T) { + type test struct { + Bool bool + } + + filter := NewCallbackDataFilter[test]("prefix") + + allowed, err := filter.FilterFunc(func(v test) bool { + return v.Bool + }).Allow(context.Background(), &Update{ + Update: &tg.Update{ + CallbackQuery: &tg.CallbackQuery{ + Data: "prefix:1", + }, + }, + }) + + require.NoError(t, err) + assert.True(t, allowed) + }) +}