Skip to content

Commit

Permalink
fix unit test and address comment on AcquireChan return
Browse files Browse the repository at this point in the history
  • Loading branch information
shijiesheng committed Dec 6, 2024
1 parent c5f2d53 commit 40cd57c
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 22 deletions.
5 changes: 5 additions & 0 deletions internal/internal_poller_autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ func (p *pollerAutoScaler) Start() {
p.permit.SetQuota(int(proposedResource))
}
p.pollerUsageEstimator.Reset()

// hooks
for i := range p.onAutoScale {
p.onAutoScale[i]()
}
}
}
}()
Expand Down
5 changes: 3 additions & 2 deletions internal/internal_poller_autoscaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ func Test_pollerAutoscaler(t *testing.T) {
go func() {
defer wg.Done()
for pollResult := range pollChan {
pollerScaler.permit.Acquire(context.Background())
err := pollerScaler.permit.Acquire(context.Background())
assert.NoError(t, err)
pollerScaler.CollectUsage(pollResult)
pollerScaler.permit.Release()
}
Expand All @@ -202,7 +203,7 @@ func Test_pollerAutoscaler(t *testing.T) {

assert.Eventually(t, func() bool {
return autoscalerEpoch.Load() == uint64(tt.args.autoScalerEpoch)
}, tt.args.cooldownTime+20*time.Millisecond, 10*time.Millisecond)
}, tt.args.cooldownTime+100*time.Millisecond, 10*time.Millisecond)
pollerScaler.Stop()
res := pollerScaler.permit.Quota() - pollerScaler.permit.Count()
assert.Equal(t, tt.want, int(res))
Expand Down
9 changes: 4 additions & 5 deletions internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,13 @@ func (bw *baseWorker) runPoller() {
bw.metricsScope.Counter(metrics.PollerStartCounter).Inc(1)

for {
// permitChannel can be blocking without passing context because shutdownCh is used
permitChannel := bw.concurrency.PollerPermit.AcquireChan(context.Background())
permitChannel, channelDone := bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext)
select {
case <-bw.shutdownCh:
permitChannel.Close()
channelDone()
return
case <-permitChannel.C(): // don't poll unless there is a task permit
permitChannel.Close()
case <-permitChannel: // don't poll unless there is a task permit
channelDone()
// TODO move to a centralized place inside the worker
// emit metrics on concurrent task permit quota and current task permit count
// NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process
Expand Down
9 changes: 1 addition & 8 deletions internal/worker/concurrency.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,9 @@ type ConcurrencyLimit struct {
// Permit is an adaptive permit issuer to control concurrency
type Permit interface {
Acquire(context.Context) error
AcquireChan(context.Context) PermitChannel
AcquireChan(context.Context) (channel <-chan struct{}, done func())
Count() int
Quota() int
Release()
SetQuota(int)
}

// PermitChannel is a channel that can be used to wait for a permit to be available
// Remember to call Close() to avoid goroutine leak
type PermitChannel interface {
C() <-chan struct{}
Close()
}
10 changes: 6 additions & 4 deletions internal/worker/resizable_permit.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (p *resizablePermit) Count() int {
// After usage:
// 1. avoid goroutine leak by calling permitChannel.Close()
// 2. release permit by calling permit.Release()
func (p *resizablePermit) AcquireChan(ctx context.Context) PermitChannel {
func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) {
ctx, cancel := context.WithCancel(ctx)
pc := &permitChannel{
p: p,
Expand All @@ -79,8 +79,10 @@ func (p *resizablePermit) AcquireChan(ctx context.Context) PermitChannel {
cancel: cancel,
wg: &sync.WaitGroup{},
}
pc.start()
return pc
pc.Start()
return pc.C(), func() {
pc.Close()
}
}

type permitChannel struct {
Expand All @@ -95,7 +97,7 @@ func (ch *permitChannel) C() <-chan struct{} {
return ch.c
}

func (ch *permitChannel) start() {
func (ch *permitChannel) Start() {
ch.wg.Add(1)
go func() {
defer ch.wg.Done()
Expand Down
24 changes: 21 additions & 3 deletions internal/worker/resizable_permit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"go.uber.org/goleak"
)

func TestPermit_Simulation(t *testing.T) {
Expand Down Expand Up @@ -86,6 +87,7 @@ func TestPermit_Simulation(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer goleak.VerifyNone(t)
wg := &sync.WaitGroup{}
permit := NewResizablePermit(tt.capacity[0])
wg.Add(1)
Expand Down Expand Up @@ -117,15 +119,15 @@ func TestPermit_Simulation(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
permitChan := permit.AcquireChan(ctx)
permitChan, done := permit.AcquireChan(ctx)
select {
case <-permitChan.C():
case <-permitChan:
time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond)
permit.Release()
case <-ctx.Done():
failures.Inc()
}
permitChan.Close()
done()
}()
}

Expand All @@ -142,3 +144,19 @@ func TestPermit_Simulation(t *testing.T) {
})
}
}

func Test_Permit_AcquireChan(t *testing.T) {
permit := NewResizablePermit(2)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
permitChan, done := permit.AcquireChan(ctx)
select {
case <-permitChan:
assert.Equal(t, 2, permit.Quota())
assert.Equal(t, 1, permit.Count())
case <-ctx.Done():
t.Error("unexpected timeout")
}
done()
permit.Release()
}

0 comments on commit 40cd57c

Please sign in to comment.