From 14e3c01e8793c3bfe6221a9654215e7c1efb6e68 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Tue, 3 Dec 2024 10:25:11 -0800 Subject: [PATCH 1/5] add concurrencylimit entity to worker --- internal/common/autoscaler/autoscaler.go | 6 - internal/internal_poller_autoscaler.go | 31 +---- internal/internal_poller_autoscaler_test.go | 9 +- internal/internal_worker_base.go | 38 +++--- internal/worker/channel_permit.go | 75 ++++++++++++ internal/worker/concurrency.go | 41 +++++++ internal/worker/resizable_permit.go | 63 ++++++++++ internal/worker/resizable_permit_test.go | 127 ++++++++++++++++++++ 8 files changed, 340 insertions(+), 50 deletions(-) create mode 100644 internal/worker/channel_permit.go create mode 100644 internal/worker/concurrency.go create mode 100644 internal/worker/resizable_permit.go create mode 100644 internal/worker/resizable_permit_test.go diff --git a/internal/common/autoscaler/autoscaler.go b/internal/common/autoscaler/autoscaler.go index ecac9641c..c080917ce 100644 --- a/internal/common/autoscaler/autoscaler.go +++ b/internal/common/autoscaler/autoscaler.go @@ -24,12 +24,6 @@ package autoscaler type ( AutoScaler interface { Estimator - // Acquire X ResourceUnit of resource - Acquire(ResourceUnit) error - // Release X ResourceUnit of resource - Release(ResourceUnit) - // GetCurrent ResourceUnit of resource - GetCurrent() ResourceUnit // Start starts the autoscaler go routine that scales the ResourceUnit according to Estimator Start() // Stop stops the autoscaler if started or do nothing if not yet started diff --git a/internal/internal_poller_autoscaler.go b/internal/internal_poller_autoscaler.go index d88dd1f10..9185d5641 100644 --- a/internal/internal_poller_autoscaler.go +++ b/internal/internal_poller_autoscaler.go @@ -26,11 +26,11 @@ import ( "sync" "time" - "github.com/marusama/semaphore/v2" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/cadence/internal/common/autoscaler" + "go.uber.org/cadence/internal/worker" ) // defaultPollerScalerCooldownInSeconds @@ -53,7 +53,7 @@ type ( isDryRun bool cooldownTime time.Duration logger *zap.Logger - sem semaphore.Semaphore // resizable semaphore to control number of concurrent pollers + permit worker.Permit ctx context.Context cancel context.CancelFunc wg *sync.WaitGroup // graceful stop @@ -82,6 +82,7 @@ type ( func newPollerScaler( options pollerAutoScalerOptions, logger *zap.Logger, + permit worker.Permit, hooks ...func()) *pollerAutoScaler { if !options.Enabled { return nil @@ -91,7 +92,7 @@ func newPollerScaler( isDryRun: options.DryRun, cooldownTime: options.Cooldown, logger: logger, - sem: semaphore.New(options.InitCount), + permit: permit, wg: &sync.WaitGroup{}, ctx: ctx, cancel: cancel, @@ -107,21 +108,6 @@ func newPollerScaler( } } -// Acquire concurrent poll quota -func (p *pollerAutoScaler) Acquire(resource autoscaler.ResourceUnit) error { - return p.sem.Acquire(p.ctx, int(resource)) -} - -// Release concurrent poll quota -func (p *pollerAutoScaler) Release(resource autoscaler.ResourceUnit) { - p.sem.Release(int(resource)) -} - -// GetCurrent poll quota -func (p *pollerAutoScaler) GetCurrent() autoscaler.ResourceUnit { - return autoscaler.ResourceUnit(p.sem.GetLimit()) -} - // Start an auto-scaler go routine and returns a done to stop it func (p *pollerAutoScaler) Start() { logger := p.logger.Sugar() @@ -133,7 +119,7 @@ func (p *pollerAutoScaler) Start() { case <-p.ctx.Done(): return case <-time.After(p.cooldownTime): - currentResource := autoscaler.ResourceUnit(p.sem.GetLimit()) + currentResource := autoscaler.ResourceUnit(p.permit.Quota()) currentUsages, err := p.pollerUsageEstimator.Estimate() if err != nil { logger.Warnw("poller autoscaler skip due to estimator error", "error", err) @@ -146,14 +132,9 @@ func (p *pollerAutoScaler) Start() { "recommend", uint64(proposedResource), "isDryRun", p.isDryRun) if !p.isDryRun { - p.sem.SetLimit(int(proposedResource)) + p.permit.SetQuota(int(proposedResource)) } p.pollerUsageEstimator.Reset() - - // hooks - for i := range p.onAutoScale { - p.onAutoScale[i]() - } } } }() diff --git a/internal/internal_poller_autoscaler_test.go b/internal/internal_poller_autoscaler_test.go index 68514602f..c1a3dfb4f 100644 --- a/internal/internal_poller_autoscaler_test.go +++ b/internal/internal_poller_autoscaler_test.go @@ -21,12 +21,14 @@ package internal import ( + "context" "math/rand" "sync" "testing" "time" "go.uber.org/cadence/internal/common/testlogger" + "go.uber.org/cadence/internal/worker" "github.com/stretchr/testify/assert" "go.uber.org/atomic" @@ -171,6 +173,7 @@ func Test_pollerAutoscaler(t *testing.T) { TargetUtilization: float64(tt.args.targetMilliUsage) / 1000, }, testlogger.NewZap(t), + worker.NewResizablePermit(tt.args.initialPollerCount), // hook function that collects number of iterations func() { autoscalerEpoch.Add(1) @@ -190,9 +193,9 @@ func Test_pollerAutoscaler(t *testing.T) { go func() { defer wg.Done() for pollResult := range pollChan { - pollerScaler.Acquire(1) + pollerScaler.permit.Acquire(context.Background()) pollerScaler.CollectUsage(pollResult) - pollerScaler.Release(1) + pollerScaler.permit.Release() } }() } @@ -201,7 +204,7 @@ func Test_pollerAutoscaler(t *testing.T) { return autoscalerEpoch.Load() == uint64(tt.args.autoScalerEpoch) }, tt.args.cooldownTime+20*time.Millisecond, 10*time.Millisecond) pollerScaler.Stop() - res := pollerScaler.GetCurrent() + res := pollerScaler.permit.Quota() - pollerScaler.permit.Count() assert.Equal(t, tt.want, int(res)) }) } diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index ba9da7818..4fa30c6b0 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -33,6 +33,7 @@ import ( "time" "go.uber.org/cadence/internal/common/debug" + "go.uber.org/cadence/internal/worker" "github.com/uber-go/tally" "go.uber.org/zap" @@ -141,7 +142,7 @@ type ( logger *zap.Logger metricsScope tally.Scope - pollerRequestCh chan struct{} + concurrency *worker.ConcurrencyLimit pollerAutoScaler *pollerAutoScaler taskQueueCh chan interface{} sessionTokenBucket *sessionTokenBucket @@ -167,11 +168,18 @@ func createPollRetryPolicy() backoff.RetryPolicy { func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker { ctx, cancel := context.WithCancel(context.Background()) + concurrency := &worker.ConcurrencyLimit{ + PollerPermit: worker.NewResizablePermit(options.pollerCount), + TaskPermit: worker.NewChannelPermit(options.maxConcurrentTask), + } + var pollerAS *pollerAutoScaler if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled { + concurrency.PollerPermit = worker.NewResizablePermit(pollerOptions.InitCount) pollerAS = newPollerScaler( pollerOptions, logger, + concurrency.PollerPermit, ) } @@ -182,7 +190,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy), logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}), metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType), - pollerRequestCh: make(chan struct{}, options.maxConcurrentTask), + concurrency: concurrency, pollerAutoScaler: pollerAS, taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched. limiterContext: ctx, @@ -244,11 +252,13 @@ func (bw *baseWorker) runPoller() { select { case <-bw.shutdownCh: return - case <-bw.pollerRequestCh: - bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(cap(bw.pollerRequestCh))) - // This metric is used to monitor how many poll requests have been allocated - // and can be used to approximate number of concurrent task running (not pinpoint accurate) - bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(cap(bw.pollerRequestCh) - len(bw.pollerRequestCh))) + case <-bw.concurrency.TaskPermit.GetChan(): // don't poll unless there is a task permit + // TODO move to a centralized place inside the worker + // emit metrics on concurrent task permit quota and current task permit count + // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process + // thus the metrics is only an estimated value of how many tasks are running concurrently + bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota())) + bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count())) if bw.sessionTokenBucket != nil { bw.sessionTokenBucket.waitForAvailableToken() } @@ -260,10 +270,6 @@ func (bw *baseWorker) runPoller() { func (bw *baseWorker) runTaskDispatcher() { defer bw.shutdownWG.Done() - for i := 0; i < bw.options.maxConcurrentTask; i++ { - bw.pollerRequestCh <- struct{}{} - } - for { // wait for new task or shutdown select { @@ -294,10 +300,10 @@ func (bw *baseWorker) pollTask() { var task interface{} if bw.pollerAutoScaler != nil { - if pErr := bw.pollerAutoScaler.Acquire(1); pErr == nil { - defer bw.pollerAutoScaler.Release(1) + if pErr := bw.concurrency.PollerPermit.Acquire(bw.limiterContext); pErr == nil { + defer bw.concurrency.PollerPermit.Release() } else { - bw.logger.Warn("poller auto scaler acquire error", zap.Error(pErr)) + bw.logger.Warn("poller permit acquire error", zap.Error(pErr)) } } @@ -333,7 +339,7 @@ func (bw *baseWorker) pollTask() { case <-bw.shutdownCh: } } else { - bw.pollerRequestCh <- struct{}{} // poll failed, trigger a new poll + bw.concurrency.TaskPermit.Release() // poll failed, trigger a new poll by returning a task permit } } @@ -368,7 +374,7 @@ func (bw *baseWorker) processTask(task interface{}) { } if isPolledTask { - bw.pollerRequestCh <- struct{}{} + bw.concurrency.TaskPermit.Release() // task processed, trigger a new poll by returning a task permit } }() err := bw.options.taskWorker.ProcessTask(task) diff --git a/internal/worker/channel_permit.go b/internal/worker/channel_permit.go new file mode 100644 index 000000000..713b65dea --- /dev/null +++ b/internal/worker/channel_permit.go @@ -0,0 +1,75 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "fmt" +) + +// ChannelPermit is a handy wrapper entity over a buffered channel +type ChannelPermit interface { + Acquire(context.Context) error + Count() int + Quota() int + Release() + GetChan() <-chan struct{} // fetch the underlying channel +} + +type channelPermit struct { + channel chan struct{} +} + +// NewChannelPermit creates a static permit that's not resizable +func NewChannelPermit(count int) ChannelPermit { + channel := make(chan struct{}, count) + for i := 0; i < count; i++ { + channel <- struct{}{} + } + return &channelPermit{channel: channel} +} + +func (p *channelPermit) Acquire(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("failed to acquire permit before context is done") + case p.channel <- struct{}{}: + return nil + } +} + +// AcquireChan returns a permit ready channel +func (p *channelPermit) GetChan() <-chan struct{} { + return p.channel +} + +func (p *channelPermit) Release() { + p.channel <- struct{}{} +} + +// Count returns the number of permits available +func (p *channelPermit) Count() int { + return len(p.channel) +} + +func (p *channelPermit) Quota() int { + return cap(p.channel) +} diff --git a/internal/worker/concurrency.go b/internal/worker/concurrency.go new file mode 100644 index 000000000..81f44ba73 --- /dev/null +++ b/internal/worker/concurrency.go @@ -0,0 +1,41 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import "context" + +var _ Permit = (*resizablePermit)(nil) +var _ ChannelPermit = (*channelPermit)(nil) + +// ConcurrencyLimit contains synchronization primitives for dynamically controlling the concurrencies in workers +type ConcurrencyLimit struct { + PollerPermit Permit // controls concurrency of pollers + TaskPermit ChannelPermit // controls concurrency of task processing +} + +// Permit is an adaptive permit issuer to control concurrency +type Permit interface { + Acquire(context.Context) error + Count() int + Quota() int + Release() + SetQuota(int) +} diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go new file mode 100644 index 000000000..1404bbb73 --- /dev/null +++ b/internal/worker/resizable_permit.go @@ -0,0 +1,63 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "fmt" + + "github.com/marusama/semaphore/v2" +) + +type resizablePermit struct { + sem semaphore.Semaphore +} + +// NewResizablePermit creates a dynamic permit that's resizable +func NewResizablePermit(initCount int) Permit { + return &resizablePermit{sem: semaphore.New(initCount)} +} + +// Acquire is blocking until a permit is acquired or returns error after context is done +// Remember to call Release(count) to release the permit after usage +func (p *resizablePermit) Acquire(ctx context.Context) error { + if err := p.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire permit before context is done: %w", err) + } + return nil +} + +func (p *resizablePermit) Release() { + p.sem.Release(1) +} + +func (p *resizablePermit) Quota() int { + return p.sem.GetLimit() +} + +func (p *resizablePermit) SetQuota(c int) { + p.sem.SetLimit(c) +} + +// Count returns the number of permits available +func (p *resizablePermit) Count() int { + return p.sem.GetCount() +} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go new file mode 100644 index 000000000..a9bfcf338 --- /dev/null +++ b/internal/worker/resizable_permit_test.go @@ -0,0 +1,127 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "sync" + "testing" + "time" + + "math/rand" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestPermit_Simulation(t *testing.T) { + tests := []struct { + name string + capacity []int // update every 50ms + goroutines int // each would block on acquiring 1 token for 100-150ms + maxTestDuration time.Duration + expectFailuresRange []int // range of failures, inclusive [min, max] + }{ + { + name: "enough permit, no blocking", + maxTestDuration: 200 * time.Millisecond, // at most need 150 ms, add 50 ms buffer + capacity: []int{10000}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit, blocking but all acquire", + maxTestDuration: 800 * time.Millisecond, // at most need 150ms * 1000 / 200 = 750ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit for some to acquire, fail some", + maxTestDuration: 250 * time.Millisecond, // at least need 100ms * 1000 / 200 = 500ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{1, 999}, // should at least pass some acquires + }, + { + name: "not enough permit at beginning but due to capacity change, blocking but all acquire", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{200, 400, 600}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "enough permit at beginning but due to capacity change, some would fail", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{600, 400, 200}, + goroutines: 1000, + expectFailuresRange: []int{1, 999}, + }, + { + name: "not enough permit for any acquire, fail all", + maxTestDuration: 300 * time.Millisecond, + capacity: []int{0}, + goroutines: 1000, + expectFailuresRange: []int{1000, 1000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wg := &sync.WaitGroup{} + permit := NewResizablePermit(tt.capacity[0]) + wg.Add(1) + go func() { // update quota every 50ms + defer wg.Done() + for i := 1; i < len(tt.capacity); i++ { + time.Sleep(50 * time.Millisecond) + permit.SetQuota(tt.capacity[i]) + } + }() + failures := atomic.NewInt32(0) + ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration) + defer cancel() + for i := 0; i < tt.goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := permit.Acquire(ctx); err != nil { + failures.Inc() + return + } + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + }() + } + + wg.Wait() + // sanity check + assert.Equal(t, 0, permit.Count()) + assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota()) + + // expect failures in range + expectFailureMin := tt.expectFailuresRange[0] + expectFailureMax := tt.expectFailuresRange[1] + assert.GreaterOrEqual(t, int(failures.Load()), expectFailureMin) + assert.LessOrEqual(t, int(failures.Load()), expectFailureMax) + }) + } +} From 082d125ae6d8ee0848bcd130acd80edbe1180b78 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Tue, 3 Dec 2024 11:01:09 -0800 Subject: [PATCH 2/5] wip --- internal/worker/concurrency_auto_scaler.go | 60 ++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 internal/worker/concurrency_auto_scaler.go diff --git a/internal/worker/concurrency_auto_scaler.go b/internal/worker/concurrency_auto_scaler.go new file mode 100644 index 000000000..8a2795599 --- /dev/null +++ b/internal/worker/concurrency_auto_scaler.go @@ -0,0 +1,60 @@ +package worker + +import ( + "context" + "sync" + "time" +) + +type ConcurrencyAutoScaler struct { + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + + concurrency ConcurrencyLimit + tick time.Duration + + PollerPermitLastUpdate time.Time +} + +type ConcurrencyAutoScalerOptions struct { + Concurrency ConcurrencyLimit + Tick time.Duration // frequency of auto tuning + Cooldown time.Duration // cooldown time of update + +} + +func NewConcurrencyAutoScaler(options ConcurrencyAutoScalerOptions) *ConcurrencyAutoScaler { + ctx, cancel := context.WithCancel(context.Background()) + + return &ConcurrencyAutoScaler{ + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + concurrency: options.Concurrency, + tick: options.Tick, + } +} + +func (c *ConcurrencyAutoScaler) Start() { + c.wg.Add(1) + go func () { + defer c.wg.Done() + for { + select { + case <-c.ctx.Done(): + case <-time.Tick(c.tick): + + } + } + }() +} + +func (c *ConcurrencyAutoScaler) updatePollerPermit() { + c.PollerPermitLastUpdate = time.Now() +} + +func (c *ConcurrencyAutoScaler) Stop() { + c.cancel() + c.wg.Wait() +} From c5f2d53753c97f059b9321d9396f8d2f5f86db4f Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Thu, 5 Dec 2024 11:14:44 -0800 Subject: [PATCH 3/5] add PermitChannel --- internal/internal_worker_base.go | 9 ++- internal/worker/channel_permit.go | 75 ---------------------- internal/worker/concurrency.go | 13 +++- internal/worker/concurrency_auto_scaler.go | 60 ----------------- internal/worker/resizable_permit.go | 54 ++++++++++++++++ internal/worker/resizable_permit_test.go | 21 +++++- 6 files changed, 89 insertions(+), 143 deletions(-) delete mode 100644 internal/worker/channel_permit.go delete mode 100644 internal/worker/concurrency_auto_scaler.go diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 4fa30c6b0..79e6b9086 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -170,12 +170,11 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t concurrency := &worker.ConcurrencyLimit{ PollerPermit: worker.NewResizablePermit(options.pollerCount), - TaskPermit: worker.NewChannelPermit(options.maxConcurrentTask), + TaskPermit: worker.NewResizablePermit(options.maxConcurrentTask), } var pollerAS *pollerAutoScaler if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled { - concurrency.PollerPermit = worker.NewResizablePermit(pollerOptions.InitCount) pollerAS = newPollerScaler( pollerOptions, logger, @@ -249,10 +248,14 @@ func (bw *baseWorker) runPoller() { bw.metricsScope.Counter(metrics.PollerStartCounter).Inc(1) for { + // permitChannel can be blocking without passing context because shutdownCh is used + permitChannel := bw.concurrency.PollerPermit.AcquireChan(context.Background()) select { case <-bw.shutdownCh: + permitChannel.Close() return - case <-bw.concurrency.TaskPermit.GetChan(): // don't poll unless there is a task permit + case <-permitChannel.C(): // don't poll unless there is a task permit + permitChannel.Close() // TODO move to a centralized place inside the worker // emit metrics on concurrent task permit quota and current task permit count // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process diff --git a/internal/worker/channel_permit.go b/internal/worker/channel_permit.go deleted file mode 100644 index 713b65dea..000000000 --- a/internal/worker/channel_permit.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2017-2021 Uber Technologies Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package worker - -import ( - "context" - "fmt" -) - -// ChannelPermit is a handy wrapper entity over a buffered channel -type ChannelPermit interface { - Acquire(context.Context) error - Count() int - Quota() int - Release() - GetChan() <-chan struct{} // fetch the underlying channel -} - -type channelPermit struct { - channel chan struct{} -} - -// NewChannelPermit creates a static permit that's not resizable -func NewChannelPermit(count int) ChannelPermit { - channel := make(chan struct{}, count) - for i := 0; i < count; i++ { - channel <- struct{}{} - } - return &channelPermit{channel: channel} -} - -func (p *channelPermit) Acquire(ctx context.Context) error { - select { - case <-ctx.Done(): - return fmt.Errorf("failed to acquire permit before context is done") - case p.channel <- struct{}{}: - return nil - } -} - -// AcquireChan returns a permit ready channel -func (p *channelPermit) GetChan() <-chan struct{} { - return p.channel -} - -func (p *channelPermit) Release() { - p.channel <- struct{}{} -} - -// Count returns the number of permits available -func (p *channelPermit) Count() int { - return len(p.channel) -} - -func (p *channelPermit) Quota() int { - return cap(p.channel) -} diff --git a/internal/worker/concurrency.go b/internal/worker/concurrency.go index 81f44ba73..aa246bacf 100644 --- a/internal/worker/concurrency.go +++ b/internal/worker/concurrency.go @@ -23,19 +23,26 @@ package worker import "context" var _ Permit = (*resizablePermit)(nil) -var _ ChannelPermit = (*channelPermit)(nil) // ConcurrencyLimit contains synchronization primitives for dynamically controlling the concurrencies in workers type ConcurrencyLimit struct { - PollerPermit Permit // controls concurrency of pollers - TaskPermit ChannelPermit // controls concurrency of task processing + PollerPermit Permit // controls concurrency of pollers + TaskPermit Permit // controls concurrency of task processing } // Permit is an adaptive permit issuer to control concurrency type Permit interface { Acquire(context.Context) error + AcquireChan(context.Context) PermitChannel Count() int Quota() int Release() SetQuota(int) } + +// PermitChannel is a channel that can be used to wait for a permit to be available +// Remember to call Close() to avoid goroutine leak +type PermitChannel interface { + C() <-chan struct{} + Close() +} diff --git a/internal/worker/concurrency_auto_scaler.go b/internal/worker/concurrency_auto_scaler.go deleted file mode 100644 index 8a2795599..000000000 --- a/internal/worker/concurrency_auto_scaler.go +++ /dev/null @@ -1,60 +0,0 @@ -package worker - -import ( - "context" - "sync" - "time" -) - -type ConcurrencyAutoScaler struct { - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup - - concurrency ConcurrencyLimit - tick time.Duration - - PollerPermitLastUpdate time.Time -} - -type ConcurrencyAutoScalerOptions struct { - Concurrency ConcurrencyLimit - Tick time.Duration // frequency of auto tuning - Cooldown time.Duration // cooldown time of update - -} - -func NewConcurrencyAutoScaler(options ConcurrencyAutoScalerOptions) *ConcurrencyAutoScaler { - ctx, cancel := context.WithCancel(context.Background()) - - return &ConcurrencyAutoScaler{ - ctx: ctx, - cancel: cancel, - wg: &sync.WaitGroup{}, - concurrency: options.Concurrency, - tick: options.Tick, - } -} - -func (c *ConcurrencyAutoScaler) Start() { - c.wg.Add(1) - go func () { - defer c.wg.Done() - for { - select { - case <-c.ctx.Done(): - case <-time.Tick(c.tick): - - } - } - }() -} - -func (c *ConcurrencyAutoScaler) updatePollerPermit() { - c.PollerPermitLastUpdate = time.Now() -} - -func (c *ConcurrencyAutoScaler) Stop() { - c.cancel() - c.wg.Wait() -} diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go index 1404bbb73..c0ab89d13 100644 --- a/internal/worker/resizable_permit.go +++ b/internal/worker/resizable_permit.go @@ -23,6 +23,7 @@ package worker import ( "context" "fmt" + "sync" "github.com/marusama/semaphore/v2" ) @@ -45,14 +46,17 @@ func (p *resizablePermit) Acquire(ctx context.Context) error { return nil } +// Release release one permit func (p *resizablePermit) Release() { p.sem.Release(1) } +// Quota returns the maximum number of permits that can be acquired func (p *resizablePermit) Quota() int { return p.sem.GetLimit() } +// SetQuota sets the maximum number of permits that can be acquired func (p *resizablePermit) SetQuota(c int) { p.sem.SetLimit(c) } @@ -61,3 +65,53 @@ func (p *resizablePermit) SetQuota(c int) { func (p *resizablePermit) Count() int { return p.sem.GetCount() } + +// AcquireChan creates a PermitChannel that can be used to wait for a permit +// After usage: +// 1. avoid goroutine leak by calling permitChannel.Close() +// 2. release permit by calling permit.Release() +func (p *resizablePermit) AcquireChan(ctx context.Context) PermitChannel { + ctx, cancel := context.WithCancel(ctx) + pc := &permitChannel{ + p: p, + c: make(chan struct{}), + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + } + pc.start() + return pc +} + +type permitChannel struct { + p Permit + c chan struct{} + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +func (ch *permitChannel) C() <-chan struct{} { + return ch.c +} + +func (ch *permitChannel) start() { + ch.wg.Add(1) + go func() { + defer ch.wg.Done() + if err := ch.p.Acquire(ch.ctx); err != nil { + return + } + // avoid blocking on sending to the channel + select { + case ch.c <- struct{}{}: + case <-ch.ctx.Done(): // release if acquire is successful but fail to send to the channel + ch.p.Release() + } + }() +} + +func (ch *permitChannel) Close() { + ch.cancel() + ch.wg.Wait() +} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go index a9bfcf338..4b7d39f01 100644 --- a/internal/worker/resizable_permit_test.go +++ b/internal/worker/resizable_permit_test.go @@ -99,7 +99,9 @@ func TestPermit_Simulation(t *testing.T) { failures := atomic.NewInt32(0) ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration) defer cancel() - for i := 0; i < tt.goroutines; i++ { + + aquireChan := tt.goroutines / 2 + for i := 0; i < tt.goroutines-aquireChan; i++ { wg.Add(1) go func() { defer wg.Done() @@ -111,10 +113,25 @@ func TestPermit_Simulation(t *testing.T) { permit.Release() }() } + for i := 0; i < aquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + permitChan := permit.AcquireChan(ctx) + select { + case <-permitChan.C(): + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + case <-ctx.Done(): + failures.Inc() + } + permitChan.Close() + }() + } wg.Wait() // sanity check - assert.Equal(t, 0, permit.Count()) + assert.Equal(t, 0, permit.Count(), "all permit should be released") assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota()) // expect failures in range From 40cd57cd195e79c96774b284edb3ea0f65df950b Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Fri, 6 Dec 2024 13:36:16 -0800 Subject: [PATCH 4/5] fix unit test and address comment on AcquireChan return --- internal/internal_poller_autoscaler.go | 5 +++++ internal/internal_poller_autoscaler_test.go | 5 +++-- internal/internal_worker_base.go | 9 ++++---- internal/worker/concurrency.go | 9 +------- internal/worker/resizable_permit.go | 10 +++++---- internal/worker/resizable_permit_test.go | 24 ++++++++++++++++++--- 6 files changed, 40 insertions(+), 22 deletions(-) diff --git a/internal/internal_poller_autoscaler.go b/internal/internal_poller_autoscaler.go index 9185d5641..2dc81e7ba 100644 --- a/internal/internal_poller_autoscaler.go +++ b/internal/internal_poller_autoscaler.go @@ -135,6 +135,11 @@ func (p *pollerAutoScaler) Start() { p.permit.SetQuota(int(proposedResource)) } p.pollerUsageEstimator.Reset() + + // hooks + for i := range p.onAutoScale { + p.onAutoScale[i]() + } } } }() diff --git a/internal/internal_poller_autoscaler_test.go b/internal/internal_poller_autoscaler_test.go index c1a3dfb4f..4a441b642 100644 --- a/internal/internal_poller_autoscaler_test.go +++ b/internal/internal_poller_autoscaler_test.go @@ -193,7 +193,8 @@ func Test_pollerAutoscaler(t *testing.T) { go func() { defer wg.Done() for pollResult := range pollChan { - pollerScaler.permit.Acquire(context.Background()) + err := pollerScaler.permit.Acquire(context.Background()) + assert.NoError(t, err) pollerScaler.CollectUsage(pollResult) pollerScaler.permit.Release() } @@ -202,7 +203,7 @@ func Test_pollerAutoscaler(t *testing.T) { assert.Eventually(t, func() bool { return autoscalerEpoch.Load() == uint64(tt.args.autoScalerEpoch) - }, tt.args.cooldownTime+20*time.Millisecond, 10*time.Millisecond) + }, tt.args.cooldownTime+100*time.Millisecond, 10*time.Millisecond) pollerScaler.Stop() res := pollerScaler.permit.Quota() - pollerScaler.permit.Count() assert.Equal(t, tt.want, int(res)) diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 79e6b9086..b4bfb0ad6 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -248,14 +248,13 @@ func (bw *baseWorker) runPoller() { bw.metricsScope.Counter(metrics.PollerStartCounter).Inc(1) for { - // permitChannel can be blocking without passing context because shutdownCh is used - permitChannel := bw.concurrency.PollerPermit.AcquireChan(context.Background()) + permitChannel, channelDone := bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext) select { case <-bw.shutdownCh: - permitChannel.Close() + channelDone() return - case <-permitChannel.C(): // don't poll unless there is a task permit - permitChannel.Close() + case <-permitChannel: // don't poll unless there is a task permit + channelDone() // TODO move to a centralized place inside the worker // emit metrics on concurrent task permit quota and current task permit count // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process diff --git a/internal/worker/concurrency.go b/internal/worker/concurrency.go index aa246bacf..8d0771b91 100644 --- a/internal/worker/concurrency.go +++ b/internal/worker/concurrency.go @@ -33,16 +33,9 @@ type ConcurrencyLimit struct { // Permit is an adaptive permit issuer to control concurrency type Permit interface { Acquire(context.Context) error - AcquireChan(context.Context) PermitChannel + AcquireChan(context.Context) (channel <-chan struct{}, done func()) Count() int Quota() int Release() SetQuota(int) } - -// PermitChannel is a channel that can be used to wait for a permit to be available -// Remember to call Close() to avoid goroutine leak -type PermitChannel interface { - C() <-chan struct{} - Close() -} diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go index c0ab89d13..31c504f63 100644 --- a/internal/worker/resizable_permit.go +++ b/internal/worker/resizable_permit.go @@ -70,7 +70,7 @@ func (p *resizablePermit) Count() int { // After usage: // 1. avoid goroutine leak by calling permitChannel.Close() // 2. release permit by calling permit.Release() -func (p *resizablePermit) AcquireChan(ctx context.Context) PermitChannel { +func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) { ctx, cancel := context.WithCancel(ctx) pc := &permitChannel{ p: p, @@ -79,8 +79,10 @@ func (p *resizablePermit) AcquireChan(ctx context.Context) PermitChannel { cancel: cancel, wg: &sync.WaitGroup{}, } - pc.start() - return pc + pc.Start() + return pc.C(), func() { + pc.Close() + } } type permitChannel struct { @@ -95,7 +97,7 @@ func (ch *permitChannel) C() <-chan struct{} { return ch.c } -func (ch *permitChannel) start() { +func (ch *permitChannel) Start() { ch.wg.Add(1) go func() { defer ch.wg.Done() diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go index 4b7d39f01..21903fb91 100644 --- a/internal/worker/resizable_permit_test.go +++ b/internal/worker/resizable_permit_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/atomic" + "go.uber.org/goleak" ) func TestPermit_Simulation(t *testing.T) { @@ -86,6 +87,7 @@ func TestPermit_Simulation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + defer goleak.VerifyNone(t) wg := &sync.WaitGroup{} permit := NewResizablePermit(tt.capacity[0]) wg.Add(1) @@ -117,15 +119,15 @@ func TestPermit_Simulation(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - permitChan := permit.AcquireChan(ctx) + permitChan, done := permit.AcquireChan(ctx) select { - case <-permitChan.C(): + case <-permitChan: time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) permit.Release() case <-ctx.Done(): failures.Inc() } - permitChan.Close() + done() }() } @@ -142,3 +144,19 @@ func TestPermit_Simulation(t *testing.T) { }) } } + +func Test_Permit_AcquireChan(t *testing.T) { + permit := NewResizablePermit(2) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + permitChan, done := permit.AcquireChan(ctx) + select { + case <-permitChan: + assert.Equal(t, 2, permit.Quota()) + assert.Equal(t, 1, permit.Count()) + case <-ctx.Done(): + t.Error("unexpected timeout") + } + done() + permit.Release() +} From 6342ff80afcf6d0dd56b4d35adaf5ec5db74dd3c Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Fri, 6 Dec 2024 14:44:29 -0800 Subject: [PATCH 5/5] add unit test and fix comments --- internal/worker/resizable_permit.go | 9 +- internal/worker/resizable_permit_test.go | 149 ++++++++++++++++++++--- 2 files changed, 140 insertions(+), 18 deletions(-) diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go index 31c504f63..de785bfce 100644 --- a/internal/worker/resizable_permit.go +++ b/internal/worker/resizable_permit.go @@ -66,10 +66,10 @@ func (p *resizablePermit) Count() int { return p.sem.GetCount() } -// AcquireChan creates a PermitChannel that can be used to wait for a permit -// After usage: -// 1. avoid goroutine leak by calling permitChannel.Close() -// 2. release permit by calling permit.Release() +// AcquireChan returns a channel that could be used to wait for the permit and a close function when done +// Notes: +// 1. avoid goroutine leak by calling the done function +// 2. if the read succeeded, release permit by calling permit.Release() func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) { ctx, cancel := context.WithCancel(ctx) pc := &permitChannel{ @@ -85,6 +85,7 @@ func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, fun } } +// permitChannel is an implementation to acquire a permit through channel type permitChannel struct { p Permit c chan struct{} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go index 21903fb91..141510bde 100644 --- a/internal/worker/resizable_permit_test.go +++ b/internal/worker/resizable_permit_test.go @@ -60,7 +60,7 @@ func TestPermit_Simulation(t *testing.T) { maxTestDuration: 250 * time.Millisecond, // at least need 100ms * 1000 / 200 = 500ms to acquire all permit capacity: []int{200}, goroutines: 1000, - expectFailuresRange: []int{1, 999}, // should at least pass some acquires + expectFailuresRange: []int{400, 600}, // should at least pass some acquires }, { name: "not enough permit at beginning but due to capacity change, blocking but all acquire", @@ -74,7 +74,7 @@ func TestPermit_Simulation(t *testing.T) { maxTestDuration: 250 * time.Millisecond, capacity: []int{600, 400, 200}, goroutines: 1000, - expectFailuresRange: []int{1, 999}, + expectFailuresRange: []int{1, 500}, // the worst case with 200 capacity will at least pass 500 acquires }, { name: "not enough permit for any acquire, fail all", @@ -145,18 +145,139 @@ func TestPermit_Simulation(t *testing.T) { } } -func Test_Permit_AcquireChan(t *testing.T) { - permit := NewResizablePermit(2) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - permitChan, done := permit.AcquireChan(ctx) - select { - case <-permitChan: - assert.Equal(t, 2, permit.Quota()) +// Test_Permit_Acquire tests the basic acquire functionality +// before each acquire will wait 100ms +func Test_Permit_Acquire(t *testing.T) { + + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + err := permit.Acquire(context.Background()) + assert.NoError(t, err) assert.Equal(t, 1, permit.Count()) - case <-ctx.Done(): - t.Error("unexpected timeout") + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "context deadline exceeded") + assert.Empty(t, permit.Count()) + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "canceled") + assert.Empty(t, permit.Count()) + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err := permit.Acquire(ctx) + assert.NoError(t, err) + err = permit.Acquire(ctx) + assert.ErrorContains(t, err, "failed to acquire permit") + assert.Equal(t, 1, permit.Count()) + }) +} + +func Test_Permit_Release(t *testing.T) { + for _, tt := range []struct { + name string + quota, acquire, release int + expectPanic bool + }{ + {"release all acquired permits", 10, 5, 5, false}, + {"release partial acquired permit", 10, 5, 1, false}, + {"release non acquired permit", 10, 5, 0, false}, + {"release more than acquired permit", 10, 5, 10, true}, + } { + t.Run(tt.name, func(t *testing.T) { + permit := NewResizablePermit(tt.quota) + for i := 0; i < tt.acquire; i++ { + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + } + releaseOp := func() { + for i := 0; i < tt.release; i++ { + permit.Release() + } + } + + if tt.expectPanic { + assert.Panics(t, releaseOp) + } else { + assert.NotPanics(t, releaseOp) + assert.Equal(t, tt.acquire-tt.release, permit.Count()) + } + }) } - done() - permit.Release() +} + +func Test_Permit_AcquireChan(t *testing.T) { + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + assert.Equal(t, 1, permit.Count()) + case <-ctx.Done(): + t.Errorf("permit not acquired") + } + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(4) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + for i := 0; i < 10; i++ { + channel, done := permit.AcquireChan(ctx) + select { + case <-channel: + case <-ctx.Done(): + } + done() + } + + assert.Equal(t, 4, permit.Count()) + }) }