Skip to content

Commit

Permalink
Add ErrKeepAndRetryTaskLater (#11)
Browse files Browse the repository at this point in the history
This will cause a retry but, when possible, it will avoid requeueing
the message and instead allow the same message to be processed again.

Also fix a bug so that retriable errors don't get recorded as an error
(it appears this was the intended behavior, and it's definitely what
I want).
  • Loading branch information
cameron-dunn-sublime authored Apr 26, 2022
1 parent 2b032e4 commit 0fb9883
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 15 deletions.
7 changes: 7 additions & 0 deletions v1/brokers/iface/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ type Broker interface {
AdjustRoutingKey(s *tasks.Signature)
}

type RetrySameMessage interface {
Broker

// RetryMessage Does not return an error because, at least with current use case, all errors should just be ignored
RetryMessage(s *tasks.Signature)
}

// TaskProcessor - can process a delivered task
// This will probably always be a worker instance
type TaskProcessor interface {
Expand Down
47 changes: 45 additions & 2 deletions v1/brokers/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import (
)

const (
maxAWSSQSDelay = time.Minute * 15 // Max supported SQS delay is 15 min: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
maxAWSSQSDelay = time.Minute * 15 // Max supported SQS delay is 15 min: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
maxAWSSQSVisibilityTimeout = time.Hour * 12 // Max supported SQS visibility timeout is 12 hours: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ChangeMessageVisibility.html
)

// Broker represents a AWS SQS broker
Expand Down Expand Up @@ -185,6 +186,8 @@ func (b *Broker) Publish(ctx context.Context, signature *tasks.Signature) error
func (b *Broker) extend(by time.Duration, signature *tasks.Signature) error {
b.AdjustRoutingKey(signature)

by = restrictVisibilityTimeoutDelay(by)

visibilityInput := &awssqs.ChangeMessageVisibilityInput{
QueueUrl: aws.String(b.GetConfig().Broker + "/" + signature.RoutingKey),
ReceiptHandle: &signature.SQSReceiptHandle,
Expand All @@ -195,6 +198,34 @@ func (b *Broker) extend(by time.Duration, signature *tasks.Signature) error {
return err
}

func (b *Broker) RetryMessage(signature *tasks.Signature) {
b.AdjustRoutingKey(signature)

delay := signature.ETA.Sub(time.Now().UTC())
delay = restrictVisibilityTimeoutDelay(delay)

visibilityInput := &awssqs.ChangeMessageVisibilityInput{
QueueUrl: aws.String(b.GetConfig().Broker + "/" + signature.RoutingKey),
ReceiptHandle: &signature.SQSReceiptHandle,
VisibilityTimeout: aws.Int64(int64(delay.Seconds())),
}

_, err := b.service.ChangeMessageVisibility(visibilityInput)
if err != nil {
log.ERROR.Printf("ignoring error attempting to change visibility timeout. will re-attempt after default period. task %s", signature.UUID)
}
}

func restrictVisibilityTimeoutDelay(delay time.Duration) time.Duration {
if delay > maxAWSSQSVisibilityTimeout {
log.ERROR.Printf("attempted to retry a message with invalid delay: %s. using max.", delay.String())
delay = maxAWSSQSVisibilityTimeout
} else if delay < 0 {
delay = 0
}
return delay
}

// consume is a method which keeps consuming deliveries from a channel, until there is an error or a stop signal
func (b *Broker) consume(deliveries <-chan *awssqs.ReceiveMessageOutput, concurrency int, taskProcessor iface.TaskProcessor, pool chan struct{}) error {

Expand Down Expand Up @@ -233,6 +264,18 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor
sig.SQSReceiptHandle = *delivery.Messages[0].ReceiptHandle
}

if receiveCount := delivery.Messages[0].Attributes[awssqs.MessageSystemAttributeNameApproximateReceiveCount]; receiveCount != nil {
if rc, err := strconv.ParseInt(*receiveCount, 10, 64); err == nil {
sqsRetryCount := int(rc) - 1

// RetryCount may already be part of the signature if using certain retry mechanisms. To avoid overwriting,
// just use whichever one is higher.
if sqsRetryCount > sig.RetryCount {
sig.RetryCount = sqsRetryCount
}
}
}

sentTimeSinceEpochMilliString := delivery.Messages[0].Attributes[awssqs.MessageSystemAttributeNameSentTimestamp]
if sentTimeSinceEpochMilliString != nil {
if i, err := strconv.ParseInt(*sentTimeSinceEpochMilliString, 10, 64); err == nil {
Expand All @@ -252,7 +295,7 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor

err := taskProcessor.Process(sig, b.extend)
if err != nil {
// stop task deletion in case we want to send messages to dlq in sqs
// stop task deletion in case we want to send messages to dlq in sqs or retry from visibility timeout
if err == errs.ErrStopTaskDeletion {
return nil
}
Expand Down
13 changes: 13 additions & 0 deletions v1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/RichardKnop/machinery/v1/backends/result"
"github.com/RichardKnop/machinery/v1/brokers/eager"
"github.com/RichardKnop/machinery/v1/brokers/errs"
"github.com/RichardKnop/machinery/v1/config"
"github.com/RichardKnop/machinery/v1/log"
"github.com/RichardKnop/machinery/v1/tasks"
Expand Down Expand Up @@ -222,6 +223,18 @@ func (server *Server) SendTask(signature *tasks.Signature) (*result.AsyncResult,
return server.SendTaskWithContext(context.Background(), signature)
}

// RetryTaskAt attempts to retry the same task at the given time, but falls back to sending the task.
func (server *Server) RetryTaskAt(signature *tasks.Signature) error {
if retrier, ok := server.broker.(brokersiface.RetrySameMessage); ok {
retrier.RetryMessage(signature)

return errs.ErrStopTaskDeletion
}

_, err := server.SendTaskWithContext(context.Background(), signature)
return err
}

// SendChainWithContext will inject the trace context in all the signature headers before publishing it
func (server *Server) SendChainWithContext(ctx context.Context, chain *tasks.Chain) (*result.ChainAsyncResult, error) {
span, _ := opentracing.StartSpanFromContext(ctx, "SendChain", tracing.ProducerOption(), tracing.MachineryTag, tracing.WorkflowChainTag)
Expand Down
23 changes: 23 additions & 0 deletions v1/tasks/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,27 @@ func NewErrRetryTaskLater(msg string, retryIn time.Duration) ErrRetryTaskLater {
// Retriable is interface that retriable errors should implement
type Retriable interface {
RetryIn() time.Duration
error
}

// ErrKeepAndRetryTaskLater ...
type ErrKeepAndRetryTaskLater struct {
msg string
err error
retryIn time.Duration
}

// RetryIn returns time.Duration from now when task should be retried
func (e ErrKeepAndRetryTaskLater) RetryIn() time.Duration {
return e.retryIn
}

// Error implements the error interface
func (e ErrKeepAndRetryTaskLater) Error() string {
return fmt.Sprintf("Task error: %v (%s). Will retry in: %s. Will attempt to re-process same message.", e.err, e.msg, e.retryIn)
}

// NewErrKeepAndRetryTaskLater returns new ErrKeepAndRetryTaskLater instance
func NewErrKeepAndRetryTaskLater(err error, msg string, retryIn time.Duration) ErrKeepAndRetryTaskLater {
return ErrKeepAndRetryTaskLater{err: err, msg: msg, retryIn: retryIn}
}
31 changes: 20 additions & 11 deletions v1/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (t *Task) Call() (taskResults []*TaskResult, err error) {
results := t.TaskFunc.Call(args)
signature := SignatureFromContext(t.Context)
recordTimeSinceIngestion := func(name string) {
if signature.IngestionTime != nil {
if signature != nil && signature.IngestionTime != nil {
span.SetTag(
name,
time.Now().Sub(*signature.IngestionTime).Microseconds())
Expand All @@ -193,24 +193,33 @@ func (t *Task) Call() (taskResults []*TaskResult, err error) {
if !lastResult.IsNil() {
recordTimeSinceIngestion("ingestion_to_err")

// If the result implements Retriable interface, return instance of Retriable
retriableErrorInterface := reflect.TypeOf((*Retriable)(nil)).Elem()
if lastResult.Type().Implements(retriableErrorInterface) {
return nil, lastResult.Interface().(ErrRetryTaskLater)
}
value := lastResult.Interface()

// Otherwise, check that the result implements the standard error interface,
// check that the result implements the standard error interface,
// if not, return ErrLastReturnValueMustBeError error
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
if !lastResult.Type().Implements(errorInterface) {
asError, ok := value.(error)
if !ok {
return nil, ErrLastReturnValueMustBeError
}

_, isRetriable := asError.(Retriable)

if span != nil {
span.LogFields(opentracing_log.Error(lastResult.Interface().(error)))
if !isRetriable {
span.LogFields(opentracing_log.Error(asError))
} else {
span.SetTag("warning", asError)
}

span.SetTag("can_retry", isRetriable)
span.SetTag("did_fail", true)
}

// Return the standard error
return nil, lastResult.Interface().(error)
return nil, asError
}
if span != nil {
span.SetTag("did_fail", false)
}

recordTimeSinceIngestion("ingestion_to_success")
Expand Down
3 changes: 3 additions & 0 deletions v1/tasks/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"testing"
"time"

"github.com/opentracing/opentracing-go"

"github.com/RichardKnop/machinery/v1/tasks"
"github.com/stretchr/testify/assert"
)
Expand All @@ -18,6 +20,7 @@ func TestTaskCallErrorTest(t *testing.T) {
retriable := func() error { return tasks.NewErrRetryTaskLater("some error", 4*time.Hour) }

task, err := tasks.New(retriable, []tasks.Arg{})
_, task.Context = opentracing.StartSpanFromContext(context.Background(), "test")
assert.NoError(t, err)

// Invoke TryCall and validate that returned error can be cast to tasks.ErrRetryTaskLater
Expand Down
25 changes: 23 additions & 2 deletions v1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,14 @@ func (worker *Worker) Process(signature *tasks.Signature, extendFunc tasks.Exten
if err != nil {
// If a tasks.ErrRetryTaskLater was returned from the task,
// retry the task after specified duration
retriableErr, ok := interface{}(err).(tasks.ErrRetryTaskLater)
retryTaskLaterErr, ok := interface{}(err).(tasks.ErrRetryTaskLater)
if ok {
return worker.retryTaskIn(signature, retriableErr.RetryIn())
return worker.retryTaskIn(signature, retryTaskLaterErr.RetryIn())
}

keepAndRetryErr, ok := interface{}(err).(tasks.ErrKeepAndRetryTaskLater)
if ok {
return worker.keepAndRetryTaskIn(signature, keepAndRetryErr.RetryIn())
}

// Otherwise, execute default retry logic based on signature.RetryCount
Expand Down Expand Up @@ -245,6 +250,22 @@ func (worker *Worker) retryTaskIn(signature *tasks.Signature, retryIn time.Durat
return err
}

// keepAndRetryTaskIn attempts to keep the message on the queue but with a new ETA of now + retryIn.Seconds()
func (worker *Worker) keepAndRetryTaskIn(signature *tasks.Signature, retryIn time.Duration) error {
// Update task state to RETRY
if err := worker.server.GetBackend().SetStateRetry(signature); err != nil {
return fmt.Errorf("Set state to 'retry' for task %s returned error: %w", signature.UUID, err)
}

// Delay task by retryIn duration
eta := time.Now().UTC().Add(retryIn)
signature.ETA = &eta

log.WARNING.Printf("Task %s failed. Going to retry in %.0f seconds. Attempting to keep message.", signature.UUID, retryIn.Seconds())

return worker.server.RetryTaskAt(signature)
}

// taskSucceeded updates the task state and triggers success callbacks or a
// chord callback if this was the last task of a group with a chord callback
func (worker *Worker) taskSucceeded(signature *tasks.Signature, taskResults []*tasks.TaskResult) error {
Expand Down

0 comments on commit 0fb9883

Please sign in to comment.