From cc85898a58305219c4c7bcd1440ccfb86442bcb7 Mon Sep 17 00:00:00 2001 From: Amir Y <83904651+amirylm@users.noreply.github.com> Date: Sat, 9 Sep 2023 01:41:54 +0300 Subject: [PATCH] Thread control utility (#10560) * utility for managing group of goroutines * refactor context to StopChan * remove limits * leftovers * leftovers round #2 * lint --- core/utils/thread_control.go | 44 +++++++++++++++++++++++++++++++ core/utils/thread_control_test.go | 27 +++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 core/utils/thread_control.go create mode 100644 core/utils/thread_control_test.go diff --git a/core/utils/thread_control.go b/core/utils/thread_control.go new file mode 100644 index 00000000000..8f7fff42496 --- /dev/null +++ b/core/utils/thread_control.go @@ -0,0 +1,44 @@ +package utils + +import ( + "context" + "sync" +) + +var _ ThreadControl = &threadControl{} + +// ThreadControl is a helper for managing a group of goroutines. +type ThreadControl interface { + // Go starts a goroutine and tracks the lifetime of the goroutine. + Go(fn func(context.Context)) + // Close cancels the goroutines and waits for all of them to exit. + Close() +} + +func NewThreadControl() *threadControl { + tc := &threadControl{ + stop: make(chan struct{}), + } + + return tc +} + +type threadControl struct { + threadsWG sync.WaitGroup + stop StopChan +} + +func (tc *threadControl) Go(fn func(context.Context)) { + tc.threadsWG.Add(1) + go func() { + defer tc.threadsWG.Done() + ctx, cancel := tc.stop.NewCtx() + defer cancel() + fn(ctx) + }() +} + +func (tc *threadControl) Close() { + close(tc.stop) + tc.threadsWG.Wait() +} diff --git a/core/utils/thread_control_test.go b/core/utils/thread_control_test.go new file mode 100644 index 00000000000..9001ca7241c --- /dev/null +++ b/core/utils/thread_control_test.go @@ -0,0 +1,27 @@ +package utils + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestThreadControl_Close(t *testing.T) { + n := 10 + tc := NewThreadControl() + + finished := atomic.Int32{} + + for i := 0; i < n; i++ { + tc.Go(func(ctx context.Context) { + <-ctx.Done() + finished.Add(1) + }) + } + + tc.Close() + + require.Equal(t, int32(n), finished.Load()) +}