diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go index 31c504f63..de785bfce 100644 --- a/internal/worker/resizable_permit.go +++ b/internal/worker/resizable_permit.go @@ -66,10 +66,10 @@ func (p *resizablePermit) Count() int { return p.sem.GetCount() } -// AcquireChan creates a PermitChannel that can be used to wait for a permit -// After usage: -// 1. avoid goroutine leak by calling permitChannel.Close() -// 2. release permit by calling permit.Release() +// AcquireChan returns a channel that could be used to wait for the permit and a close function when done +// Notes: +// 1. avoid goroutine leak by calling the done function +// 2. if the read succeeded, release permit by calling permit.Release() func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) { ctx, cancel := context.WithCancel(ctx) pc := &permitChannel{ @@ -85,6 +85,7 @@ func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, fun } } +// permitChannel is an implementation to acquire a permit through channel type permitChannel struct { p Permit c chan struct{} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go index 21903fb91..141510bde 100644 --- a/internal/worker/resizable_permit_test.go +++ b/internal/worker/resizable_permit_test.go @@ -60,7 +60,7 @@ func TestPermit_Simulation(t *testing.T) { maxTestDuration: 250 * time.Millisecond, // at least need 100ms * 1000 / 200 = 500ms to acquire all permit capacity: []int{200}, goroutines: 1000, - expectFailuresRange: []int{1, 999}, // should at least pass some acquires + expectFailuresRange: []int{400, 600}, // should at least pass some acquires }, { name: "not enough permit at beginning but due to capacity change, blocking but all acquire", @@ -74,7 +74,7 @@ func TestPermit_Simulation(t *testing.T) { maxTestDuration: 250 * time.Millisecond, capacity: []int{600, 400, 200}, goroutines: 1000, - expectFailuresRange: []int{1, 999}, + expectFailuresRange: []int{1, 500}, // the worst case with 200 capacity will at least pass 500 acquires }, { name: "not enough permit for any acquire, fail all", @@ -145,18 +145,139 @@ func TestPermit_Simulation(t *testing.T) { } } -func Test_Permit_AcquireChan(t *testing.T) { - permit := NewResizablePermit(2) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - permitChan, done := permit.AcquireChan(ctx) - select { - case <-permitChan: - assert.Equal(t, 2, permit.Quota()) +// Test_Permit_Acquire tests the basic acquire functionality +// before each acquire will wait 100ms +func Test_Permit_Acquire(t *testing.T) { + + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + err := permit.Acquire(context.Background()) + assert.NoError(t, err) assert.Equal(t, 1, permit.Count()) - case <-ctx.Done(): - t.Error("unexpected timeout") + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "context deadline exceeded") + assert.Empty(t, permit.Count()) + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "canceled") + assert.Empty(t, permit.Count()) + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err := permit.Acquire(ctx) + assert.NoError(t, err) + err = permit.Acquire(ctx) + assert.ErrorContains(t, err, "failed to acquire permit") + assert.Equal(t, 1, permit.Count()) + }) +} + +func Test_Permit_Release(t *testing.T) { + for _, tt := range []struct { + name string + quota, acquire, release int + expectPanic bool + }{ + {"release all acquired permits", 10, 5, 5, false}, + {"release partial acquired permit", 10, 5, 1, false}, + {"release non acquired permit", 10, 5, 0, false}, + {"release more than acquired permit", 10, 5, 10, true}, + } { + t.Run(tt.name, func(t *testing.T) { + permit := NewResizablePermit(tt.quota) + for i := 0; i < tt.acquire; i++ { + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + } + releaseOp := func() { + for i := 0; i < tt.release; i++ { + permit.Release() + } + } + + if tt.expectPanic { + assert.Panics(t, releaseOp) + } else { + assert.NotPanics(t, releaseOp) + assert.Equal(t, tt.acquire-tt.release, permit.Count()) + } + }) } - done() - permit.Release() +} + +func Test_Permit_AcquireChan(t *testing.T) { + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + assert.Equal(t, 1, permit.Count()) + case <-ctx.Done(): + t.Errorf("permit not acquired") + } + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(4) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + for i := 0; i < 10; i++ { + channel, done := permit.AcquireChan(ctx) + select { + case <-channel: + case <-ctx.Done(): + } + done() + } + + assert.Equal(t, 4, permit.Count()) + }) }