diff --git a/Makefile b/Makefile index 2a9ba68..9c9875b 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: - go test -race -v ./ + go test -race -v -timeout 1m ./ coverage: - go test -race -v -coverprofile=coverage.out -covermode=atomic ./ \ No newline at end of file + go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./ \ No newline at end of file diff --git a/group.go b/group.go index 4677bb3..1d83bb6 100644 --- a/group.go +++ b/group.go @@ -2,7 +2,7 @@ package pond import ( "errors" - "sync" + "sync/atomic" "github.com/alitto/pond/v2/internal/future" ) @@ -20,9 +20,6 @@ type TaskGroup interface { // Waits for all tasks in the group to complete. Wait() error - - // Done returns a channel that is closed when all tasks in the group have completed. - Done() <-chan struct{} } // ResultTaskGroup represents a group of tasks that can be executed concurrently. @@ -39,9 +36,6 @@ type ResultTaskGroup[O any] interface { // Waits for all tasks in the group to complete. Wait() ([]O, error) - - // Done returns a channel that is closed when all tasks in the group have completed. - Done() <-chan struct{} } type result[O any] struct { @@ -51,22 +45,16 @@ type result[O any] struct { type abstractTaskGroup[T func() | func() O, E func() error | func() (O, error), O any] struct { pool *pool - mutex sync.Mutex - nextIndex int + nextIndex atomic.Int64 future *future.CompositeFuture[*result[O]] futureResolver future.CompositeFutureResolver[*result[O]] } func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E, O] { - g.mutex.Lock() - defer g.mutex.Unlock() - if len(tasks) == 0 { panic(errors.New("no tasks provided")) } - g.future.Add(len(tasks)) - for _, task := range tasks { g.submit(task) } @@ -75,15 +63,10 @@ func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E, } func (g *abstractTaskGroup[T, E, O]) SubmitErr(tasks ...E) *abstractTaskGroup[T, E, O] { - g.mutex.Lock() - defer g.mutex.Unlock() - if len(tasks) == 0 { panic(errors.New("no tasks provided")) } - g.future.Add(len(tasks)) - for _, task := range tasks { g.submit(task) } @@ -92,8 +75,7 @@ func (g *abstractTaskGroup[T, E, O]) SubmitErr(tasks ...E) *abstractTaskGroup[T, } func (g *abstractTaskGroup[T, E, O]) submit(task any) { - index := g.nextIndex - g.nextIndex++ + index := int(g.nextIndex.Add(1) - 1) err := g.pool.Go(func() { output, err := invokeTask[O](task) @@ -111,10 +93,6 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) { } } -func (g *abstractTaskGroup[T, E, O]) Done() <-chan struct{} { - return g.future.Done() -} - type taskGroup struct { abstractTaskGroup[func(), func() error, struct{}] } @@ -130,7 +108,7 @@ func (g *taskGroup) SubmitErr(tasks ...func() error) TaskGroup { } func (g *taskGroup) Wait() error { - _, err := g.future.Wait() + _, err := g.future.Wait(int(g.nextIndex.Load())) return err } @@ -149,7 +127,7 @@ func (g *resultTaskGroup[O]) SubmitErr(tasks ...func() (O, error)) ResultTaskGro } func (g *resultTaskGroup[O]) Wait() ([]O, error) { - results, err := g.future.Wait() + results, err := g.future.Wait(int(g.nextIndex.Load())) if err != nil { return []O{}, err @@ -165,7 +143,7 @@ func (g *resultTaskGroup[O]) Wait() ([]O, error) { } func newTaskGroup(pool *pool) TaskGroup { - future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context(), 0) + future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context()) return &taskGroup{ abstractTaskGroup: abstractTaskGroup[func(), func() error, struct{}]{ @@ -177,7 +155,7 @@ func newTaskGroup(pool *pool) TaskGroup { } func newResultTaskGroup[O any](pool *pool) ResultTaskGroup[O] { - future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context(), 0) + future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context()) return &resultTaskGroup[O]{ abstractTaskGroup: abstractTaskGroup[func() O, func() (O, error), O]{ diff --git a/group_test.go b/group_test.go index ee10a1e..780a554 100644 --- a/group_test.go +++ b/group_test.go @@ -58,6 +58,29 @@ func TestResultTaskGroupWaitWithError(t *testing.T) { assert.Equal(t, 0, len(outputs)) } +func TestResultTaskGroupWaitWithErrorInLastTask(t *testing.T) { + + group := NewResultPool[int](10). + NewGroup() + + sampleErr := errors.New("sample error") + + group.SubmitErr(func() (int, error) { + return 1, nil + }) + + time.Sleep(10 * time.Millisecond) + + group.SubmitErr(func() (int, error) { + return 0, sampleErr + }) + + outputs, err := group.Wait() + + assert.Equal(t, sampleErr, err) + assert.Equal(t, 0, len(outputs)) +} + func TestResultTaskGroupWaitWithMultipleErrors(t *testing.T) { pool := NewResultPool[int](10) @@ -93,33 +116,6 @@ func TestTaskGroupWithStoppedPool(t *testing.T) { assert.Equal(t, ErrPoolStopped, err) } -func TestTaskGroupDone(t *testing.T) { - - pool := NewResultPool[int](10) - - group := pool.NewGroup() - - for i := 0; i < 5; i++ { - i := i - group.SubmitErr(func() (int, error) { - time.Sleep(1 * time.Millisecond) - return i, nil - }) - } - - <-group.Done() - - outputs, err := group.Wait() - - assert.Equal(t, nil, err) - assert.Equal(t, 5, len(outputs)) - assert.Equal(t, 0, outputs[0]) - assert.Equal(t, 1, outputs[1]) - assert.Equal(t, 2, outputs[2]) - assert.Equal(t, 3, outputs[3]) - assert.Equal(t, 4, outputs[4]) -} - func TestTaskGroupWithNoTasks(t *testing.T) { group := NewResultPool[int](10). diff --git a/internal/future/composite.go b/internal/future/composite.go index 6b2d38b..d73647a 100644 --- a/internal/future/composite.go +++ b/internal/future/composite.go @@ -14,39 +14,83 @@ type compositeResolution[V any] struct { err error } +type waitListener struct { + count int + ch chan struct{} +} + type CompositeFuture[V any] struct { - *ValueFuture[[]V] - resolver ValueFutureResolver[[]V] + ctx context.Context + cancel context.CancelCauseFunc resolutions []compositeResolution[V] - count int mutex sync.Mutex + listeners []waitListener } -func NewCompositeFuture[V any](ctx context.Context, count int) (*CompositeFuture[V], CompositeFutureResolver[V]) { - childFuture, resolver := NewValueFuture[[]V](ctx) +func NewCompositeFuture[V any](ctx context.Context) (*CompositeFuture[V], CompositeFutureResolver[V]) { + childCtx, cancel := context.WithCancelCause(ctx) future := &CompositeFuture[V]{ - ValueFuture: childFuture, - resolver: resolver, + ctx: childCtx, + cancel: cancel, resolutions: make([]compositeResolution[V], 0), } - if count > 0 { - future.Add(count) - } - return future, future.resolve } -func (f *CompositeFuture[V]) Add(delta int) { - if delta <= 0 { - panic(fmt.Errorf("delta must be greater than 0")) +func (f *CompositeFuture[V]) Wait(count int) ([]V, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + err := context.Cause(f.ctx) + + if len(f.resolutions) >= count || err != nil { + if err != nil { + return []V{}, err + } + + // Get sorted results + result := make([]V, count) + for _, resolution := range f.resolutions { + if resolution.index < count { + result[resolution.index] = resolution.value + } + } + + return result, nil + } + + // Register a listener + ch := make(chan struct{}) + f.listeners = append(f.listeners, waitListener{ + count: count, + ch: ch, + }) + + f.mutex.Unlock() + + // Wait for the listener to be notified or the context to be canceled + select { + case <-ch: + case <-f.ctx.Done(): } f.mutex.Lock() - defer f.mutex.Unlock() - f.count += delta + if err := context.Cause(f.ctx); err != nil { + return []V{}, err + } + + // Get sorted results + result := make([]V, count) + for _, resolution := range f.resolutions { + if resolution.index < count { + result[resolution.index] = resolution.value + } + } + + return result, nil } func (f *CompositeFuture[V]) resolve(index int, value V, err error) { @@ -56,9 +100,6 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) { if index < 0 { panic(fmt.Errorf("index must be greater than or equal to 0")) } - if index >= f.count { - panic(fmt.Errorf("index must be less than %d", f.count)) - } // Save the resolution f.resolutions = append(f.resolutions, compositeResolution[V]{ @@ -67,20 +108,19 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) { err: err, }) - pending := f.count - len(f.resolutions) - - if pending == 0 || err != nil { - if err != nil { - f.resolver([]V{}, err) - } else { + // Cancel the context if an error occurred + if err != nil { + f.cancel(err) + } - // Sort the resolutions - values := make([]V, f.count) - for _, resolution := range f.resolutions { - values[resolution.index] = resolution.value - } + // Notify listeners + for i := 0; i < len(f.listeners); i++ { + listener := f.listeners[i] - f.resolver(values, nil) + if err != nil || listener.count <= len(f.resolutions) { + close(listener.ch) + f.listeners = append(f.listeners[:i], f.listeners[i+1:]...) + i-- } } } diff --git a/internal/future/composite_test.go b/internal/future/composite_test.go index 7983ed5..21b0c2f 100644 --- a/internal/future/composite_test.go +++ b/internal/future/composite_test.go @@ -4,18 +4,19 @@ import ( "context" "errors" "testing" + "time" "github.com/alitto/pond/v2/internal/assert" ) func TestCompositeFutureWait(t *testing.T) { - future, resolve := NewCompositeFuture[string](context.Background(), 3) + future, resolve := NewCompositeFuture[string](context.Background()) resolve(0, "output1", nil) resolve(1, "output2", nil) resolve(2, "output3", nil) - outputs, err := future.Wait() + outputs, err := future.Wait(3) assert.Equal(t, nil, err) assert.Equal(t, 3, len(outputs)) @@ -25,15 +26,14 @@ func TestCompositeFutureWait(t *testing.T) { } func TestCompositeFutureWaitWithError(t *testing.T) { - future, resolve := NewCompositeFuture[string](context.Background(), 1) - future.Add(2) + future, resolve := NewCompositeFuture[string](context.Background()) sampleErr := errors.New("sample error") resolve(0, "output1", nil) resolve(1, "output2", nil) resolve(2, "output3", sampleErr) - outputs, err := future.Wait() + outputs, err := future.Wait(3) assert.Equal(t, sampleErr, err) assert.Equal(t, 0, len(outputs)) @@ -41,48 +41,40 @@ func TestCompositeFutureWaitWithError(t *testing.T) { func TestCompositeFutureWaitWithCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - future, resolve := NewCompositeFuture[string](ctx, 2) + future, resolve := NewCompositeFuture[string](ctx) cancel() resolve(0, "output1", nil) - _, err := future.Wait() + _, err := future.Wait(2) assert.Equal(t, context.Canceled, err) } func TestCompositeFutureResolveWithIndexOutOfRange(t *testing.T) { - _, resolve := NewCompositeFuture[string](context.Background(), 2) - - assert.PanicsWithError(t, "index must be less than 2", func() { - resolve(2, "output1", nil) - }) + _, resolve := NewCompositeFuture[string](context.Background()) assert.PanicsWithError(t, "index must be greater than or equal to 0", func() { resolve(-1, "output1", nil) }) } -func TestCompositeFutureAddWithInvalidCount(t *testing.T) { - future, _ := NewCompositeFuture[string](context.Background(), 2) - - assert.PanicsWithError(t, "delta must be greater than 0", func() { - future.Add(0) - }) -} - -func TestCompositeFutureAdd(t *testing.T) { - future, resolve := NewCompositeFuture[string](context.Background(), 2) +func TestCompositeFutureWithMultipleWait(t *testing.T) { + future, resolve := NewCompositeFuture[string](context.Background()) resolve(0, "output1", nil) - future.Add(1) + outputs1, err := future.Wait(1) + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(outputs1)) + assert.Equal(t, "output1", outputs1[0]) resolve(1, "output2", nil) resolve(2, "output3", nil) - outputs, err := future.Wait() + outputs, err := future.Wait(3) assert.Equal(t, nil, err) assert.Equal(t, 3, len(outputs)) @@ -90,3 +82,53 @@ func TestCompositeFutureAdd(t *testing.T) { assert.Equal(t, "output2", outputs[1]) assert.Equal(t, "output3", outputs[2]) } + +func TestCompositeFutureWithErrorsAndMultipleWait(t *testing.T) { + future, resolve := NewCompositeFuture[string](context.Background()) + + sampleErr := errors.New("sample error") + resolve(0, "output1", sampleErr) + + outputs1, err := future.Wait(1) + + assert.Equal(t, sampleErr, err) + assert.Equal(t, 0, len(outputs1)) + + resolve(1, "output2", nil) + resolve(2, "output3", nil) + + outputs, err := future.Wait(3) + + assert.Equal(t, sampleErr, err) + assert.Equal(t, 0, len(outputs)) +} + +func TestCompositeFutureWaitBeforeResoluion(t *testing.T) { + future, resolve := NewCompositeFuture[string](context.Background()) + + go func() { + time.Sleep(10 * time.Millisecond) + resolve(0, "output1", nil) + }() + + outputs, err := future.Wait(1) + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(outputs)) + assert.Equal(t, "output1", outputs[0]) +} + +func TestCompositeFutureWaitBeforeContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + future, _ := NewCompositeFuture[string](ctx) + + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + outputs, err := future.Wait(1) + + assert.Equal(t, context.Canceled, err) + assert.Equal(t, 0, len(outputs)) +} diff --git a/pool.go b/pool.go index 40ba2c9..ca170c1 100644 --- a/pool.go +++ b/pool.go @@ -52,6 +52,9 @@ type basePool interface { type Pool interface { basePool + // Submits a task to the pool without waiting for it to complete. + Go(task func()) error + // Submits a task to the pool and returns a future that can be used to wait for the task to complete. Submit(task func()) Task @@ -179,20 +182,24 @@ func (p *pool) dispatch(incomingTasks []any) { func (p *pool) dispatchTask(task any) { workerCount := int(p.workerCount.Load()) - if workerCount < p.tasksLen { - // If there are less workers than the size of the channel, start workers - p.startWorker() - } - // Attempt to submit task without blocking select { case p.tasks <- task: // Task submitted + + // If we could submit the task without blocking, it means one of two things: + // 1. There are idle workers (all workers are busy) + // 2. There are no workers + // In either case, we should launch a new worker if the number of workers is less than the maximum concurrency + if workerCount < p.maxConcurrency { + // Launch a new worker + p.startWorker() + } return default: } - // There are no idle workers, create more + // Task queue is full, launch a new worker if the number of workers is less than the maximum concurrency if workerCount < p.maxConcurrency { // Launch a new worker p.startWorker() diff --git a/task.go b/task.go index c0b43fd..6814073 100644 --- a/task.go +++ b/task.go @@ -43,6 +43,8 @@ func (t wrappedTask[R, C]) Run() error { c(err) case func(R, error): c(result, err) + default: + panic(fmt.Sprintf("unsupported callback type: %#v", t.callback)) } return err