Skip to content

Commit

Permalink
test: fix racy setup
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Dec 19, 2024
1 parent 050d18f commit d42c8ae
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 54 deletions.
2 changes: 1 addition & 1 deletion cmd/status/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestStatusCmd(t *testing.T) {
})

t.Run("case=block", func(t *testing.T) {
ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, time.Millisecond)
ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, 100*time.Millisecond)

l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
Expand Down
90 changes: 48 additions & 42 deletions internal/driver/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error {
innerCtx, cancel := context.WithCancel(ctx)
defer cancel()

serveFuncs := []func(context.Context, chan<- struct{}) error{
serveFuncs := []func(context.Context, chan<- struct{}) func() error{
r.serveRead,
r.serveWrite,
r.serveOPLSyntax,
Expand Down Expand Up @@ -149,88 +149,94 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error {
// We need to separate the setup (invoking the functions that return the serve functions) from running the serve
// functions to mitigate race conditions in the HTTP router.
for _, serve := range serveFuncs {
eg.Go(func() error {
return serve(innerCtx, doneShutdown)
})
eg.Go(serve(innerCtx, doneShutdown))
}

return eg.Wait()
}

func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) error {
func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) func() error {
rt, s := r.ReadRouter(ctx), r.ReadGRPCServer(ctx)

if tracer := r.Tracer(ctx); tracer.IsLoaded() {
rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider()))
}

addr, listenFile := r.Config(ctx).ReadAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done)
return func() error {
addr, listenFile := r.Config(ctx).ReadAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done)
}
}

func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) error {
func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) func() error {
rt, s := r.WriteRouter(ctx), r.WriteGRPCServer(ctx)

if tracer := r.Tracer(ctx); tracer.IsLoaded() {
rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider()))
}

addr, listenFile := r.Config(ctx).WriteAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done)
return func() error {
addr, listenFile := r.Config(ctx).WriteAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done)
}
}

func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) error {
func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) func() error {
rt, s := r.OPLSyntaxRouter(ctx), r.OplGRPCServer(ctx)

if tracer := r.Tracer(ctx); tracer.IsLoaded() {
rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider()))
}

addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done)
return func() error {
addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn()
return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done)
}
}

func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) error {
func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) func() error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

addr, listenFile := r.Config(ctx).MetricsListenOn()
l, err := listenAndWriteFile(ctx, addr, listenFile)
if err != nil {
return err
}

//nolint:gosec // graceful.WithDefaults already sets a timeout
s := graceful.WithDefaults(&http.Server{
Handler: r.metricsRouter(ctx),
})

eg := &errgroup.Group{}
return func() error {
defer cancel()
eg := &errgroup.Group{}

eg.Go(func() error {
if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) {
return errors.WithStack(err)
addr, listenFile := r.Config(ctx).MetricsListenOn()
l, err := listenAndWriteFile(ctx, addr, listenFile)
if err != nil {
return err
}
return nil
})
eg.Go(func() (err error) {
defer func() {
l := r.Logger().WithField("endpoint", "metrics")
if err != nil {
l.WithError(err).Error("graceful shutdown failed")
} else {
l.Info("gracefully shutdown server")

eg.Go(func() error {
if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) {
return errors.WithStack(err)
}
done <- struct{}{}
}()
return nil
})
eg.Go(func() (err error) {
defer func() {
l := r.Logger().WithField("endpoint", "metrics")
if err != nil {
l.WithError(err).Error("graceful shutdown failed")
} else {
l.Info("gracefully shutdown server")
}
done <- struct{}{}
}()

<-ctx.Done()
ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout)
defer cancel()
return s.Shutdown(ctx)
})
<-ctx.Done()
ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout)
defer cancel()
return s.Shutdown(ctx)
})

return eg.Wait()
return eg.Wait()
}
}

func multiplexPort(ctx context.Context, log *logrusx.Logger, addr, listenFile string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error {
Expand Down
12 changes: 3 additions & 9 deletions internal/driver/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ func TestScrapingEndpoint(t *testing.T) {

eg := errgroup.Group{}
doneShutdown := make(chan struct{})
eg.Go(func() error {
return r.serveWrite(ctx, doneShutdown)
})
eg.Go(func() error {
return r.serveMetrics(ctx, doneShutdown)
})
eg.Go(r.serveWrite(ctx, doneShutdown))
eg.Go(r.serveMetrics(ctx, doneShutdown))

_, writePort, _ := getAddr(t, "write")
_, metricsPort, _ := getAddr(t, "metrics")
Expand Down Expand Up @@ -104,9 +100,7 @@ func TestPanicRecovery(t *testing.T) {

eg := errgroup.Group{}
doneShutdown := make(chan struct{})
eg.Go(func() error {
return r.serveWrite(ctx, doneShutdown)
})
eg.Go(r.serveWrite(ctx, doneShutdown))

_, port, _ := getAddr(t, "write")

Expand Down
2 changes: 0 additions & 2 deletions internal/e2e/full_suit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ const (
func Test(t *testing.T) {
t.Parallel()
for _, dsn := range dbx.GetDSNs(t, false) {
dsn := dsn
t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -98,7 +97,6 @@ func Test(t *testing.T) {
syntaxRemote: oplAddr,
},
} {
cl := cl
t.Run(fmt.Sprintf("client=%T", cl), runCases(cl, namespaceTestMgr))

if tc, ok := cl.(transactClient); ok {
Expand Down

0 comments on commit d42c8ae

Please sign in to comment.