Skip to content

Commit

Permalink
Fix unwrapping of DeadlineExceeded errors into ErrTaskTimeout.
Browse files Browse the repository at this point in the history
context.Cause() was called on the parent context, not the task context.
This resulted in the error handler being called with a nil err.
  • Loading branch information
bojanz committed May 15, 2024
1 parent 445ed61 commit 688360a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
18 changes: 11 additions & 7 deletions nanoq.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,9 @@ func (p *Processor) processTask(ctx context.Context, t Task) error {
return fmt.Errorf("task %v canceled: %v", t.ID, context.Cause(ctx))
}

if errors.Is(err, context.DeadlineExceeded) {
// Extract a more specific timeout error, if any.
err = context.Cause(ctx)
}
if p.errorHandler != nil {
p.errorHandler(ctx, t, err)
}

if t.Retries < t.MaxRetries && !errors.Is(err, ErrSkipRetry) {
retryIn := p.retryPolicy(t)
if err := p.client.RetryTask(ctx, t, retryIn); err != nil {
Expand Down Expand Up @@ -455,8 +450,17 @@ func callHandler(ctx context.Context, h Handler, t Task) (err error) {
err = fmt.Errorf("panic [%s:%d]: %v: %w", file, line, r, ErrSkipRetry)
}
}()
ctx, cancel := context.WithTimeoutCause(ctx, t.Timeout(), ErrTaskTimeout)
taskCtx, cancel := context.WithTimeoutCause(ctx, t.Timeout(), ErrTaskTimeout)
defer cancel()

return h(ctx, t)
err = h(taskCtx, t)
if err != nil && errors.Is(err, context.DeadlineExceeded) {
// Extract a more specific timeout error, if any.
// context.Cause returns nil if the canceled context is a child of taskCtx.
if cerr := context.Cause(taskCtx); cerr != nil {
err = cerr
}
}

return err
}
55 changes: 55 additions & 0 deletions nanoq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,58 @@ func TestProcessor_Run_Cancel(t *testing.T) {
t.Error(err)
}
}

func TestProcessor_Run_Timeout(t *testing.T) {
db, mock, _ := sqlmock.New()
defer db.Close()
client := nanoq.NewClient(sqlx.NewDb(db, "sqlmock"))
processor := nanoq.NewProcessor(client, zerolog.Nop())
processor.Handle("my-type", func(ctx context.Context, task nanoq.Task) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
continue
}
}
})
errorHandlerCalled := 0
processor.OnError(func(ctx context.Context, task nanoq.Task, err error) {
if !errors.Is(err, nanoq.ErrTaskTimeout) {
t.Errorf("error handler called with unexpected error: %v", err)
}
errorHandlerCalled++
})

// Task claim, timeout_seconds=1.
mock.ExpectBegin()
rows := sqlmock.NewRows([]string{"id", "fingerprint", "type", "payload", "retries", "max_retries", "timeout_seconds", "created_at", "scheduled_at"}).
AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "0", "1", time.Now(), time.Now())
mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows)

mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM").
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()

mock.ExpectBegin()
mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM").
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()

ctx, cancel := context.WithCancel(context.Background())
go processor.Run(ctx, 1, 1*time.Millisecond)
time.Sleep(2 * time.Second)
cancel()
// Wait for the processor to shut down.
time.Sleep(2 * time.Millisecond)

err := mock.ExpectationsWereMet()
if err != nil {
t.Error(err)
}

if errorHandlerCalled != 1 {
t.Errorf("erorr handler called %v times instead of %v", errorHandlerCalled, 1)
}
}

0 comments on commit 688360a

Please sign in to comment.