Skip to content

Commit

Permalink
fix(pool): launch workers after queueing tasks (#74)
Browse files Browse the repository at this point in the history
* fix(pool): Stop hangs

* Reduce workflow timeout to 1m
  • Loading branch information
alitto authored Oct 20, 2024
1 parent bea3bad commit d0bc8c6
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 119 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
go test -race -v ./
go test -race -v -timeout 1m ./

coverage:
go test -race -v -coverprofile=coverage.out -covermode=atomic ./
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./
36 changes: 7 additions & 29 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package pond

import (
"errors"
"sync"
"sync/atomic"

"github.com/alitto/pond/v2/internal/future"
)
Expand All @@ -20,9 +20,6 @@ type TaskGroup interface {

// Waits for all tasks in the group to complete.
Wait() error

// Done returns a channel that is closed when all tasks in the group have completed.
Done() <-chan struct{}
}

// ResultTaskGroup represents a group of tasks that can be executed concurrently.
Expand All @@ -39,9 +36,6 @@ type ResultTaskGroup[O any] interface {

// Waits for all tasks in the group to complete.
Wait() ([]O, error)

// Done returns a channel that is closed when all tasks in the group have completed.
Done() <-chan struct{}
}

type result[O any] struct {
Expand All @@ -51,22 +45,16 @@ type result[O any] struct {

type abstractTaskGroup[T func() | func() O, E func() error | func() (O, error), O any] struct {
pool *pool
mutex sync.Mutex
nextIndex int
nextIndex atomic.Int64
future *future.CompositeFuture[*result[O]]
futureResolver future.CompositeFutureResolver[*result[O]]
}

func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E, O] {
g.mutex.Lock()
defer g.mutex.Unlock()

if len(tasks) == 0 {
panic(errors.New("no tasks provided"))
}

g.future.Add(len(tasks))

for _, task := range tasks {
g.submit(task)
}
Expand All @@ -75,15 +63,10 @@ func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E,
}

func (g *abstractTaskGroup[T, E, O]) SubmitErr(tasks ...E) *abstractTaskGroup[T, E, O] {
g.mutex.Lock()
defer g.mutex.Unlock()

if len(tasks) == 0 {
panic(errors.New("no tasks provided"))
}

g.future.Add(len(tasks))

for _, task := range tasks {
g.submit(task)
}
Expand All @@ -92,8 +75,7 @@ func (g *abstractTaskGroup[T, E, O]) SubmitErr(tasks ...E) *abstractTaskGroup[T,
}

func (g *abstractTaskGroup[T, E, O]) submit(task any) {
index := g.nextIndex
g.nextIndex++
index := int(g.nextIndex.Add(1) - 1)

err := g.pool.Go(func() {
output, err := invokeTask[O](task)
Expand All @@ -111,10 +93,6 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
}
}

func (g *abstractTaskGroup[T, E, O]) Done() <-chan struct{} {
return g.future.Done()
}

type taskGroup struct {
abstractTaskGroup[func(), func() error, struct{}]
}
Expand All @@ -130,7 +108,7 @@ func (g *taskGroup) SubmitErr(tasks ...func() error) TaskGroup {
}

func (g *taskGroup) Wait() error {
_, err := g.future.Wait()
_, err := g.future.Wait(int(g.nextIndex.Load()))
return err
}

Expand All @@ -149,7 +127,7 @@ func (g *resultTaskGroup[O]) SubmitErr(tasks ...func() (O, error)) ResultTaskGro
}

func (g *resultTaskGroup[O]) Wait() ([]O, error) {
results, err := g.future.Wait()
results, err := g.future.Wait(int(g.nextIndex.Load()))

if err != nil {
return []O{}, err
Expand All @@ -165,7 +143,7 @@ func (g *resultTaskGroup[O]) Wait() ([]O, error) {
}

func newTaskGroup(pool *pool) TaskGroup {
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context(), 0)
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context())

return &taskGroup{
abstractTaskGroup: abstractTaskGroup[func(), func() error, struct{}]{
Expand All @@ -177,7 +155,7 @@ func newTaskGroup(pool *pool) TaskGroup {
}

func newResultTaskGroup[O any](pool *pool) ResultTaskGroup[O] {
future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context(), 0)
future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context())

return &resultTaskGroup[O]{
abstractTaskGroup: abstractTaskGroup[func() O, func() (O, error), O]{
Expand Down
50 changes: 23 additions & 27 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ func TestResultTaskGroupWaitWithError(t *testing.T) {
assert.Equal(t, 0, len(outputs))
}

func TestResultTaskGroupWaitWithErrorInLastTask(t *testing.T) {

group := NewResultPool[int](10).
NewGroup()

sampleErr := errors.New("sample error")

group.SubmitErr(func() (int, error) {
return 1, nil
})

time.Sleep(10 * time.Millisecond)

group.SubmitErr(func() (int, error) {
return 0, sampleErr
})

outputs, err := group.Wait()

assert.Equal(t, sampleErr, err)
assert.Equal(t, 0, len(outputs))
}

func TestResultTaskGroupWaitWithMultipleErrors(t *testing.T) {

pool := NewResultPool[int](10)
Expand Down Expand Up @@ -93,33 +116,6 @@ func TestTaskGroupWithStoppedPool(t *testing.T) {
assert.Equal(t, ErrPoolStopped, err)
}

func TestTaskGroupDone(t *testing.T) {

pool := NewResultPool[int](10)

group := pool.NewGroup()

for i := 0; i < 5; i++ {
i := i
group.SubmitErr(func() (int, error) {
time.Sleep(1 * time.Millisecond)
return i, nil
})
}

<-group.Done()

outputs, err := group.Wait()

assert.Equal(t, nil, err)
assert.Equal(t, 5, len(outputs))
assert.Equal(t, 0, outputs[0])
assert.Equal(t, 1, outputs[1])
assert.Equal(t, 2, outputs[2])
assert.Equal(t, 3, outputs[3])
assert.Equal(t, 4, outputs[4])
}

func TestTaskGroupWithNoTasks(t *testing.T) {

group := NewResultPool[int](10).
Expand Down
102 changes: 71 additions & 31 deletions internal/future/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,83 @@ type compositeResolution[V any] struct {
err error
}

type waitListener struct {
count int
ch chan struct{}
}

type CompositeFuture[V any] struct {
*ValueFuture[[]V]
resolver ValueFutureResolver[[]V]
ctx context.Context
cancel context.CancelCauseFunc
resolutions []compositeResolution[V]
count int
mutex sync.Mutex
listeners []waitListener
}

func NewCompositeFuture[V any](ctx context.Context, count int) (*CompositeFuture[V], CompositeFutureResolver[V]) {
childFuture, resolver := NewValueFuture[[]V](ctx)
func NewCompositeFuture[V any](ctx context.Context) (*CompositeFuture[V], CompositeFutureResolver[V]) {
childCtx, cancel := context.WithCancelCause(ctx)

future := &CompositeFuture[V]{
ValueFuture: childFuture,
resolver: resolver,
ctx: childCtx,
cancel: cancel,
resolutions: make([]compositeResolution[V], 0),
}

if count > 0 {
future.Add(count)
}

return future, future.resolve
}

func (f *CompositeFuture[V]) Add(delta int) {
if delta <= 0 {
panic(fmt.Errorf("delta must be greater than 0"))
func (f *CompositeFuture[V]) Wait(count int) ([]V, error) {
f.mutex.Lock()
defer f.mutex.Unlock()

err := context.Cause(f.ctx)

if len(f.resolutions) >= count || err != nil {
if err != nil {
return []V{}, err
}

// Get sorted results
result := make([]V, count)
for _, resolution := range f.resolutions {
if resolution.index < count {
result[resolution.index] = resolution.value
}
}

return result, nil
}

// Register a listener
ch := make(chan struct{})
f.listeners = append(f.listeners, waitListener{
count: count,
ch: ch,
})

f.mutex.Unlock()

// Wait for the listener to be notified or the context to be canceled
select {
case <-ch:
case <-f.ctx.Done():
}

f.mutex.Lock()
defer f.mutex.Unlock()

f.count += delta
if err := context.Cause(f.ctx); err != nil {
return []V{}, err
}

// Get sorted results
result := make([]V, count)
for _, resolution := range f.resolutions {
if resolution.index < count {
result[resolution.index] = resolution.value
}
}

return result, nil
}

func (f *CompositeFuture[V]) resolve(index int, value V, err error) {
Expand All @@ -56,9 +100,6 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) {
if index < 0 {
panic(fmt.Errorf("index must be greater than or equal to 0"))
}
if index >= f.count {
panic(fmt.Errorf("index must be less than %d", f.count))
}

// Save the resolution
f.resolutions = append(f.resolutions, compositeResolution[V]{
Expand All @@ -67,20 +108,19 @@ func (f *CompositeFuture[V]) resolve(index int, value V, err error) {
err: err,
})

pending := f.count - len(f.resolutions)

if pending == 0 || err != nil {
if err != nil {
f.resolver([]V{}, err)
} else {
// Cancel the context if an error occurred
if err != nil {
f.cancel(err)
}

// Sort the resolutions
values := make([]V, f.count)
for _, resolution := range f.resolutions {
values[resolution.index] = resolution.value
}
// Notify listeners
for i := 0; i < len(f.listeners); i++ {
listener := f.listeners[i]

f.resolver(values, nil)
if err != nil || listener.count <= len(f.resolutions) {
close(listener.ch)
f.listeners = append(f.listeners[:i], f.listeners[i+1:]...)
i--
}
}
}
Loading

0 comments on commit d0bc8c6

Please sign in to comment.