diff --git a/internal/common/autoscaler/autoscaler.go b/internal/common/autoscaler/autoscaler.go index ecac9641c..c080917ce 100644 --- a/internal/common/autoscaler/autoscaler.go +++ b/internal/common/autoscaler/autoscaler.go @@ -24,12 +24,6 @@ package autoscaler type ( AutoScaler interface { Estimator - // Acquire X ResourceUnit of resource - Acquire(ResourceUnit) error - // Release X ResourceUnit of resource - Release(ResourceUnit) - // GetCurrent ResourceUnit of resource - GetCurrent() ResourceUnit // Start starts the autoscaler go routine that scales the ResourceUnit according to Estimator Start() // Stop stops the autoscaler if started or do nothing if not yet started diff --git a/internal/internal_poller_autoscaler.go b/internal/internal_poller_autoscaler.go index d88dd1f10..2dc81e7ba 100644 --- a/internal/internal_poller_autoscaler.go +++ b/internal/internal_poller_autoscaler.go @@ -26,11 +26,11 @@ import ( "sync" "time" - "github.com/marusama/semaphore/v2" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/cadence/internal/common/autoscaler" + "go.uber.org/cadence/internal/worker" ) // defaultPollerScalerCooldownInSeconds @@ -53,7 +53,7 @@ type ( isDryRun bool cooldownTime time.Duration logger *zap.Logger - sem semaphore.Semaphore // resizable semaphore to control number of concurrent pollers + permit worker.Permit ctx context.Context cancel context.CancelFunc wg *sync.WaitGroup // graceful stop @@ -82,6 +82,7 @@ type ( func newPollerScaler( options pollerAutoScalerOptions, logger *zap.Logger, + permit worker.Permit, hooks ...func()) *pollerAutoScaler { if !options.Enabled { return nil @@ -91,7 +92,7 @@ func newPollerScaler( isDryRun: options.DryRun, cooldownTime: options.Cooldown, logger: logger, - sem: semaphore.New(options.InitCount), + permit: permit, wg: &sync.WaitGroup{}, ctx: ctx, cancel: cancel, @@ -107,21 +108,6 @@ func newPollerScaler( } } -// Acquire concurrent poll quota -func (p *pollerAutoScaler) Acquire(resource autoscaler.ResourceUnit) error { - return p.sem.Acquire(p.ctx, int(resource)) -} - -// Release concurrent poll quota -func (p *pollerAutoScaler) Release(resource autoscaler.ResourceUnit) { - p.sem.Release(int(resource)) -} - -// GetCurrent poll quota -func (p *pollerAutoScaler) GetCurrent() autoscaler.ResourceUnit { - return autoscaler.ResourceUnit(p.sem.GetLimit()) -} - // Start an auto-scaler go routine and returns a done to stop it func (p *pollerAutoScaler) Start() { logger := p.logger.Sugar() @@ -133,7 +119,7 @@ func (p *pollerAutoScaler) Start() { case <-p.ctx.Done(): return case <-time.After(p.cooldownTime): - currentResource := autoscaler.ResourceUnit(p.sem.GetLimit()) + currentResource := autoscaler.ResourceUnit(p.permit.Quota()) currentUsages, err := p.pollerUsageEstimator.Estimate() if err != nil { logger.Warnw("poller autoscaler skip due to estimator error", "error", err) @@ -146,7 +132,7 @@ func (p *pollerAutoScaler) Start() { "recommend", uint64(proposedResource), "isDryRun", p.isDryRun) if !p.isDryRun { - p.sem.SetLimit(int(proposedResource)) + p.permit.SetQuota(int(proposedResource)) } p.pollerUsageEstimator.Reset() diff --git a/internal/internal_poller_autoscaler_test.go b/internal/internal_poller_autoscaler_test.go index 68514602f..4a441b642 100644 --- a/internal/internal_poller_autoscaler_test.go +++ b/internal/internal_poller_autoscaler_test.go @@ -21,12 +21,14 @@ package internal import ( + "context" "math/rand" "sync" "testing" "time" "go.uber.org/cadence/internal/common/testlogger" + "go.uber.org/cadence/internal/worker" "github.com/stretchr/testify/assert" "go.uber.org/atomic" @@ -171,6 +173,7 @@ func Test_pollerAutoscaler(t *testing.T) { TargetUtilization: float64(tt.args.targetMilliUsage) / 1000, }, testlogger.NewZap(t), + worker.NewResizablePermit(tt.args.initialPollerCount), // hook function that collects number of iterations func() { autoscalerEpoch.Add(1) @@ -190,18 +193,19 @@ func Test_pollerAutoscaler(t *testing.T) { go func() { defer wg.Done() for pollResult := range pollChan { - pollerScaler.Acquire(1) + err := pollerScaler.permit.Acquire(context.Background()) + assert.NoError(t, err) pollerScaler.CollectUsage(pollResult) - pollerScaler.Release(1) + pollerScaler.permit.Release() } }() } assert.Eventually(t, func() bool { return autoscalerEpoch.Load() == uint64(tt.args.autoScalerEpoch) - }, tt.args.cooldownTime+20*time.Millisecond, 10*time.Millisecond) + }, tt.args.cooldownTime+100*time.Millisecond, 10*time.Millisecond) pollerScaler.Stop() - res := pollerScaler.GetCurrent() + res := pollerScaler.permit.Quota() - pollerScaler.permit.Count() assert.Equal(t, tt.want, int(res)) }) } diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index ba9da7818..b4bfb0ad6 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -33,6 +33,7 @@ import ( "time" "go.uber.org/cadence/internal/common/debug" + "go.uber.org/cadence/internal/worker" "github.com/uber-go/tally" "go.uber.org/zap" @@ -141,7 +142,7 @@ type ( logger *zap.Logger metricsScope tally.Scope - pollerRequestCh chan struct{} + concurrency *worker.ConcurrencyLimit pollerAutoScaler *pollerAutoScaler taskQueueCh chan interface{} sessionTokenBucket *sessionTokenBucket @@ -167,11 +168,17 @@ func createPollRetryPolicy() backoff.RetryPolicy { func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker { ctx, cancel := context.WithCancel(context.Background()) + concurrency := &worker.ConcurrencyLimit{ + PollerPermit: worker.NewResizablePermit(options.pollerCount), + TaskPermit: worker.NewResizablePermit(options.maxConcurrentTask), + } + var pollerAS *pollerAutoScaler if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled { pollerAS = newPollerScaler( pollerOptions, logger, + concurrency.PollerPermit, ) } @@ -182,7 +189,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy), logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}), metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType), - pollerRequestCh: make(chan struct{}, options.maxConcurrentTask), + concurrency: concurrency, pollerAutoScaler: pollerAS, taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched. limiterContext: ctx, @@ -241,14 +248,19 @@ func (bw *baseWorker) runPoller() { bw.metricsScope.Counter(metrics.PollerStartCounter).Inc(1) for { + permitChannel, channelDone := bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext) select { case <-bw.shutdownCh: + channelDone() return - case <-bw.pollerRequestCh: - bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(cap(bw.pollerRequestCh))) - // This metric is used to monitor how many poll requests have been allocated - // and can be used to approximate number of concurrent task running (not pinpoint accurate) - bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(cap(bw.pollerRequestCh) - len(bw.pollerRequestCh))) + case <-permitChannel: // don't poll unless there is a task permit + channelDone() + // TODO move to a centralized place inside the worker + // emit metrics on concurrent task permit quota and current task permit count + // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process + // thus the metrics is only an estimated value of how many tasks are running concurrently + bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota())) + bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count())) if bw.sessionTokenBucket != nil { bw.sessionTokenBucket.waitForAvailableToken() } @@ -260,10 +272,6 @@ func (bw *baseWorker) runPoller() { func (bw *baseWorker) runTaskDispatcher() { defer bw.shutdownWG.Done() - for i := 0; i < bw.options.maxConcurrentTask; i++ { - bw.pollerRequestCh <- struct{}{} - } - for { // wait for new task or shutdown select { @@ -294,10 +302,10 @@ func (bw *baseWorker) pollTask() { var task interface{} if bw.pollerAutoScaler != nil { - if pErr := bw.pollerAutoScaler.Acquire(1); pErr == nil { - defer bw.pollerAutoScaler.Release(1) + if pErr := bw.concurrency.PollerPermit.Acquire(bw.limiterContext); pErr == nil { + defer bw.concurrency.PollerPermit.Release() } else { - bw.logger.Warn("poller auto scaler acquire error", zap.Error(pErr)) + bw.logger.Warn("poller permit acquire error", zap.Error(pErr)) } } @@ -333,7 +341,7 @@ func (bw *baseWorker) pollTask() { case <-bw.shutdownCh: } } else { - bw.pollerRequestCh <- struct{}{} // poll failed, trigger a new poll + bw.concurrency.TaskPermit.Release() // poll failed, trigger a new poll by returning a task permit } } @@ -368,7 +376,7 @@ func (bw *baseWorker) processTask(task interface{}) { } if isPolledTask { - bw.pollerRequestCh <- struct{}{} + bw.concurrency.TaskPermit.Release() // task processed, trigger a new poll by returning a task permit } }() err := bw.options.taskWorker.ProcessTask(task) diff --git a/internal/worker/concurrency.go b/internal/worker/concurrency.go new file mode 100644 index 000000000..8d0771b91 --- /dev/null +++ b/internal/worker/concurrency.go @@ -0,0 +1,41 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import "context" + +var _ Permit = (*resizablePermit)(nil) + +// ConcurrencyLimit contains synchronization primitives for dynamically controlling the concurrencies in workers +type ConcurrencyLimit struct { + PollerPermit Permit // controls concurrency of pollers + TaskPermit Permit // controls concurrency of task processing +} + +// Permit is an adaptive permit issuer to control concurrency +type Permit interface { + Acquire(context.Context) error + AcquireChan(context.Context) (channel <-chan struct{}, done func()) + Count() int + Quota() int + Release() + SetQuota(int) +} diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go new file mode 100644 index 000000000..de785bfce --- /dev/null +++ b/internal/worker/resizable_permit.go @@ -0,0 +1,120 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "fmt" + "sync" + + "github.com/marusama/semaphore/v2" +) + +type resizablePermit struct { + sem semaphore.Semaphore +} + +// NewResizablePermit creates a dynamic permit that's resizable +func NewResizablePermit(initCount int) Permit { + return &resizablePermit{sem: semaphore.New(initCount)} +} + +// Acquire is blocking until a permit is acquired or returns error after context is done +// Remember to call Release(count) to release the permit after usage +func (p *resizablePermit) Acquire(ctx context.Context) error { + if err := p.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire permit before context is done: %w", err) + } + return nil +} + +// Release release one permit +func (p *resizablePermit) Release() { + p.sem.Release(1) +} + +// Quota returns the maximum number of permits that can be acquired +func (p *resizablePermit) Quota() int { + return p.sem.GetLimit() +} + +// SetQuota sets the maximum number of permits that can be acquired +func (p *resizablePermit) SetQuota(c int) { + p.sem.SetLimit(c) +} + +// Count returns the number of permits available +func (p *resizablePermit) Count() int { + return p.sem.GetCount() +} + +// AcquireChan returns a channel that could be used to wait for the permit and a close function when done +// Notes: +// 1. avoid goroutine leak by calling the done function +// 2. if the read succeeded, release permit by calling permit.Release() +func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) { + ctx, cancel := context.WithCancel(ctx) + pc := &permitChannel{ + p: p, + c: make(chan struct{}), + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + } + pc.Start() + return pc.C(), func() { + pc.Close() + } +} + +// permitChannel is an implementation to acquire a permit through channel +type permitChannel struct { + p Permit + c chan struct{} + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +func (ch *permitChannel) C() <-chan struct{} { + return ch.c +} + +func (ch *permitChannel) Start() { + ch.wg.Add(1) + go func() { + defer ch.wg.Done() + if err := ch.p.Acquire(ch.ctx); err != nil { + return + } + // avoid blocking on sending to the channel + select { + case ch.c <- struct{}{}: + case <-ch.ctx.Done(): // release if acquire is successful but fail to send to the channel + ch.p.Release() + } + }() +} + +func (ch *permitChannel) Close() { + ch.cancel() + ch.wg.Wait() +} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go new file mode 100644 index 000000000..141510bde --- /dev/null +++ b/internal/worker/resizable_permit_test.go @@ -0,0 +1,283 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "sync" + "testing" + "time" + + "math/rand" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "go.uber.org/goleak" +) + +func TestPermit_Simulation(t *testing.T) { + tests := []struct { + name string + capacity []int // update every 50ms + goroutines int // each would block on acquiring 1 token for 100-150ms + maxTestDuration time.Duration + expectFailuresRange []int // range of failures, inclusive [min, max] + }{ + { + name: "enough permit, no blocking", + maxTestDuration: 200 * time.Millisecond, // at most need 150 ms, add 50 ms buffer + capacity: []int{10000}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit, blocking but all acquire", + maxTestDuration: 800 * time.Millisecond, // at most need 150ms * 1000 / 200 = 750ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit for some to acquire, fail some", + maxTestDuration: 250 * time.Millisecond, // at least need 100ms * 1000 / 200 = 500ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{400, 600}, // should at least pass some acquires + }, + { + name: "not enough permit at beginning but due to capacity change, blocking but all acquire", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{200, 400, 600}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "enough permit at beginning but due to capacity change, some would fail", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{600, 400, 200}, + goroutines: 1000, + expectFailuresRange: []int{1, 500}, // the worst case with 200 capacity will at least pass 500 acquires + }, + { + name: "not enough permit for any acquire, fail all", + maxTestDuration: 300 * time.Millisecond, + capacity: []int{0}, + goroutines: 1000, + expectFailuresRange: []int{1000, 1000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer goleak.VerifyNone(t) + wg := &sync.WaitGroup{} + permit := NewResizablePermit(tt.capacity[0]) + wg.Add(1) + go func() { // update quota every 50ms + defer wg.Done() + for i := 1; i < len(tt.capacity); i++ { + time.Sleep(50 * time.Millisecond) + permit.SetQuota(tt.capacity[i]) + } + }() + failures := atomic.NewInt32(0) + ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration) + defer cancel() + + aquireChan := tt.goroutines / 2 + for i := 0; i < tt.goroutines-aquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := permit.Acquire(ctx); err != nil { + failures.Inc() + return + } + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + }() + } + for i := 0; i < aquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + permitChan, done := permit.AcquireChan(ctx) + select { + case <-permitChan: + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + case <-ctx.Done(): + failures.Inc() + } + done() + }() + } + + wg.Wait() + // sanity check + assert.Equal(t, 0, permit.Count(), "all permit should be released") + assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota()) + + // expect failures in range + expectFailureMin := tt.expectFailuresRange[0] + expectFailureMax := tt.expectFailuresRange[1] + assert.GreaterOrEqual(t, int(failures.Load()), expectFailureMin) + assert.LessOrEqual(t, int(failures.Load()), expectFailureMax) + }) + } +} + +// Test_Permit_Acquire tests the basic acquire functionality +// before each acquire will wait 100ms +func Test_Permit_Acquire(t *testing.T) { + + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, permit.Count()) + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "context deadline exceeded") + assert.Empty(t, permit.Count()) + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "canceled") + assert.Empty(t, permit.Count()) + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err := permit.Acquire(ctx) + assert.NoError(t, err) + err = permit.Acquire(ctx) + assert.ErrorContains(t, err, "failed to acquire permit") + assert.Equal(t, 1, permit.Count()) + }) +} + +func Test_Permit_Release(t *testing.T) { + for _, tt := range []struct { + name string + quota, acquire, release int + expectPanic bool + }{ + {"release all acquired permits", 10, 5, 5, false}, + {"release partial acquired permit", 10, 5, 1, false}, + {"release non acquired permit", 10, 5, 0, false}, + {"release more than acquired permit", 10, 5, 10, true}, + } { + t.Run(tt.name, func(t *testing.T) { + permit := NewResizablePermit(tt.quota) + for i := 0; i < tt.acquire; i++ { + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + } + releaseOp := func() { + for i := 0; i < tt.release; i++ { + permit.Release() + } + } + + if tt.expectPanic { + assert.Panics(t, releaseOp) + } else { + assert.NotPanics(t, releaseOp) + assert.Equal(t, tt.acquire-tt.release, permit.Count()) + } + }) + } +} + +func Test_Permit_AcquireChan(t *testing.T) { + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + assert.Equal(t, 1, permit.Count()) + case <-ctx.Done(): + t.Errorf("permit not acquired") + } + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(4) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + for i := 0; i < 10; i++ { + channel, done := permit.AcquireChan(ctx) + select { + case <-channel: + case <-ctx.Done(): + } + done() + } + + assert.Equal(t, 4, permit.Count()) + }) +}