From af6a9cbe71311dd80afafce1373ec755594ae69f Mon Sep 17 00:00:00 2001 From: Alejandro Durante Date: Mon, 21 Oct 2024 19:07:07 -0300 Subject: [PATCH] feat(context): ensure pool returns ErrPoolStopped when ctx is canceled --- internal/dispatcher/dispatcher.go | 4 ++- internal/dispatcher/dispatcher_test.go | 12 +++++---- pool.go | 21 ++++++++++++++-- pool_test.go | 35 ++++++++++++++++++++++++++ subpool_test.go | 21 ++++++++++++++++ 5 files changed, 85 insertions(+), 8 deletions(-) diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index 83ec26c..470485a 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -12,6 +12,7 @@ import ( var ErrDispatcherClosed = errors.New("dispatcher has been closed") type Dispatcher[T any] struct { + ctx context.Context bufferHasElements chan struct{} buffer *linkedbuffer.LinkedBuffer[T] dispatchFunc func([]T) @@ -24,6 +25,7 @@ type Dispatcher[T any] struct { // and process each element serially using the dispatchFunc func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize int) *Dispatcher[T] { dispatcher := &Dispatcher[T]{ + ctx: ctx, buffer: linkedbuffer.NewLinkedBuffer[T](10, batchSize), bufferHasElements: make(chan struct{}, 1), dispatchFunc: dispatchFunc, @@ -40,7 +42,7 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize // Write writes values to the dispatcher func (d *Dispatcher[T]) Write(values ...T) error { // Check if the dispatcher has been closed - if d.closed.Load() { + if d.closed.Load() || d.ctx.Err() != nil { return ErrDispatcherClosed } diff --git a/internal/dispatcher/dispatcher_test.go b/internal/dispatcher/dispatcher_test.go index 992cbe2..48c9837 100644 --- a/internal/dispatcher/dispatcher_test.go +++ b/internal/dispatcher/dispatcher_test.go @@ -80,13 +80,15 @@ func TestDispatcherWithContextCanceled(t *testing.T) { // Cancel the context cancel() - // Write to the dispatcher - dispatcher.Write(1) - time.Sleep(5 * time.Millisecond) + + // Attempt to write to the dispatcher + err := dispatcher.Write(1) + + assert.Equal(t, ErrDispatcherClosed, err) // Assert counters - assert.Equal(t, uint64(1), dispatcher.Len()) - assert.Equal(t, uint64(1), dispatcher.WriteCount()) + assert.Equal(t, uint64(0), dispatcher.Len()) + assert.Equal(t, uint64(0), dispatcher.WriteCount()) assert.Equal(t, uint64(0), dispatcher.ReadCount()) } diff --git a/pool.go b/pool.go index ca170c1..fe5664d 100644 --- a/pool.go +++ b/pool.go @@ -15,6 +15,12 @@ const DEFAULT_TASKS_CHAN_LENGTH = 2048 var ErrPoolStopped = errors.New("pool stopped") +var poolStoppedFuture = func() Task { + future, resolve := future.NewFuture(context.Background()) + resolve(ErrPoolStopped) + return future +}() + // basePool is the base interface for all pool types. type basePool interface { // Returns the number of worker goroutines that are currently active (executing a task) in the pool. @@ -46,6 +52,9 @@ type basePool interface { // Stops the pool and waits for all tasks to complete. StopAndWait() + + // Returns true if the pool has been stopped or its context has been cancelled. + Stopped() bool } // Represents a pool of goroutines that can execute tasks concurrently. @@ -71,6 +80,7 @@ type Pool interface { // pool is an implementation of the Pool interface. type pool struct { ctx context.Context + cancel context.CancelCauseFunc maxConcurrency int tasks chan any tasksLen int @@ -85,6 +95,10 @@ func (p *pool) Context() context.Context { return p.ctx } +func (p *pool) Stopped() bool { + return p.ctx.Err() != nil +} + func (p *pool) MaxConcurrency() int { return p.maxConcurrency } @@ -138,13 +152,12 @@ func (p *pool) SubmitErr(task func() error) Task { } func (p *pool) submit(task any) Task { - future, resolve := future.NewFuture(p.Context()) wrapped := wrapTask[struct{}, func(error)](task, resolve) if err := p.dispatcher.Write(wrapped); err != nil { - resolve(ErrPoolStopped) + return poolStoppedFuture } return future @@ -157,6 +170,8 @@ func (p *pool) Stop() Task { close(p.tasks) p.workerWaitGroup.Wait() + + p.cancel(ErrPoolStopped) }) } @@ -285,6 +300,8 @@ func newPool(maxConcurrency int, options ...Option) *pool { option(pool) } + pool.ctx, pool.cancel = context.WithCancelCause(pool.ctx) + pool.dispatcher = dispatcher.NewDispatcher(pool.ctx, pool.dispatch, tasksLen) return pool diff --git a/pool_test.go b/pool_test.go index 9aa855c..8f6b273 100644 --- a/pool_test.go +++ b/pool_test.go @@ -153,6 +153,7 @@ func TestPoolSubmitOnStoppedPool(t *testing.T) { err = pool.Go(func() {}) assert.Equal(t, ErrPoolStopped, err) + assert.Equal(t, true, pool.Stopped()) } func TestNewPoolWithInvalidMaxConcurrency(t *testing.T) { @@ -160,3 +161,37 @@ func TestNewPoolWithInvalidMaxConcurrency(t *testing.T) { NewPool(-1) }) } + +func TestPoolStoppedAfterCancel(t *testing.T) { + + ctx, cancel := context.WithCancel(context.Background()) + + pool := NewPool(10, WithContext(ctx)) + + err := pool.Submit(func() { + cancel() + }).Wait() + + // If the context is canceled during the task execution, the task should return the context error. + assert.Equal(t, context.Canceled, err) + + err = pool.Submit(func() {}).Wait() + + // If the context is canceled, the pool should be stopped and the task should return the pool stopped error. + assert.Equal(t, ErrPoolStopped, err) + assert.True(t, pool.Stopped()) + + err = pool.Go(func() {}) + + assert.Equal(t, ErrPoolStopped, err) + + pool.StopAndWait() + + err = pool.Submit(func() {}).Wait() + + assert.Equal(t, ErrPoolStopped, err) + + err = pool.Go(func() {}) + + assert.Equal(t, ErrPoolStopped, err) +} diff --git a/subpool_test.go b/subpool_test.go index d2a309d..29c5d4e 100644 --- a/subpool_test.go +++ b/subpool_test.go @@ -1,6 +1,7 @@ package pond import ( + "context" "errors" "sync" "sync/atomic" @@ -170,3 +171,23 @@ func TestSubpoolMaxConcurrency(t *testing.T) { assert.Equal(t, 10, subpool.MaxConcurrency()) } + +func TestSubpoolStoppedAfterCancel(t *testing.T) { + + ctx, cancel := context.WithCancel(context.Background()) + + pool := NewPool(10, WithContext(ctx)) + subpool := pool.NewSubpool(5) + + cancel() + + time.Sleep(1 * time.Millisecond) + + err := pool.Submit(func() {}).Wait() + + assert.Equal(t, ErrPoolStopped, err) + + err = subpool.Submit(func() {}).Wait() + + assert.Equal(t, ErrPoolStopped, err) +}