Skip to content

Commit

Permalink
fix(pool): fix race condition with small pool sizes (#83)
Browse files Browse the repository at this point in the history
* fix(dispatcher): some tasks are misse

* double test runs

* revert initial linkedbuffer size

* remove brittle assertion

* Revert dispatcher chan size

* comment buffer reset

* add lock while dispatching tasks to avoid deadlocks
  • Loading branch information
alitto authored Nov 13, 2024
1 parent f1d2a44 commit 543ed3a
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
go-version: ${{ matrix.go-version }}

- name: Test
run: make test
run: make test-ci
codecov:
name: Coverage report
runs-on: ubuntu-latest
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
test:
go test -race -v -timeout 1m ./...
go test -race -v -timeout 15s -count=1 ./...

test-ci:
go test -race -v -timeout 1m -count=3 ./...

coverage:
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...
5 changes: 3 additions & 2 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ func TestResultTaskGroupWait(t *testing.T) {

func TestResultTaskGroupWaitWithError(t *testing.T) {

group := NewResultPool[int](1).
NewGroup()
pool := NewResultPool[int](1)

group := pool.NewGroup()

sampleErr := errors.New("sample error")

Expand Down
4 changes: 0 additions & 4 deletions internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
bufferHasElements: make(chan struct{}, 1),
dispatchFunc: dispatchFunc,
batchSize: batchSize,
closed: atomic.Bool{},
}

dispatcher.waitGroup.Add(1)
Expand Down Expand Up @@ -118,9 +117,6 @@ func (d *Dispatcher[T]) run(ctx context.Context) {

// Submit the next batch of values
d.dispatchFunc(batch[0:batchSize])

// Reset batch
batch = batch[:0]
}

if !ok || d.closed.Load() {
Expand Down
9 changes: 3 additions & 6 deletions internal/linkedbuffer/linkedbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type LinkedBuffer[T any] struct {
maxCapacity int
writeCount atomic.Uint64
readCount atomic.Uint64
mutex sync.RWMutex
mutex sync.Mutex
}

func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
Expand Down Expand Up @@ -78,28 +78,25 @@ func (b *LinkedBuffer[T]) Write(values []T) {

// Read reads values from the buffer and returns the number of elements read
func (b *LinkedBuffer[T]) Read(values []T) int {
b.mutex.Lock()
defer b.mutex.Unlock()

var readBuffer *Buffer[T]

for {
b.mutex.RLock()
readBuffer = b.readBuffer
b.mutex.RUnlock()

// Read element
n, err := readBuffer.Read(values)

if err == ErrEOF {
// Move to next buffer
b.mutex.Lock()
if readBuffer.next == nil {
b.mutex.Unlock()
return n
}
if b.readBuffer != readBuffer.next {
b.readBuffer = readBuffer.next
}
b.mutex.Unlock()
continue
}

Expand Down
39 changes: 39 additions & 0 deletions internal/linkedbuffer/linkedbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,42 @@ func TestLinkedBufferLen(t *testing.T) {
buf.readCount.Add(1)
assert.Equal(t, uint64(0), buf.Len())
}

func TestLinkedBufferWithReusedBuffer(t *testing.T) {

buf := NewLinkedBuffer[int](2, 1)

values := make([]int, 1)

buf.Write([]int{1})
buf.Write([]int{2})

n := buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, values[0])

assert.Equal(t, 1, len(values))
assert.Equal(t, 1, cap(values))

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 2, values[0])

buf.Write([]int{3})
buf.Write([]int{4})

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 3, values[0])

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 4, values[0])
}
66 changes: 49 additions & 17 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

var MAX_TASKS_CHAN_LENGTH = runtime.NumCPU() * 128

var PERSISTENT_WORKER_COUNT = int64(runtime.NumCPU())

var ErrPoolStopped = errors.New("pool stopped")

var poolStoppedFuture = func() Task {
Expand Down Expand Up @@ -91,6 +93,7 @@ type pool struct {
workerCount atomic.Int64
workerWaitGroup sync.WaitGroup
dispatcher *dispatcher.Dispatcher[any]
dispatcherRunning sync.Mutex
successfulTaskCount atomic.Uint64
failedTaskCount atomic.Uint64
}
Expand Down Expand Up @@ -196,15 +199,16 @@ func (p *pool) NewGroupContext(ctx context.Context) TaskGroup {
}

func (p *pool) dispatch(incomingTasks []any) {
p.dispatcherRunning.Lock()
defer p.dispatcherRunning.Unlock()

// Submit tasks
for _, task := range incomingTasks {
p.dispatchTask(task)
}
}

func (p *pool) dispatchTask(task any) {
workerCount := int(p.workerCount.Load())

// Attempt to submit task without blocking
select {
case p.tasks <- task:
Expand All @@ -214,19 +218,13 @@ func (p *pool) dispatchTask(task any) {
// 1. There are no idle workers (all spawned workers are processing a task)
// 2. There are no workers in the pool
// In either case, we should launch a new worker as long as the number of workers is less than the size of the task queue.
if workerCount < p.tasksLen {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.tasksLen)
return
default:
}

// Task queue is full, launch a new worker if the number of workers is less than the maximum concurrency
if workerCount < p.maxConcurrency {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.maxConcurrency)

// Block until task is submitted
select {
Expand All @@ -238,15 +236,41 @@ func (p *pool) dispatchTask(task any) {
}
}

func (p *pool) startWorker() {
func (p *pool) startWorker(limit int) {
if p.workerCount.Load() >= int64(limit) {
return
}
p.workerWaitGroup.Add(1)
p.workerCount.Add(1)
go p.worker()
workerNumber := p.workerCount.Add(1)
// Guarantee at least PERSISTENT_WORKER_COUNT workers are always running during dispatch to prevent deadlocks
canExitDuringDispatch := workerNumber > PERSISTENT_WORKER_COUNT
go p.worker(canExitDuringDispatch)
}

func (p *pool) worker() {
defer func() {
func (p *pool) workerCanExit(canExitDuringDispatch bool) bool {
if canExitDuringDispatch {
p.workerCount.Add(-1)
return true
}

// Check if the dispatcher is running
if !p.dispatcherRunning.TryLock() {
// Dispatcher is running, cannot exit yet
return false
}
if len(p.tasks) > 0 {
// There are tasks in the queue, cannot exit yet
p.dispatcherRunning.Unlock()
return false
}
p.workerCount.Add(-1)
p.dispatcherRunning.Unlock()

return true
}

func (p *pool) worker(canExitDuringDispatch bool) {
defer func() {
p.workerWaitGroup.Done()
}()

Expand All @@ -255,17 +279,20 @@ func (p *pool) worker() {
select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)
return
default:
}

select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)
return
case task, ok := <-p.tasks:
if !ok || task == nil {
// Channel closed or worker killed, exit
p.workerCount.Add(-1)
return
}

Expand All @@ -276,8 +303,13 @@ func (p *pool) worker() {
p.updateMetrics(err)

default:
// No tasks left, exit
return
// No tasks left

// Check if the worker can exit
if p.workerCanExit(canExitDuringDispatch) {
return
}
continue
}
}
}
Expand Down

0 comments on commit 543ed3a

Please sign in to comment.