Skip to content

Commit

Permalink
feat(taskgroup): wait for ongoing tasks complete when group is stoppe…
Browse files Browse the repository at this point in the history
…d or context is cancelled (#82)
  • Loading branch information
alitto authored Nov 10, 2024
1 parent 11ecf70 commit f1d2a44
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 40 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
go test -race -v -timeout 1m ./
go test -race -v -timeout 1m ./...

coverage:
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...
31 changes: 26 additions & 5 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pond
import (
"context"
"errors"
"sync"
"sync/atomic"

"github.com/alitto/pond/v2/internal/future"
Expand All @@ -22,6 +23,10 @@ type TaskGroup interface {
SubmitErr(tasks ...func() error) TaskGroup

// Waits for all tasks in the group to complete.
// If any of the tasks return an error, the group will return the first error encountered.
// If the context is cancelled, the group will return the context error.
// If the group is stopped, the group will return ErrGroupStopped.
// If a task is running when the context is cancelled or the group is stopped, the task will be allowed to complete before returning.
Wait() error

// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
Expand All @@ -43,7 +48,11 @@ type ResultTaskGroup[O any] interface {
// Submits a task to the group that can return an error.
SubmitErr(tasks ...func() (O, error)) ResultTaskGroup[O]

// Waits for all tasks in the group to complete.
// Waits for all tasks in the group to complete and returns the results of each task in the order they were submitted.
// If any of the tasks return an error, the group will return the first error encountered.
// If the context is cancelled, the group will return the context error.
// If the group is stopped, the group will return ErrGroupStopped.
// If a task is running when the context is cancelled or the group is stopped, the task will be allowed to complete before returning.
Wait() ([]O, error)

// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
Expand All @@ -61,6 +70,7 @@ type result[O any] struct {
type abstractTaskGroup[T func() | func() O, E func() error | func() (O, error), O any] struct {
pool *pool
nextIndex atomic.Int64
taskWaitGroup sync.WaitGroup
future *future.CompositeFuture[*result[O]]
futureResolver future.CompositeFutureResolver[*result[O]]
}
Expand Down Expand Up @@ -92,7 +102,11 @@ func (g *abstractTaskGroup[T, E, O]) SubmitErr(tasks ...E) *abstractTaskGroup[T,
func (g *abstractTaskGroup[T, E, O]) submit(task any) {
index := int(g.nextIndex.Add(1) - 1)

g.taskWaitGroup.Add(1)

err := g.pool.Go(func() {
defer g.taskWaitGroup.Done()

// Check if the context has been cancelled to prevent running tasks that are not needed
if err := g.future.Context().Err(); err != nil {
g.futureResolver(index, &result[O]{
Expand All @@ -111,6 +125,8 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
})

if err != nil {
g.taskWaitGroup.Done()

g.futureResolver(index, &result[O]{
Err: err,
}, err)
Expand All @@ -133,6 +149,9 @@ func (g *taskGroup) SubmitErr(tasks ...func() error) TaskGroup {

func (g *taskGroup) Wait() error {
_, err := g.future.Wait(int(g.nextIndex.Load()))
// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
// That's why we wait for the future to be resolved before waiting for the wait group.
g.taskWaitGroup.Wait()
return err
}

Expand All @@ -153,14 +172,16 @@ func (g *resultTaskGroup[O]) SubmitErr(tasks ...func() (O, error)) ResultTaskGro
func (g *resultTaskGroup[O]) Wait() ([]O, error) {
results, err := g.future.Wait(int(g.nextIndex.Load()))

if err != nil {
return []O{}, err
}
// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
// That's why we wait for the future to be resolved before waiting for the wait group.
g.taskWaitGroup.Wait()

values := make([]O, len(results))

for i, result := range results {
values[i] = result.Output
if result != nil {
values[i] = result.Output
}
}

return values, err
Expand Down
73 changes: 69 additions & 4 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestResultTaskGroupWait(t *testing.T) {

func TestResultTaskGroupWaitWithError(t *testing.T) {

group := NewResultPool[int](10).
group := NewResultPool[int](1).
NewGroup()

sampleErr := errors.New("sample error")
Expand All @@ -57,7 +57,12 @@ func TestResultTaskGroupWaitWithError(t *testing.T) {
outputs, err := group.Wait()

assert.Equal(t, sampleErr, err)
assert.Equal(t, 0, len(outputs))
assert.Equal(t, 5, len(outputs))
assert.Equal(t, 0, outputs[0])
assert.Equal(t, 1, outputs[1])
assert.Equal(t, 2, outputs[2])
assert.Equal(t, 0, outputs[3]) // This task returned an error
assert.Equal(t, 0, outputs[4]) // This task was not executed
}

func TestResultTaskGroupWaitWithErrorInLastTask(t *testing.T) {
Expand All @@ -80,7 +85,9 @@ func TestResultTaskGroupWaitWithErrorInLastTask(t *testing.T) {
outputs, err := group.Wait()

assert.Equal(t, sampleErr, err)
assert.Equal(t, 0, len(outputs))
assert.Equal(t, 2, len(outputs))
assert.Equal(t, 1, outputs[0])
assert.Equal(t, 0, outputs[1])
}

func TestResultTaskGroupWaitWithMultipleErrors(t *testing.T) {
Expand All @@ -95,6 +102,7 @@ func TestResultTaskGroupWaitWithMultipleErrors(t *testing.T) {
i := i
group.SubmitErr(func() (int, error) {
if i%2 == 0 {
time.Sleep(10 * time.Millisecond)
return 0, sampleErr
}
return i, nil
Expand All @@ -103,8 +111,65 @@ func TestResultTaskGroupWaitWithMultipleErrors(t *testing.T) {

outputs, err := group.Wait()

assert.Equal(t, 0, len(outputs))
assert.Equal(t, sampleErr, err)
assert.Equal(t, 5, len(outputs))
assert.Equal(t, 0, outputs[0])
assert.Equal(t, 1, outputs[1])
assert.Equal(t, 0, outputs[2])
assert.Equal(t, 3, outputs[3])
assert.Equal(t, 0, outputs[4])
}

func TestResultTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) {
pool := NewResultPool[string](1)

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)

group.Submit(func() string {
cancel() // cancel the context after the first task is started
time.Sleep(10 * time.Millisecond)
return "output1"
})

group.Submit(func() string {
time.Sleep(10 * time.Millisecond)
return "output2"
})

results, err := group.Wait()

assert.Equal(t, context.Canceled, err)
assert.Equal(t, int(2), len(results))
assert.Equal(t, "", results[0])
assert.Equal(t, "", results[1])
}

func TestTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) {
pool := NewPool(1)

var executedCount atomic.Int32

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)

group.Submit(func() {
cancel() // cancel the context after the first task is started
time.Sleep(10 * time.Millisecond)
executedCount.Add(1)
})

group.Submit(func() {
time.Sleep(10 * time.Millisecond)
executedCount.Add(1)
})

err := group.Wait()

assert.Equal(t, context.Canceled, err)
assert.Equal(t, int32(1), executedCount.Load())
}

func TestTaskGroupWithStoppedPool(t *testing.T) {
Expand Down
14 changes: 8 additions & 6 deletions internal/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestDispatcherWithContextCanceledAfterWrite(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

receivedCount := atomic.Uint64{}
var receivedCount atomic.Uint64
receiveFunc := func(elems []int) {
for range elems {
receivedCount.Add(1)
Expand All @@ -110,16 +110,18 @@ func TestDispatcherWithContextCanceledAfterWrite(t *testing.T) {
assert.Equal(t, uint64(0), dispatcher.WriteCount())
assert.Equal(t, uint64(0), dispatcher.ReadCount())

// Cancel the context
dispatcher.Write(1)
time.Sleep(5 * time.Millisecond) // Wait for the dispatcher to process the element
time.Sleep(5 * time.Millisecond)

// Cancel the context
cancel()

dispatcher.Write(1)
time.Sleep(5 * time.Millisecond) // Wait for the dispatcher to process the element
time.Sleep(5 * time.Millisecond)

// Assert counters
assert.Equal(t, uint64(1), dispatcher.Len())
assert.Equal(t, uint64(2), dispatcher.WriteCount())
assert.Equal(t, uint64(0), dispatcher.Len())
assert.Equal(t, uint64(1), dispatcher.WriteCount())
assert.Equal(t, uint64(1), dispatcher.ReadCount())
}

Expand Down
57 changes: 41 additions & 16 deletions internal/future/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ type compositeResolution[V any] struct {
value V
}

type compositeErrorResolution struct {
index int
err error
}

func (e *compositeErrorResolution) Error() string {
return e.err.Error()
}

type waitListener struct {
count int
ch chan struct{}
Expand Down Expand Up @@ -66,8 +75,14 @@ func (f *CompositeFuture[V]) Context() context.Context {
}

func (f *CompositeFuture[V]) Cancel(cause error) {
var zero V
f.resolve(len(f.resolutions), zero, cause)
f.mutex.Lock()
defer f.mutex.Unlock()

// Cancel the context
f.cancel(cause)

// Notify listeners
f.notifyListeners()
}

func (f *CompositeFuture[V]) Wait(count int) ([]V, error) {
Expand Down Expand Up @@ -109,9 +124,12 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) {

// Cancel the context if an error occurred
if err != nil {
f.cancel(err)
f.cancel(&compositeErrorResolution{
index: index,
err: err,
})
} else if context.Cause(f.ctx) == nil {
// Save the resolution as long as the context is not canceled
// Save the resolution
f.resolutions = append(f.resolutions, compositeResolution[V]{
index: index,
value: value,
Expand All @@ -122,28 +140,35 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) {
f.notifyListeners()
}

func (f *CompositeFuture[V]) getResult(count int) ([]V, error) {
// If we have enough results, return them
if len(f.resolutions) >= count {
func (f *CompositeFuture[V]) getResult(count int) (values []V, err error) {

cause := context.Cause(f.ctx)

// Get sorted results
result := make([]V, count)
// If we have enough results, return them
if cause != nil || len(f.resolutions) >= count {
// Get sorted resolution values
values = make([]V, count)
for _, resolution := range f.resolutions {
if resolution.index < count {
result[resolution.index] = resolution.value
values[resolution.index] = resolution.value
}
}

return result, nil
}

err := context.Cause(f.ctx)
if cause == nil {
return
}

if err != nil {
return []V{}, err
if errorResolution, ok := cause.(*compositeErrorResolution); ok {
// Unwrap the error resolution
err = errorResolution.err
} else if len(f.resolutions) < count {
// If the context is canceled and we have collected enough results, return nil error
// because we assume that context cancellation happened after the last resolution.
err = cause
}

return nil, nil
return
}

func (f *CompositeFuture[V]) notifyListeners() {
Expand Down
Loading

0 comments on commit f1d2a44

Please sign in to comment.