diff --git a/README.md b/README.md index e8d27aa..a935aac 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ - [Helper methods](#helper-methods) - [Sending files](#sending-files) - [Downloading files](#downloading-files) + - [Interceptors](#interceptors) - [Updates](#updates) - [Handlers](#handlers) - [Typed Handlers](#typed-handlers) @@ -350,6 +351,47 @@ defer f.Close() // ... ``` +### Interceptors + +Interceptors are used to modify or process the request before it is sent to the server and the response before it is returned to the caller. It's like a [[tgb.Middleware](https://pkg.go.dev/github.com/mr-linch/go-tg/tgb#Middleware)], but for outgoing requests. + +All interceptors should be registered on the client before the request is made. + +```go +client := tg.New("", + tg.WithClientInterceptors( + tg.Interceptor(func(ctx context.Context, req *tg.Request, dst any, invoker tg.InterceptorInvoker) error { + started := time.Now() + + // before request + err := invoker(ctx, req, dst) + // after request + + log.Print("call %s took %s", req.Method, time.Since(started)) + + return err + }), + ), +) +``` + + +Arguments of the interceptor are: + - `ctx` - context of the request; + - `req` - request object [tg.Request](https://pkg.go.dev/github.com/mr-linch/go-tg#Request); + - `dst` - pointer to destination for the response, can be `nil` if the request is made with `DoVoid` method; + - `invoker` - function for calling the next interceptor or the actual request. + +Contrib package has some useful interceptors: + - [InterceptorRetryFloodError](https://pkg.go.dev/github.com/mr-linch/go-tg#NewInterceptorRetryFloodError) - retry request if the server returns a flood error. Parameters can be customized via options; + - [InterceptorRetryInternalServerError](https://pkg.go.dev/github.com/mr-linch/go-tg#NewInterceptorRetryInternalServerError) - retry request if the server returns an error. Parameters can be customized via options; + - [InterceptorMethodFilter](https://pkg.go.dev/github.com/mr-linch/go-tg#NewInterceptorMethodFilter) - call underlying interceptor only for specified methods; + - [InterceptorDefaultParseMethod](https://pkg.go.dev/github.com/mr-linch/go-tg#NewInterceptorDefaultParseMethod) - set default `parse_mode` for messages if not specified. + +Interceptors are called in the order they are registered. + +Example of using retry flood interceptor: [examples/retry-flood](https://github.com/mr-linch/go-tg/blob/main/examples/retry-flood/main.go) + ## Updates Everything related to receiving and processing updates is in the [`tgb`](https://pkg.go.dev/github.com/mr-linch/go-tg/tgb) package. diff --git a/client.go b/client.go index 8b7ed83..f469ddb 100644 --- a/client.go +++ b/client.go @@ -35,6 +35,9 @@ type Client struct { // contains cached bot info me *User meLock sync.Mutex + + interceptors []Interceptor + invoker InterceptorInvoker } // ClientOption is a function that sets some option for Client. @@ -62,6 +65,13 @@ func WithClientTestEnv() ClientOption { } } +// WithClientInterceptor adds interceptor to client. +func WithClientInterceptors(ints ...Interceptor) ClientOption { + return func(c *Client) { + c.interceptors = append(c.interceptors, ints...) + } +} + // New creates new Client with given token and options. func New(token string, options ...ClientOption) *Client { c := &Client{ @@ -78,9 +88,25 @@ func New(token string, options ...ClientOption) *Client { option(c) } + c.invoker = c.buildInvoker() + return c } +func (client *Client) buildInvoker() InterceptorInvoker { + invoker := client.invoke + + for i := len(client.interceptors) - 1; i >= 0; i-- { + invoker = func(next InterceptorInvoker, interceptor Interceptor) InterceptorInvoker { + return func(ctx context.Context, req *Request, dst any) error { + return interceptor(ctx, req, dst, next) + } + }(invoker, client.interceptors[i]) + } + + return invoker +} + func (client *Client) Token() string { return client.token } @@ -245,7 +271,7 @@ func (client *Client) executeStreaming( } } -func (client *Client) Do(ctx context.Context, req *Request, dst interface{}) error { +func (client *Client) invoke(ctx context.Context, req *Request, dst any) error { res, err := client.execute(ctx, req) if err != nil { return fmt.Errorf("execute: %w", err) @@ -268,6 +294,10 @@ func (client *Client) Do(ctx context.Context, req *Request, dst interface{}) err return nil } +func (client *Client) Do(ctx context.Context, req *Request, dst interface{}) error { + return client.invoker(ctx, req, dst) +} + // Download file by path from Client.GetFile method. // Don't forget to close ReadCloser. func (client *Client) Download(ctx context.Context, path string) (io.ReadCloser, error) { diff --git a/client_test.go b/client_test.go index 85c1084..542ef1c 100644 --- a/client_test.go +++ b/client_test.go @@ -157,3 +157,35 @@ func TestClient_Execute(t *testing.T) { } }) } + +func TestClientInterceptors(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/bot1234:secret/getMe", r.URL.Path) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true,"result":{"id":5556648742,"is_bot":true,"first_name":"go_tg_local_bot","username":"go_tg_local_bot","can_join_groups":true,"can_read_all_group_messages":false,"supports_inline_queries":false}}`)) + })) + + defer ts.Close() + + calls := 0 + + client := New( + "1234:secret", + WithClientDoer(ts.Client()), + WithClientServerURL(ts.URL), + WithClientInterceptors(func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + calls++ + return invoker(ctx, req, dst) + }), + ) + ctx := context.Background() + + err := client.Do(ctx, NewRequest("getMe"), nil) + + assert.NoError(t, err) + assert.Equal(t, 1, calls) + }) +} diff --git a/examples/echo-bot/main.go b/examples/echo-bot/main.go index d8f5385..fcfb275 100644 --- a/examples/echo-bot/main.go +++ b/examples/echo-bot/main.go @@ -29,7 +29,7 @@ func main() { tg.HTML.Line("Send me a message and I will echo it back to you. Also you can send me a reaction and I will react with the same emoji."), tg.HTML.Italic("🚀 Powered by", tg.HTML.Spoiler("go-tg")), ), - ).ParseMode(tg.HTML).LinkPreviewOptions(tg.LinkPreviewOptions{ + ).LinkPreviewOptions(tg.LinkPreviewOptions{ URL: "https://github.com/mr-linch/go-tg", PreferLargeMedia: true, }).DoVoid(ctx) @@ -41,7 +41,11 @@ func main() { return fmt.Errorf("answer chat action: %w", err) } - time.Sleep(time.Second) + select { + case <-time.After(1 * time.Second): + case <-ctx.Done(): + return ctx.Err() + } return msg.AnswerPhoto(tg.NewFileArgUpload( tg.NewInputFileBytes("gopher.png", gopherPNG), diff --git a/examples/retry-flood/main.go b/examples/retry-flood/main.go new file mode 100644 index 0000000..db4e9ab --- /dev/null +++ b/examples/retry-flood/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "context" + "log" + "sync" + "time" + + "github.com/mr-linch/go-tg" + "github.com/mr-linch/go-tg/examples" + "github.com/mr-linch/go-tg/tgb" +) + +func main() { + pm := tg.HTML + + onStart := func(ctx context.Context, msg *tgb.MessageUpdate) error { + return msg.Answer(pm.Text( + "👋 Hi, I'm retry flood demo, send me /spam command for start.", + "🔁 I will retry when receive flood wait error", + "Stop spam with shutdown bot service", + )).DoVoid(ctx) + } + + onSpam := func(ctx context.Context, mu *tgb.MessageUpdate) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + var wg sync.WaitGroup + + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := mu.Answer(pm.Text("🔁 spamming...")).DoVoid(ctx); err != nil { + log.Printf("answer: %v", err) + } + }() + } + + wg.Wait() + } + } + } + + examples.Run(tgb.NewRouter(). + Message(onSpam, tgb.Command("spam")). + ChannelPost(onSpam, tgb.Command("spam")). + Message(onStart). + ChannelPost(onStart). + Error(func(ctx context.Context, update *tgb.Update, err error) error { + log.Printf("error in handler: %v", err) + return nil + }), + + tg.WithClientInterceptors( + tg.Interceptor(func(ctx context.Context, req *tg.Request, dst any, invoker tg.InterceptorInvoker) error { + defer func(started time.Time) { + log.Printf("request: %s took: %s", req.Method, time.Since(started)) + }(time.Now()) + return invoker(ctx, req, dst) + }), + tg.NewInterceptorRetryFloodError( + // we override the default timeAfter function to log the retry flood delay + tg.WithInterceptorRetryFloodErrorTimeAfter(func(sleep time.Duration) <-chan time.Time { + log.Printf("retry flood error after %s", sleep) + return time.After(sleep) + }), + ), + ), + ) +} diff --git a/examples/run.go b/examples/run.go index a56a9bf..78b848c 100644 --- a/examples/run.go +++ b/examples/run.go @@ -15,19 +15,24 @@ import ( // Run runs bot with given router. // Exit on error. -func Run(handler tgb.Handler) { +func Run(handler tgb.Handler, opts ...tg.ClientOption) { ctx := context.Background() ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, os.Kill, syscall.SIGTERM) defer cancel() - if err := run(ctx, handler); err != nil { + if err := run(ctx, handler, nil, opts...); err != nil { log.Printf("error: %v", err) defer os.Exit(1) } } -func run(ctx context.Context, handler tgb.Handler) error { +func run( + ctx context.Context, + handler tgb.Handler, + do func(ctx context.Context, client *tg.Client) error, + opts ...tg.ClientOption, +) error { // define flags var ( flagToken string @@ -49,9 +54,9 @@ func run(ctx context.Context, handler tgb.Handler) error { return fmt.Errorf("token is required, provide it with -token flag") } - opts := []tg.ClientOption{ + opts = append(opts, tg.WithClientServerURL(flagServer), - } + ) if flagTestEnv { opts = append(opts, tg.WithClientTestEnv()) @@ -66,7 +71,9 @@ func run(ctx context.Context, handler tgb.Handler) error { log.Printf("authorized as %s", me.Username.Link()) - if flagWebhookURL != "" { + if do != nil { + return do(ctx, client) + } else if flagWebhookURL != "" { err = tgb.NewWebhook( handler, client, diff --git a/interceptors.go b/interceptors.go new file mode 100644 index 0000000..e4e3ef2 --- /dev/null +++ b/interceptors.go @@ -0,0 +1,207 @@ +package tg + +import ( + "context" + "errors" + "math" + "math/rand" + "net/http" + "time" +) + +type InterceptorInvoker func(ctx context.Context, req *Request, dst any) error + +// Interceptor is a function that intercepts request and response. +type Interceptor func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error + +type interceptorRetryFloodErrorOpts struct { + tries int + maxRetryAfter time.Duration + timeAfter func(time.Duration) <-chan time.Time +} + +// InterceptorRetryFloodErrorOption is an option for NewRetryFloodErrorInterceptor. +type InterceptorRetryFloodErrorOption func(*interceptorRetryFloodErrorOpts) + +// WithInterceptorRetryFloodErrorTries sets the number of tries. +func WithInterceptorRetryFloodErrorTries(tries int) InterceptorRetryFloodErrorOption { + return func(o *interceptorRetryFloodErrorOpts) { + o.tries = tries + } +} + +// WithInterceptorRetryFloodErrorMaxRetryAfter sets the maximum retry after duration. +func WithInterceptorRetryFloodErrorMaxRetryAfter(maxRetryAfter time.Duration) InterceptorRetryFloodErrorOption { + return func(o *interceptorRetryFloodErrorOpts) { + o.maxRetryAfter = maxRetryAfter + } +} + +// WithInterceptorRetryFloodErrorTimeAfter sets the time.After function. +func WithInterceptorRetryFloodErrorTimeAfter(timeAfter func(time.Duration) <-chan time.Time) InterceptorRetryFloodErrorOption { + return func(o *interceptorRetryFloodErrorOpts) { + o.timeAfter = timeAfter + } +} + +// NewInterceptorRetryFloodError returns a new interceptor that retries the request if the error is flood error. +// With that interceptor, calling of method that hit limit will be look like it will look like the request just takes unusually long. +// Under the hood, multiple HTTP requests are being performed, with the appropriate delays in between. +// +// Default tries is 3, maxRetryAfter is 1 hour, timeAfter is time.After. +// The interceptor will retry the request if the error is flood error with RetryAfter less than maxRetryAfter. +// The interceptor will wait for RetryAfter duration before retrying the request. +// The interceptor will retry the request for tries times. +func NewInterceptorRetryFloodError(opts ...InterceptorRetryFloodErrorOption) Interceptor { + options := interceptorRetryFloodErrorOpts{ + tries: 3, + maxRetryAfter: time.Hour, + timeAfter: time.After, + } + + for _, o := range opts { + o(&options) + } + + return func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + var err error + LOOP: + for i := 0; i < options.tries; i++ { + err = invoker(ctx, req, dst) + if err == nil { + return nil + } + + var tgErr *Error + if errors.As(err, &tgErr) && tgErr.Code == http.StatusTooManyRequests && tgErr.Parameters != nil { + if tgErr.Parameters.RetryAfterDuration() > options.maxRetryAfter { + return err + } + + select { + case <-options.timeAfter(tgErr.Parameters.RetryAfterDuration()): + continue LOOP + case <-ctx.Done(): + return ctx.Err() + } + } + + break + } + + return err + } +} + +type interceptorRetryInternalServerErrorOpts struct { + tries int + delay time.Duration + timeAfter func(time.Duration) <-chan time.Time +} + +// RetryInternalServerErrorOption is an option for NewRetryInternalServerErrorInterceptor. +type RetryInternalServerErrorOption func(*interceptorRetryInternalServerErrorOpts) + +// WithInterceptorRetryInternalServerErrorTries sets the number of tries. +func WithInterceptorRetryInternalServerErrorTries(tries int) RetryInternalServerErrorOption { + return func(o *interceptorRetryInternalServerErrorOpts) { + o.tries = tries + } +} + +// WithInterceptorRetryInternalServerErrorDelay sets the delay between tries. +// The delay calculated as delay * 2^i + random jitter, where i is the number of tries. +func WithInterceptorRetryInternalServerErrorDelay(delay time.Duration) RetryInternalServerErrorOption { + return func(o *interceptorRetryInternalServerErrorOpts) { + o.delay = delay + } +} + +// WithInterceptorRetryInternalServerErrorTimeAfter sets the time.After function. +func WithInterceptorRetryInternalServerErrorTimeAfter(timeAfter func(time.Duration) <-chan time.Time) RetryInternalServerErrorOption { + return func(o *interceptorRetryInternalServerErrorOpts) { + o.timeAfter = timeAfter + } +} + +// NewInterceptorRetryInternalServerError returns a new interceptor that retries the request if the error is internal server error. +// +// With that interceptor, calling of method that hit limit will be look like it will look like the request just takes unusually long. +// Under the hood, multiple HTTP requests are being performed, with the appropriate delays in between. +// +// Default tries is 10, delay is 100ms, timeAfter is time.After. +// The interceptor will retry the request if the error is internal server error. +// The interceptor will wait for delay * 2^i + random jitter before retrying the request, where i is the number of tries. +// The interceptor will retry the request for ten times. +func NewInterceptorRetryInternalServerError(opts ...RetryInternalServerErrorOption) Interceptor { + options := &interceptorRetryInternalServerErrorOpts{ + tries: 10, + delay: time.Millisecond * 100, + timeAfter: time.After, + } + + for _, o := range opts { + o(options) + } + + return func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + var err error + LOOP: + for i := 0; i < options.tries; i++ { + err = invoker(ctx, req, dst) + if err == nil { + return nil + } + + var tgErr *Error + if errors.As(err, &tgErr) && tgErr.Code == http.StatusInternalServerError { + // do backoff delay + backoffDelay := options.delay * time.Duration(math.Pow(2, float64(i))) + jitter := time.Duration(rand.Int63n(int64(backoffDelay))) + + select { + case <-options.timeAfter(backoffDelay + jitter): + continue LOOP + case <-ctx.Done(): + return ctx.Err() + } + } + + break + } + + return err + } +} + +// NewInterceptorDefaultParseMethod returns a new interceptor that sets the parse_method to the request if it is empty. +// Use in combination with NewInterceptorMethodFilter to filter and specify only needed methods. +// Like: +// +// NewInterceptorMethodFilter(NewInterceptorDefaultParseMethod(tg.HTML), "sendMessage", "editMessageText") +func NewInterceptorDefaultParseMethod(pm ParseMode) Interceptor { + return func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + if !req.Has("parse_mode") { + req.Stringer("parse_mode", pm) + } + + return invoker(ctx, req, dst) + } +} + +// ТewInterceptorMethodFilter returns a new filtering interceptor +// that calls the interceptor only for specified methods. +func NewInterceptorMethodFilter(interceptor Interceptor, methods ...string) Interceptor { + methodMap := make(map[string]struct{}, len(methods)) + for _, method := range methods { + methodMap[method] = struct{}{} + } + + return func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + if _, ok := methodMap[req.Method]; ok { + return interceptor(ctx, req, dst, invoker) + } + + return nil + } +} diff --git a/interceptors_test.go b/interceptors_test.go new file mode 100644 index 0000000..083c853 --- /dev/null +++ b/interceptors_test.go @@ -0,0 +1,271 @@ +package tg + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewInterceptorRetryFloodError(t *testing.T) { + t.Run("TestNoError", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return nil + }) + + interceptor := NewInterceptorRetryFloodError() + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.NoError(t, err, "should no return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) + + t.Run("NoTgError", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return errors.New("test") + }) + + interceptor := NewInterceptorRetryFloodError() + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) + + t.Run("Retry", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return &Error{Code: 429, Parameters: &ResponseParameters{RetryAfter: 1}} + }) + + var timeAfterCalls int + + interceptor := NewInterceptorRetryFloodError( + WithInterceptorRetryFloodErrorTries(3), + WithInterceptorRetryFloodErrorMaxRetryAfter(time.Second*2), + WithInterceptorRetryFloodErrorTimeAfter(func(time.Duration) <-chan time.Time { + timeAfterCalls++ + result := make(chan time.Time, 1) + result <- time.Now() + return result + }), + ) + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 3, calls, "should call invoker 3 times") + assert.Equal(t, 3, timeAfterCalls, "should call timeAfter 3 times") + }) + + t.Run("MaxRetryAfter", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return &Error{Code: 429, Parameters: &ResponseParameters{RetryAfter: 2}} + }) + + var timeAfterCalls int + + interceptor := NewInterceptorRetryFloodError( + WithInterceptorRetryFloodErrorTries(3), + WithInterceptorRetryFloodErrorMaxRetryAfter(time.Second), + WithInterceptorRetryFloodErrorTimeAfter(func(time.Duration) <-chan time.Time { + timeAfterCalls++ + result := make(chan time.Time, 1) + result <- time.Now() + return result + }), + ) + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 1, calls, "should call invoker once") + assert.Equal(t, 0, timeAfterCalls, "should call timeAfter once") + }) + + t.Run("Timeout", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return &Error{Code: 429, Parameters: &ResponseParameters{RetryAfter: 1}} + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + + interceptor := NewInterceptorRetryFloodError( + WithInterceptorRetryFloodErrorTries(10), + WithInterceptorRetryFloodErrorMaxRetryAfter(time.Second*2), + ) + + err := interceptor(ctx, &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) +} + +func TestNewInterceptorRetryInternalServerError(t *testing.T) { + t.Run("TestNoError", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return nil + }) + + interceptor := NewInterceptorRetryInternalServerError() + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.NoError(t, err, "should no return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) + + t.Run("NoTgError", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return errors.New("test") + }) + + interceptor := NewInterceptorRetryInternalServerError() + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) + + t.Run("Retry", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return &Error{Code: 500} + }) + + var timeAfterCalls int + + interceptor := NewInterceptorRetryInternalServerError( + WithInterceptorRetryInternalServerErrorTries(3), + WithInterceptorRetryInternalServerErrorDelay(time.Millisecond), + WithInterceptorRetryInternalServerErrorTimeAfter(func(time.Duration) <-chan time.Time { + defer func() { timeAfterCalls++ }() + + result := make(chan time.Time, 1) + result <- time.Now() + return result + }), + ) + + err := interceptor(context.Background(), &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 3, calls, "should call invoker 3 times") + assert.Equal(t, 3, timeAfterCalls, "should call timeAfter 3 times") + }) + + t.Run("Timeout", func(t *testing.T) { + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + return &Error{Code: 500} + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + + interceptor := NewInterceptorRetryInternalServerError( + WithInterceptorRetryInternalServerErrorTries(10), + WithInterceptorRetryInternalServerErrorDelay(time.Millisecond), + ) + + err := interceptor(ctx, &Request{}, nil, invoker) + + assert.Error(t, err, "should return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) +} + +func TestNewInterceptorMethodFilter(t *testing.T) { + t.Run("InWhitelist", func(t *testing.T) { + req := NewRequest("sendMessage") + + var calls int + + interceptor := Interceptor(func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + calls++ + return invoker(ctx, req, dst) + }) + + interceptor = NewInterceptorMethodFilter(interceptor, "sendMessage") + + err := interceptor(context.Background(), req, nil, InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + return nil + })) + + assert.NoError(t, err, "should no return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) + + t.Run("NotInWhitelist", func(t *testing.T) { + req := NewRequest("editMessageText") + + var calls int + + interceptor := Interceptor(func(ctx context.Context, req *Request, dst any, invoker InterceptorInvoker) error { + calls++ + return invoker(ctx, req, dst) + }) + + interceptor = NewInterceptorMethodFilter(interceptor, "sendMessage") + + err := interceptor(context.Background(), req, nil, InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + return nil + })) + + assert.NoError(t, err, "should no return error") + assert.Equal(t, 0, calls, "should call invoker once") + }) +} + +func TestNewInterceptorDefaultParseMethod(t *testing.T) { + t.Run("Ok", func(t *testing.T) { + req := NewRequest("sendMessage") + dst := &Response{} + + var calls int + + invoker := InterceptorInvoker(func(ctx context.Context, req *Request, dst any) error { + calls++ + assert.Equal(t, HTML.String(), req.args["parse_mode"], "should set parse_mode to HTML") + return nil + }) + + interceptor := NewInterceptorDefaultParseMethod(HTML) + + err := interceptor(context.Background(), req, dst, invoker) + + assert.NoError(t, err, "should no return error") + assert.Equal(t, 1, calls, "should call invoker once") + }) +} diff --git a/request.go b/request.go index 5b9ede9..c655ae0 100644 --- a/request.go +++ b/request.go @@ -83,6 +83,14 @@ func (r *Request) File(name string, arg FileArg) *Request { return r.InputFile(name, arg.Upload) } +func (r *Request) Has(name string) bool { + _, inJSON := r.json[name] + _, inArgs := r.args[name] + _, inFiles := r.files[name] + + return inJSON || inArgs || inFiles +} + func (r *Request) InputMediaSlice(name string, im []InputMedia) *Request { for _, v := range im { r.InputMedia(v) diff --git a/request_test.go b/request_test.go index 123f4f3..386d99f 100644 --- a/request_test.go +++ b/request_test.go @@ -223,3 +223,12 @@ func TestRequest_MarshalJSON(t *testing.T) { assert.Error(t, err) }) } + +func TestRequest_Has(t *testing.T) { + r := NewRequest("sendMessage") + + r.String("chat_id", "1") + + assert.True(t, r.Has("chat_id")) + assert.False(t, r.Has("text")) +} diff --git a/types_gen_ext.go b/types_gen_ext.go index 615ab91..79ba730 100644 --- a/types_gen_ext.go +++ b/types_gen_ext.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "strconv" + "time" ) type ChatID int64 @@ -1622,3 +1623,8 @@ func (mim *MaybeInaccessibleMessage) UnmarshalJSON(v []byte) error { return json.Unmarshal(v, mim.Message) } } + +// RetryAfterDuration returns duration for retry after. +func (rp *ResponseParameters) RetryAfterDuration() time.Duration { + return time.Duration(rp.RetryAfter) * time.Second +}