Skip to content

Commit

Permalink
feat(context): ensure pool returns ErrPoolStopped when ctx is canceled
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed Oct 21, 2024
1 parent e6af32e commit af6a9cb
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 8 deletions.
4 changes: 3 additions & 1 deletion internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
var ErrDispatcherClosed = errors.New("dispatcher has been closed")

type Dispatcher[T any] struct {
ctx context.Context
bufferHasElements chan struct{}
buffer *linkedbuffer.LinkedBuffer[T]
dispatchFunc func([]T)
Expand All @@ -24,6 +25,7 @@ type Dispatcher[T any] struct {
// and process each element serially using the dispatchFunc
func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize int) *Dispatcher[T] {
dispatcher := &Dispatcher[T]{
ctx: ctx,
buffer: linkedbuffer.NewLinkedBuffer[T](10, batchSize),
bufferHasElements: make(chan struct{}, 1),
dispatchFunc: dispatchFunc,
Expand All @@ -40,7 +42,7 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
// Write writes values to the dispatcher
func (d *Dispatcher[T]) Write(values ...T) error {
// Check if the dispatcher has been closed
if d.closed.Load() {
if d.closed.Load() || d.ctx.Err() != nil {
return ErrDispatcherClosed
}

Expand Down
12 changes: 7 additions & 5 deletions internal/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ func TestDispatcherWithContextCanceled(t *testing.T) {

// Cancel the context
cancel()
// Write to the dispatcher
dispatcher.Write(1)
time.Sleep(5 * time.Millisecond)

// Attempt to write to the dispatcher
err := dispatcher.Write(1)

assert.Equal(t, ErrDispatcherClosed, err)

// Assert counters
assert.Equal(t, uint64(1), dispatcher.Len())
assert.Equal(t, uint64(1), dispatcher.WriteCount())
assert.Equal(t, uint64(0), dispatcher.Len())
assert.Equal(t, uint64(0), dispatcher.WriteCount())
assert.Equal(t, uint64(0), dispatcher.ReadCount())
}

Expand Down
21 changes: 19 additions & 2 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ const DEFAULT_TASKS_CHAN_LENGTH = 2048

var ErrPoolStopped = errors.New("pool stopped")

var poolStoppedFuture = func() Task {
future, resolve := future.NewFuture(context.Background())
resolve(ErrPoolStopped)
return future
}()

// basePool is the base interface for all pool types.
type basePool interface {
// Returns the number of worker goroutines that are currently active (executing a task) in the pool.
Expand Down Expand Up @@ -46,6 +52,9 @@ type basePool interface {

// Stops the pool and waits for all tasks to complete.
StopAndWait()

// Returns true if the pool has been stopped or its context has been cancelled.
Stopped() bool
}

// Represents a pool of goroutines that can execute tasks concurrently.
Expand All @@ -71,6 +80,7 @@ type Pool interface {
// pool is an implementation of the Pool interface.
type pool struct {
ctx context.Context
cancel context.CancelCauseFunc
maxConcurrency int
tasks chan any
tasksLen int
Expand All @@ -85,6 +95,10 @@ func (p *pool) Context() context.Context {
return p.ctx
}

func (p *pool) Stopped() bool {
return p.ctx.Err() != nil
}

func (p *pool) MaxConcurrency() int {
return p.maxConcurrency
}
Expand Down Expand Up @@ -138,13 +152,12 @@ func (p *pool) SubmitErr(task func() error) Task {
}

func (p *pool) submit(task any) Task {

future, resolve := future.NewFuture(p.Context())

wrapped := wrapTask[struct{}, func(error)](task, resolve)

if err := p.dispatcher.Write(wrapped); err != nil {
resolve(ErrPoolStopped)
return poolStoppedFuture
}

return future
Expand All @@ -157,6 +170,8 @@ func (p *pool) Stop() Task {
close(p.tasks)

p.workerWaitGroup.Wait()

p.cancel(ErrPoolStopped)
})
}

Expand Down Expand Up @@ -285,6 +300,8 @@ func newPool(maxConcurrency int, options ...Option) *pool {
option(pool)
}

pool.ctx, pool.cancel = context.WithCancelCause(pool.ctx)

pool.dispatcher = dispatcher.NewDispatcher(pool.ctx, pool.dispatch, tasksLen)

return pool
Expand Down
35 changes: 35 additions & 0 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,45 @@ func TestPoolSubmitOnStoppedPool(t *testing.T) {
err = pool.Go(func() {})

assert.Equal(t, ErrPoolStopped, err)
assert.Equal(t, true, pool.Stopped())
}

func TestNewPoolWithInvalidMaxConcurrency(t *testing.T) {
assert.PanicsWithError(t, "maxConcurrency must be greater than 0", func() {
NewPool(-1)
})
}

func TestPoolStoppedAfterCancel(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

pool := NewPool(10, WithContext(ctx))

err := pool.Submit(func() {
cancel()
}).Wait()

// If the context is canceled during the task execution, the task should return the context error.
assert.Equal(t, context.Canceled, err)

err = pool.Submit(func() {}).Wait()

// If the context is canceled, the pool should be stopped and the task should return the pool stopped error.
assert.Equal(t, ErrPoolStopped, err)
assert.True(t, pool.Stopped())

err = pool.Go(func() {})

assert.Equal(t, ErrPoolStopped, err)

pool.StopAndWait()

err = pool.Submit(func() {}).Wait()

assert.Equal(t, ErrPoolStopped, err)

err = pool.Go(func() {})

assert.Equal(t, ErrPoolStopped, err)
}
21 changes: 21 additions & 0 deletions subpool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pond

import (
"context"
"errors"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -170,3 +171,23 @@ func TestSubpoolMaxConcurrency(t *testing.T) {

assert.Equal(t, 10, subpool.MaxConcurrency())
}

func TestSubpoolStoppedAfterCancel(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

pool := NewPool(10, WithContext(ctx))
subpool := pool.NewSubpool(5)

cancel()

time.Sleep(1 * time.Millisecond)

err := pool.Submit(func() {}).Wait()

assert.Equal(t, ErrPoolStopped, err)

err = subpool.Submit(func() {}).Wait()

assert.Equal(t, ErrPoolStopped, err)
}

0 comments on commit af6a9cb

Please sign in to comment.