From cc558dccd2509bd6596f8f6ab64b5cfbf8d2764a Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 17 Oct 2024 15:36:12 +0200 Subject: [PATCH 1/4] fix: use correct number of servers --- internal/driver/daemon.go | 115 +++++++++++++++---------------- internal/driver/daemon_test.go | 52 ++++++++------ internal/e2e/grpc_client_test.go | 17 ++--- 3 files changed, 93 insertions(+), 91 deletions(-) diff --git a/internal/driver/daemon.go b/internal/driver/daemon.go index 7c78e312e..7df265d86 100644 --- a/internal/driver/daemon.go +++ b/internal/driver/daemon.go @@ -106,7 +106,14 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { innerCtx, cancel := context.WithCancel(ctx) defer cancel() - doneShutdown := make(chan struct{}, 3) + serveFuncs := []func(context.Context, chan<- struct{}) error{ + r.serveRead, + r.serveWrite, + r.serveOPLSyntax, + r.serveMetrics, + } + + doneShutdown := make(chan struct{}, len(serveFuncs)) go func() { osSignals := make(chan os.Signal, 1) @@ -118,18 +125,20 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { case <-innerCtx.Done(): } - ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), graceful.DefaultShutdownTimeout) defer cancel() - nWaitingForShutdown := cap(doneShutdown) - select { - case <-ctx.Done(): - return - case <-doneShutdown: - nWaitingForShutdown-- - if nWaitingForShutdown == 0 { - // graceful shutdown done + nWaitingForShutdown := len(serveFuncs) + for { + select { + case <-ctx.Done(): return + case <-doneShutdown: + nWaitingForShutdown-- + if nWaitingForShutdown == 0 { + // graceful shutdown done + return + } } } }() @@ -137,57 +146,49 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { eg := &errgroup.Group{} // We need to separate the setup (invoking the functions that return the serve functions) from running the serve - // functions to mitigate race contitions in the HTTP router. - for _, serve := range []func() error{ - r.serveRead(innerCtx, doneShutdown), - r.serveWrite(innerCtx, doneShutdown), - r.serveOPLSyntax(innerCtx, doneShutdown), - r.serveMetrics(innerCtx, doneShutdown), - } { - eg.Go(serve) + // functions to mitigate race conditions in the HTTP router. + for _, serve := range serveFuncs { + eg.Go(func() error { + return serve(innerCtx, doneShutdown) + }) } return eg.Wait() } -func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) func() error { +func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) error { rt, s := r.ReadRouter(ctx), r.ReadGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return func() error { - return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), r.Config(ctx).ReadAPIListenOn(), rt, s, done) - } + return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), r.Config(ctx).ReadAPIListenOn(), rt, s, done) } -func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) func() error { +func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) error { rt, s := r.WriteRouter(ctx), r.WriteGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return func() error { - return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), r.Config(ctx).WriteAPIListenOn(), rt, s, done) - } + return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), r.Config(ctx).WriteAPIListenOn(), rt, s, done) } -func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) func() error { +func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) error { rt, s := r.OPLSyntaxRouter(ctx), r.OplGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return func() error { - return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), r.Config(ctx).OPLSyntaxAPIListenOn(), rt, s, done) - } + return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), r.Config(ctx).OPLSyntaxAPIListenOn(), rt, s, done) } -func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) func() error { +func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) error { ctx, cancel := context.WithCancel(ctx) + defer cancel() //nolint:gosec // graceful.WithDefaults already sets a timeout s := graceful.WithDefaults(&http.Server{ @@ -195,36 +196,32 @@ func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{} Addr: r.Config(ctx).MetricsListenOn(), }) - return func() error { - defer cancel() - - eg := &errgroup.Group{} + eg := &errgroup.Group{} - eg.Go(func() error { - if err := s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - return errors.WithStack(err) + eg.Go(func() error { + if err := s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return errors.WithStack(err) + } + return nil + }) + eg.Go(func() (err error) { + defer func() { + l := r.Logger().WithField("endpoint", "metrics") + if err != nil { + l.WithError(err).Error("graceful shutdown failed") + } else { + l.Info("gracefully shutdown server") } - return nil - }) - eg.Go(func() (err error) { - defer func() { - l := r.Logger().WithField("endpoint", "metrics") - if err != nil { - l.WithError(err).Error("graceful shutdown failed") - } else { - l.Info("gracefully shutdown server") - } - done <- struct{}{} - }() + done <- struct{}{} + }() - <-ctx.Done() - ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) - defer cancel() - return s.Shutdown(ctx) - }) + <-ctx.Done() + ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) + defer cancel() + return s.Shutdown(ctx) + }) - return eg.Wait() - } + return eg.Wait() } func multiplexPort(ctx context.Context, log *logrusx.Logger, addr string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error { @@ -281,7 +278,7 @@ func multiplexPort(ctx context.Context, log *logrusx.Logger, addr string, router <-ctx.Done() - ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), graceful.DefaultShutdownTimeout) defer cancel() shutdownEg := errgroup.Group{} @@ -304,7 +301,7 @@ func multiplexPort(ctx context.Context, log *logrusx.Logger, addr string, router return nil case <-ctx.Done(): grpcS.Stop() - return errors.New("graceful stop of gRPC server canceled, had to force it") + return errors.New("graceful stop of gRPC server timed out, had to force it") } }) diff --git a/internal/driver/daemon_test.go b/internal/driver/daemon_test.go index c588462e7..fd5ccad56 100644 --- a/internal/driver/daemon_test.go +++ b/internal/driver/daemon_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "testing" + "time" "github.com/phayes/freeport" "github.com/prometheus/common/expfmt" @@ -44,19 +45,25 @@ func TestScrapingEndpoint(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(r.serveWrite(ctx, doneShutdown)) - eg.Go(r.serveMetrics(ctx, doneShutdown)) - - conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) - require.NoError(t, err) - defer conn.Close() + eg.Go(func() error { + return r.serveWrite(ctx, doneShutdown) + }) + eg.Go(func() error { + return r.serveMetrics(ctx, doneShutdown) + }) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() - cl := grpcHealthV1.NewHealthClient(conn) - watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) - require.NoError(t, err) - require.NoError(t, watcher.CloseSend()) - for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { - } + cl := grpcHealthV1.NewHealthClient(conn) + watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) + require.NoError(t, err) + require.NoError(t, watcher.CloseSend()) + for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { + } + }, 2*time.Second, 100*time.Millisecond) promresp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", portMetrics) + prometheus.MetricsPrometheusPath) require.NoError(t, err) @@ -101,22 +108,27 @@ func TestPanicRecovery(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(r.serveWrite(ctx, doneShutdown)) + eg.Go(func() error { + return r.serveWrite(ctx, doneShutdown) + }) conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) defer conn.Close() - cl := grpcHealthV1.NewHealthClient(conn) + require.EventuallyWithT(t, func(t *assert.CollectT) { + cl := grpcHealthV1.NewHealthClient(conn) - watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) - require.NoError(t, err) - require.NoError(t, watcher.CloseSend()) - for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { - } + watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) + require.NoError(t, err) + require.NoError(t, watcher.CloseSend()) + for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { + } + }, 2*time.Second, 100*time.Millisecond) + cl := grpcHealthV1.NewHealthClient(conn) // we want to ensure the server is still running after the panic - for i := 0; i < 10; i++ { + for range 10 { // Unary call resp, err := cl.Check(ctx, &grpcHealthV1.HealthCheckRequest{}) require.Error(t, err, "%+v", resp) diff --git a/internal/e2e/grpc_client_test.go b/internal/e2e/grpc_client_test.go index 536881bc6..72f9044d8 100644 --- a/internal/e2e/grpc_client_test.go +++ b/internal/e2e/grpc_client_test.go @@ -211,20 +211,13 @@ func (g *grpcClient) expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ket } func (g *grpcClient) waitUntilLive(t *testing.T) { - c := grpcHealthV1.NewHealthClient(g.read) + require.EventuallyWithT(t, func(t *assert.CollectT) { + c := grpcHealthV1.NewHealthClient(g.read) - for { res, err := c.Check(g.ctx, &grpcHealthV1.HealthCheckRequest{}) - if errors.Is(err, context.Canceled) { - t.Fatalf("timed out waiting for service to be live: %s", err) - } - if err == nil { - require.Equal(t, grpcHealthV1.HealthCheckResponse_SERVING, res.Status) - return - } - t.Logf("waiting for service to be live: %s", err) - time.Sleep(10 * time.Millisecond) - } + require.NoError(t, err) + assert.Equal(t, grpcHealthV1.HealthCheckResponse_SERVING, res.Status) + }, 2*time.Second, 10*time.Millisecond) } func (g *grpcClient) deleteTuple(t *testing.T, r *ketoapi.RelationTuple) { From ba8c4db1ac18fd5730423255d935a1357b790225 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Fri, 18 Oct 2024 10:19:37 +0200 Subject: [PATCH 2/4] feat: write listen files with actual address --- embedx/config.schema.json | 28 +++++++++++ go.mod | 2 +- internal/driver/config/provider.go | 32 +++++++------ internal/driver/config/provider_test.go | 4 +- internal/driver/daemon.go | 37 ++++++++++++--- internal/driver/daemon_test.go | 42 ++++++++--------- internal/driver/testhelpers.go | 63 +++++++++++++++++++++++++ internal/e2e/full_suit_test.go | 47 ++++++++++-------- internal/e2e/grpc_client_test.go | 9 ++-- internal/e2e/helpers.go | 14 ++---- 10 files changed, 199 insertions(+), 79 deletions(-) create mode 100644 internal/driver/testhelpers.go diff --git a/embedx/config.schema.json b/embedx/config.schema.json index be38a0222..8f46f26a3 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -211,6 +211,13 @@ "title": "Host", "description": "The network interface to listen on." }, + "write_listen_file": { + "type": "string", + "title": "Read Listen File", + "description": "The path to a file that will be created when the read API is ready to accept connections. The content of the file is the host:port of the read API. Use this to get the actual port when using port 0. The service might not yet be ready to accept connections when the file is created.", + "format": "uri", + "examples": ["file:///tmp/keto-read-api"] + }, "cors": { "$ref": "#/definitions/cors" }, @@ -239,6 +246,13 @@ "title": "Host", "description": "The network interface to listen on." }, + "write_listen_file": { + "type": "string", + "title": "Write Listen File", + "description": "The path to a file that will be created when the write API is ready to accept connections. The content of the file is the host:port of the write API. Use this to get the actual port when using port 0. The service might not yet be ready to accept connections when the file is created.", + "format": "uri", + "examples": ["file:///tmp/keto-write-api"] + }, "cors": { "$ref": "#/definitions/cors" }, @@ -267,6 +281,13 @@ "title": "Host", "description": "The network interface to listen on." }, + "write_listen_file": { + "type": "string", + "title": "Metrics Listen File", + "description": "The path to a file that will be created when the metrics API is ready to accept connections. The content of the file is the host:port of the metrics API. Use this to get the actual port when using port 0. The service might not yet be ready to accept connections when the file is created.", + "format": "uri", + "examples": ["file:///tmp/keto-metrics-api"] + }, "cors": { "$ref": "#/definitions/cors" }, @@ -295,6 +316,13 @@ "title": "Host", "description": "The network interface to listen on." }, + "write_listen_file": { + "type": "string", + "title": "OPL Listen File", + "description": "The path to a file that will be created when the OPL API is ready to accept connections. The content of the file is the host:port of the OPL API. Use this to get the actual port when using port 0. The service might not yet be ready to accept connections when the file is created.", + "format": "uri", + "examples": ["file:///tmp/keto-opl-api"] + }, "cors": { "$ref": "#/definitions/cors" }, diff --git a/go.mod b/go.mod index b03e92ebf..d2e6cc45c 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,6 @@ require ( github.com/ory/keto/proto v0.13.0-alpha.0 github.com/ory/x v0.0.677 github.com/pelletier/go-toml v1.9.5 - github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/pkg/errors v0.9.1 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.61.0 @@ -153,6 +152,7 @@ require ( github.com/ory/dockertest/v3 v3.11.0 // indirect github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 // indirect github.com/pkg/profile v1.7.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_golang v1.20.4 // indirect diff --git a/internal/driver/config/provider.go b/internal/driver/config/provider.go index f9e1ca73a..bdeac5563 100644 --- a/internal/driver/config/provider.go +++ b/internal/driver/config/provider.go @@ -43,14 +43,18 @@ const ( KeyBatchCheckMaxBatchSize = "limit.max_batch_check_size" KeyBatchCheckParallelizationLimit = "limit.batch_check_max_parallelization" - KeyReadAPIHost = "serve." + string(EndpointRead) + ".host" - KeyReadAPIPort = "serve." + string(EndpointRead) + ".port" - KeyWriteAPIHost = "serve." + string(EndpointWrite) + ".host" - KeyWriteAPIPort = "serve." + string(EndpointWrite) + ".port" - KeyOPLSyntaxAPIHost = "serve." + string(EndpointOPLSyntax) + ".host" - KeyOPLSyntaxAPIPort = "serve." + string(EndpointOPLSyntax) + ".port" - KeyMetricsHost = "serve." + string(EndpointMetrics) + ".host" - KeyMetricsPort = "serve." + string(EndpointMetrics) + ".port" + KeyReadAPIHost = "serve." + string(EndpointRead) + ".host" + KeyReadAPIPort = "serve." + string(EndpointRead) + ".port" + KeyReadAPIListenFile = "serve." + string(EndpointRead) + ".write_listen_file" + KeyWriteAPIHost = "serve." + string(EndpointWrite) + ".host" + KeyWriteAPIPort = "serve." + string(EndpointWrite) + ".port" + KeyWriteAPIListenFile = "serve." + string(EndpointWrite) + ".write_listen_file" + KeyOPLSyntaxAPIHost = "serve." + string(EndpointOPLSyntax) + ".host" + KeyOPLSyntaxAPIPort = "serve." + string(EndpointOPLSyntax) + ".port" + KeyOPLSyntaxListenFile = "serve." + string(EndpointOPLSyntax) + ".write_listen_file" + KeyMetricsHost = "serve." + string(EndpointMetrics) + ".host" + KeyMetricsPort = "serve." + string(EndpointMetrics) + ".port" + KeyMetricsListenFile = "serve." + string(EndpointMetrics) + ".write_listen_file" KeyNamespaces = "namespaces" KeyNamespacesExperimentalStrictMode = KeyNamespaces + ".experimental_strict_mode" @@ -167,18 +171,18 @@ func (k *Config) Set(key string, v any) error { return nil } -func (k *Config) addressFor(endpoint EndpointType) string { +func (k *Config) addressFor(endpoint EndpointType) (string, string) { return fmt.Sprintf( "%s:%d", k.p.StringF("serve."+string(endpoint)+".host", ""), k.p.IntF("serve."+string(endpoint)+".port", 0), - ) + ), k.p.StringF("serve."+string(endpoint)+".write_listen_file", "") } -func (k *Config) ReadAPIListenOn() string { return k.addressFor(EndpointRead) } -func (k *Config) WriteAPIListenOn() string { return k.addressFor(EndpointWrite) } -func (k *Config) MetricsListenOn() string { return k.addressFor(EndpointMetrics) } -func (k *Config) OPLSyntaxAPIListenOn() string { return k.addressFor(EndpointOPLSyntax) } +func (k *Config) ReadAPIListenOn() (string, string) { return k.addressFor(EndpointRead) } +func (k *Config) WriteAPIListenOn() (string, string) { return k.addressFor(EndpointWrite) } +func (k *Config) MetricsListenOn() (string, string) { return k.addressFor(EndpointMetrics) } +func (k *Config) OPLSyntaxAPIListenOn() (string, string) { return k.addressFor(EndpointOPLSyntax) } func (k *Config) MaxReadDepth() int { return k.p.Int(KeyLimitMaxReadDepth) diff --git a/internal/driver/config/provider_test.go b/internal/driver/config/provider_test.go index 2628b2749..ffba74a55 100644 --- a/internal/driver/config/provider_test.go +++ b/internal/driver/config/provider_test.go @@ -283,5 +283,7 @@ func TestProvider_DefaultReadAPIListenOn(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, ":4466", config.ReadAPIListenOn()) + addr, listenFile := config.ReadAPIListenOn() + assert.Equal(t, ":4466", addr) + assert.Zero(t, listenFile) } diff --git a/internal/driver/daemon.go b/internal/driver/daemon.go index 7df265d86..2f2d020fb 100644 --- a/internal/driver/daemon.go +++ b/internal/driver/daemon.go @@ -5,6 +5,7 @@ package driver import ( "context" + "fmt" "net" "net/http" "os" @@ -163,7 +164,8 @@ func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) e rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), r.Config(ctx).ReadAPIListenOn(), rt, s, done) + addr, listenFile := r.Config(ctx).ReadAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done) } func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) error { @@ -173,7 +175,8 @@ func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), r.Config(ctx).WriteAPIListenOn(), rt, s, done) + addr, listenFile := r.Config(ctx).WriteAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done) } func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) error { @@ -183,23 +186,29 @@ func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), r.Config(ctx).OPLSyntaxAPIListenOn(), rt, s, done) + addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done) } func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + addr, listenFile := r.Config(ctx).MetricsListenOn() + l, err := listenAndWriteFile(ctx, addr, listenFile) + if err != nil { + return err + } + //nolint:gosec // graceful.WithDefaults already sets a timeout s := graceful.WithDefaults(&http.Server{ Handler: r.metricsRouter(ctx), - Addr: r.Config(ctx).MetricsListenOn(), }) eg := &errgroup.Group{} eg.Go(func() error { - if err := s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { return errors.WithStack(err) } return nil @@ -224,8 +233,8 @@ func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{} return eg.Wait() } -func multiplexPort(ctx context.Context, log *logrusx.Logger, addr string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error { - l, err := (&net.ListenConfig{}).Listen(ctx, "tcp", addr) +func multiplexPort(ctx context.Context, log *logrusx.Logger, addr, listenFile string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error { + l, err := listenAndWriteFile(ctx, addr, listenFile) if err != nil { return err } @@ -324,6 +333,20 @@ func (r *RegistryDefault) allHandlers() []Handler { return r.handlers } +func listenAndWriteFile(ctx context.Context, addr, listenFile string) (net.Listener, error) { + l, err := (&net.ListenConfig{}).Listen(ctx, "tcp", addr) + if err != nil { + return nil, errors.WithStack(fmt.Errorf("unable to listen on %q: %w", addr, err)) + } + const filePrefix = "file://" + if strings.HasPrefix(listenFile, filePrefix) { + if err := os.WriteFile(listenFile[len(filePrefix):], []byte(l.Addr().String()), 0600); err != nil { + return nil, errors.WithStack(fmt.Errorf("unable to write listen file %q: %w", listenFile, err)) + } + } + return l, nil +} + func (r *RegistryDefault) ReadRouter(ctx context.Context) http.Handler { n := negroni.New() for _, f := range r.defaultHttpMiddlewares { diff --git a/internal/driver/daemon_test.go b/internal/driver/daemon_test.go index fd5ccad56..e84fb3031 100644 --- a/internal/driver/daemon_test.go +++ b/internal/driver/daemon_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/phayes/freeport" "github.com/prometheus/common/expfmt" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" @@ -19,8 +18,6 @@ import ( grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" - "github.com/ory/keto/internal/driver/config" - "context" prometheus "github.com/ory/x/prometheusx" @@ -29,19 +26,13 @@ import ( ) func TestScrapingEndpoint(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - port, err := freeport.GetFreePort() - require.NoError(t, err) - r := NewSqliteTestRegistry(t, false) - require.NoError(t, r.Config(ctx).Set(config.KeyWriteAPIPort, port)) - - //metrics port - portMetrics, err := freeport.GetFreePort() - require.NoError(t, err) - require.NoError(t, r.Config(ctx).Set(config.KeyMetricsPort, portMetrics)) + getAddr := UseDynamicPorts(ctx, t, r) eg := errgroup.Group{} doneShutdown := make(chan struct{}) @@ -52,8 +43,13 @@ func TestScrapingEndpoint(t *testing.T) { return r.serveMetrics(ctx, doneShutdown) }) - require.EventuallyWithT(t, func(t *assert.CollectT) { - conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + _, writePort, _ := getAddr(t, "write") + _, metricsPort, _ := getAddr(t, "metrics") + + t.Logf("write port: %s, metrics port: %s", writePort, metricsPort) + + assert.EventuallyWithT(t, func(t *assert.CollectT) { + conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%s", writePort), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) defer conn.Close() @@ -63,9 +59,9 @@ func TestScrapingEndpoint(t *testing.T) { require.NoError(t, watcher.CloseSend()) for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { } - }, 2*time.Second, 100*time.Millisecond) + }, 2*time.Second, 10*time.Millisecond) - promresp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", portMetrics) + prometheus.MetricsPrometheusPath) + promresp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s", metricsPort) + prometheus.MetricsPrometheusPath) require.NoError(t, err) require.EqualValues(t, http.StatusOK, promresp.StatusCode) @@ -91,6 +87,8 @@ func TestScrapingEndpoint(t *testing.T) { } func TestPanicRecovery(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -100,11 +98,9 @@ func TestPanicRecovery(t *testing.T) { streamPanicInterceptor := func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) { panic("test panic") } - port, err := freeport.GetFreePort() - require.NoError(t, err) r := NewSqliteTestRegistry(t, false, WithGRPCUnaryInterceptors(unaryPanicInterceptor), WithGRPCUnaryInterceptors(streamPanicInterceptor)) - require.NoError(t, r.Config(ctx).Set(config.KeyWriteAPIPort, port)) + getAddr := UseDynamicPorts(ctx, t, r) eg := errgroup.Group{} doneShutdown := make(chan struct{}) @@ -112,11 +108,13 @@ func TestPanicRecovery(t *testing.T) { return r.serveWrite(ctx, doneShutdown) }) - conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + _, port, _ := getAddr(t, "write") + + conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%s", port), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) defer conn.Close() - require.EventuallyWithT(t, func(t *assert.CollectT) { + assert.EventuallyWithT(t, func(t *assert.CollectT) { cl := grpcHealthV1.NewHealthClient(conn) watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) @@ -124,7 +122,7 @@ func TestPanicRecovery(t *testing.T) { require.NoError(t, watcher.CloseSend()) for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() { } - }, 2*time.Second, 100*time.Millisecond) + }, 2*time.Second, 10*time.Millisecond) cl := grpcHealthV1.NewHealthClient(conn) // we want to ensure the server is still running after the panic diff --git a/internal/driver/testhelpers.go b/internal/driver/testhelpers.go new file mode 100644 index 000000000..3337c71b1 --- /dev/null +++ b/internal/driver/testhelpers.go @@ -0,0 +1,63 @@ +package driver + +import ( + "context" + "fmt" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/driver/config" +) + +type GetAddr = func(testing.TB, string) (host string, port string, fullAddr string) + +func UseDynamicPorts(ctx context.Context, t testing.TB, r Registry) GetAddr { + t.Helper() + + listenDir := t.TempDir() + readListenFile := fmt.Sprintf("%s/read.addr", listenDir) + writeListenFile := fmt.Sprintf("%s/write.addr", listenDir) + metricsListenFile := fmt.Sprintf("%s/metrics.addr", listenDir) + oplListenFile := fmt.Sprintf("%s/opl.addr", listenDir) + + require.NoError(t, r.Config(ctx).Set(config.KeyReadAPIPort, 0)) + require.NoError(t, r.Config(ctx).Set(config.KeyReadAPIListenFile, "file://"+readListenFile)) + require.NoError(t, r.Config(ctx).Set(config.KeyWriteAPIPort, 0)) + require.NoError(t, r.Config(ctx).Set(config.KeyWriteAPIListenFile, "file://"+writeListenFile)) + require.NoError(t, r.Config(ctx).Set(config.KeyMetricsPort, 0)) + require.NoError(t, r.Config(ctx).Set(config.KeyMetricsListenFile, "file://"+metricsListenFile)) + require.NoError(t, r.Config(ctx).Set(config.KeyOPLSyntaxAPIPort, 0)) + require.NoError(t, r.Config(ctx).Set(config.KeyOPLSyntaxListenFile, "file://"+oplListenFile)) + + return func(t testing.TB, endpoint string) (string, string, string) { + fp := "" + switch endpoint { + case "read": + fp = readListenFile + case "write": + fp = writeListenFile + case "metrics": + fp = metricsListenFile + case "opl": + fp = oplListenFile + default: + t.Fatalf("unknown endpoint: %q", endpoint) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := os.Stat(fp) + require.NoError(t, err) + }, 2*time.Second, 10*time.Millisecond) + + addr, err := os.ReadFile(fp) + require.NoError(t, err) + host, port, err := net.SplitHostPort(string(addr)) + require.NoError(t, err) + return host, port, string(addr) + } +} diff --git a/internal/e2e/full_suit_test.go b/internal/e2e/full_suit_test.go index 6afb7cf78..ab381729c 100644 --- a/internal/e2e/full_suit_test.go +++ b/internal/e2e/full_suit_test.go @@ -57,23 +57,28 @@ func Test(t *testing.T) { t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) { t.Parallel() - ctx, reg, namespaceTestMgr := newInitializedReg(t, dsn, nil) + ctx, reg, namespaceTestMgr, getAddr := newInitializedReg(t, dsn, nil) closeServer := startServer(ctx, t, reg) t.Cleanup(closeServer) + _, _, readAddr := getAddr(t, "read") + _, _, writeAddr := getAddr(t, "write") + _, _, oplAddr := getAddr(t, "opl") + _, _, metricsAddr := getAddr(t, "metrics") + // The test cases start here // We execute every test with all clients available for _, cl := range []client{ newGrpcClient(t, ctx, - reg.Config(ctx).ReadAPIListenOn(), - reg.Config(ctx).WriteAPIListenOn(), - reg.Config(ctx).OPLSyntaxAPIListenOn(), + readAddr, + writeAddr, + oplAddr, ), &restClient{ - readURL: "http://" + reg.Config(ctx).ReadAPIListenOn(), - writeURL: "http://" + reg.Config(ctx).WriteAPIListenOn(), - oplSyntaxURL: "http://" + reg.Config(ctx).OPLSyntaxAPIListenOn(), + readURL: "http://" + readAddr, + writeURL: "http://" + writeAddr, + oplSyntaxURL: "http://" + oplAddr, }, &cliClient{c: &cmdx.CommandExecuter{ New: func() *cobra.Command { @@ -81,16 +86,16 @@ func Test(t *testing.T) { }, Ctx: ctx, PersistentArgs: []string{ - "--" + cliclient.FlagReadRemote, reg.Config(ctx).ReadAPIListenOn(), - "--" + cliclient.FlagWriteRemote, reg.Config(ctx).WriteAPIListenOn(), + "--" + cliclient.FlagReadRemote, readAddr, + "--" + cliclient.FlagWriteRemote, writeAddr, "--insecure-disable-transport-security=true", "--" + cmdx.FlagFormat, string(cmdx.FormatJSON), }, }}, &sdkClient{ - readRemote: reg.Config(ctx).ReadAPIListenOn(), - writeRemote: reg.Config(ctx).WriteAPIListenOn(), - syntaxRemote: reg.Config(ctx).OPLSyntaxAPIListenOn(), + readRemote: readAddr, + writeRemote: writeAddr, + syntaxRemote: oplAddr, }, } { cl := cl @@ -104,14 +109,14 @@ func Test(t *testing.T) { t.Run("case=metrics are served", func(t *testing.T) { t.Parallel() newGrpcClient(t, ctx, - reg.Config(ctx).ReadAPIListenOn(), - reg.Config(ctx).WriteAPIListenOn(), - reg.Config(ctx).OPLSyntaxAPIListenOn(), + readAddr, + writeAddr, + oplAddr, ).waitUntilLive(t) t.Run("case=on "+prometheus.MetricsPrometheusPath, func(t *testing.T) { t.Parallel() - resp, err := http.Get(fmt.Sprintf("http://%s%s", reg.Config(ctx).MetricsListenOn(), prometheus.MetricsPrometheusPath)) + resp, err := http.Get(fmt.Sprintf("http://%s%s", metricsAddr, prometheus.MetricsPrometheusPath)) require.NoError(t, err) require.Equal(t, resp.StatusCode, http.StatusOK) body, err := io.ReadAll(resp.Body) @@ -121,7 +126,7 @@ func Test(t *testing.T) { t.Run("case=not on /", func(t *testing.T) { t.Parallel() - resp, err := http.Get(fmt.Sprintf("http://%s", reg.Config(ctx).MetricsListenOn())) + resp, err := http.Get(fmt.Sprintf("http://%s", metricsAddr)) require.NoError(t, err) require.Equal(t, resp.StatusCode, http.StatusNotFound) }) @@ -133,7 +138,7 @@ func Test(t *testing.T) { func TestServeConfig(t *testing.T) { t.Parallel() - ctx, reg, _ := newInitializedReg(t, dbx.GetSqlite(t, dbx.SQLiteMemory), map[string]interface{}{ + ctx, reg, _, getAddr := newInitializedReg(t, dbx.GetSqlite(t, dbx.SQLiteMemory), map[string]interface{}{ "serve.read.cors.enabled": true, "serve.read.cors.debug": true, "serve.read.cors.allowed_methods": []string{http.MethodGet}, @@ -143,13 +148,15 @@ func TestServeConfig(t *testing.T) { closeServer := startServer(ctx, t, reg) t.Cleanup(closeServer) - for !healthReady(t, "http://"+reg.Config(ctx).ReadAPIListenOn()) { + _, _, readAddr := getAddr(t, "read") + + for !healthReady(t, "http://"+readAddr) { t.Log("Waiting for health check to be ready") time.Sleep(10 * time.Millisecond) } t.Log("Health check is ready") - req, err := http.NewRequest(http.MethodOptions, "http://"+reg.Config(ctx).ReadAPIListenOn()+relationtuple.ReadRouteBase, nil) + req, err := http.NewRequest(http.MethodOptions, "http://"+readAddr+relationtuple.ReadRouteBase, nil) require.NoError(t, err) req.Header.Set("Origin", "https://ory.sh") resp, err := http.DefaultClient.Do(req) diff --git a/internal/e2e/grpc_client_test.go b/internal/e2e/grpc_client_test.go index 72f9044d8..8f09994d6 100644 --- a/internal/e2e/grpc_client_test.go +++ b/internal/e2e/grpc_client_test.go @@ -6,10 +6,14 @@ package e2e import ( "context" "encoding/json" - "errors" "testing" "time" + "github.com/ory/keto/ketoapi" + opl "github.com/ory/keto/proto/ory/keto/opl/v1alpha1" + + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" + "github.com/ory/herodot" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,9 +23,6 @@ import ( "google.golang.org/grpc/status" "github.com/ory/keto/internal/x" - "github.com/ory/keto/ketoapi" - opl "github.com/ory/keto/proto/ory/keto/opl/v1alpha1" - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type grpcClient struct { diff --git a/internal/e2e/helpers.go b/internal/e2e/helpers.go index 96bcef55a..cd175f92f 100644 --- a/internal/e2e/helpers.go +++ b/internal/e2e/helpers.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/ory/x/configx" - "github.com/phayes/freeport" "github.com/spf13/pflag" "github.com/ory/keto/internal/driver/config" @@ -58,15 +57,12 @@ func (m *namespaceTestManager) remove(t *testing.T, name string) { require.NoError(t, m.reg.Config(m.ctx).Set(config.KeyNamespaces, m.nspaces)) } -func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]interface{}) (context.Context, driver.Registry, *namespaceTestManager) { +func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]interface{}) (context.Context, driver.Registry, *namespaceTestManager, driver.GetAddr) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(func() { cancel() }) - ports, err := freeport.GetFreePorts(4) - require.NoError(t, err) - flags := pflag.NewFlagSet("", pflag.ContinueOnError) configx.RegisterConfigFlag(flags, nil) @@ -75,13 +71,9 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]int "log.level": "debug", "log.leak_sensitive_values": true, config.KeyReadAPIHost: "127.0.0.1", - config.KeyReadAPIPort: ports[0], config.KeyWriteAPIHost: "127.0.0.1", - config.KeyWriteAPIPort: ports[1], config.KeyOPLSyntaxAPIHost: "127.0.0.1", - config.KeyOPLSyntaxAPIPort: ports[2], config.KeyMetricsHost: "127.0.0.1", - config.KeyMetricsPort: ports[3], config.KeyNamespaces: []*namespace.Namespace{}, } for k, v := range cfgOverwrites { @@ -94,6 +86,8 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]int reg, err := driver.NewDefaultRegistry(ctx, flags, true, nil) require.NoError(t, err) + getAddr := driver.UseDynamicPorts(ctx, t, reg) + require.NoError(t, reg.MigrateUp(ctx)) assertMigrated(ctx, t, reg) @@ -101,7 +95,7 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]int reg: reg, ctx: ctx, nspaces: []*namespace.Namespace{}, - } + }, getAddr } func assertMigrated(ctx context.Context, t testing.TB, r driver.Registry) { From 050d18fba23618c1fb3dd34ba6fc085dca461e4e Mon Sep 17 00:00:00 2001 From: zepatrik Date: Wed, 18 Dec 2024 17:10:45 +0100 Subject: [PATCH 3/4] docs: improve internal API documentation --- internal/driver/config/provider.go | 18 +++++++++++++----- internal/driver/testhelpers.go | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/driver/config/provider.go b/internal/driver/config/provider.go index bdeac5563..ab3be17e4 100644 --- a/internal/driver/config/provider.go +++ b/internal/driver/config/provider.go @@ -171,7 +171,7 @@ func (k *Config) Set(key string, v any) error { return nil } -func (k *Config) addressFor(endpoint EndpointType) (string, string) { +func (k *Config) addressFor(endpoint EndpointType) (addr string, listenFile string) { return fmt.Sprintf( "%s:%d", k.p.StringF("serve."+string(endpoint)+".host", ""), @@ -179,10 +179,18 @@ func (k *Config) addressFor(endpoint EndpointType) (string, string) { ), k.p.StringF("serve."+string(endpoint)+".write_listen_file", "") } -func (k *Config) ReadAPIListenOn() (string, string) { return k.addressFor(EndpointRead) } -func (k *Config) WriteAPIListenOn() (string, string) { return k.addressFor(EndpointWrite) } -func (k *Config) MetricsListenOn() (string, string) { return k.addressFor(EndpointMetrics) } -func (k *Config) OPLSyntaxAPIListenOn() (string, string) { return k.addressFor(EndpointOPLSyntax) } +func (k *Config) ReadAPIListenOn() (addr string, listenFile string) { + return k.addressFor(EndpointRead) +} +func (k *Config) WriteAPIListenOn() (addr string, listenFile string) { + return k.addressFor(EndpointWrite) +} +func (k *Config) MetricsListenOn() (addr string, listenFile string) { + return k.addressFor(EndpointMetrics) +} +func (k *Config) OPLSyntaxAPIListenOn() (addr string, listenFile string) { + return k.addressFor(EndpointOPLSyntax) +} func (k *Config) MaxReadDepth() int { return k.p.Int(KeyLimitMaxReadDepth) diff --git a/internal/driver/testhelpers.go b/internal/driver/testhelpers.go index 3337c71b1..2069ec906 100644 --- a/internal/driver/testhelpers.go +++ b/internal/driver/testhelpers.go @@ -14,7 +14,7 @@ import ( "github.com/ory/keto/internal/driver/config" ) -type GetAddr = func(testing.TB, string) (host string, port string, fullAddr string) +type GetAddr = func(t testing.TB, endpoint string) (host string, port string, fullAddr string) func UseDynamicPorts(ctx context.Context, t testing.TB, r Registry) GetAddr { t.Helper() From d42c8ae61025c0f92072d019a0f085d579edec0e Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 19 Dec 2024 13:10:11 +0100 Subject: [PATCH 4/4] test: fix racy setup --- cmd/status/root_test.go | 2 +- internal/driver/daemon.go | 90 ++++++++++++++++++---------------- internal/driver/daemon_test.go | 12 ++--- internal/e2e/full_suit_test.go | 2 - 4 files changed, 52 insertions(+), 54 deletions(-) diff --git a/cmd/status/root_test.go b/cmd/status/root_test.go index 01d794c65..35edf545d 100644 --- a/cmd/status/root_test.go +++ b/cmd/status/root_test.go @@ -46,7 +46,7 @@ func TestStatusCmd(t *testing.T) { }) t.Run("case=block", func(t *testing.T) { - ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, time.Millisecond) + ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, 100*time.Millisecond) l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/internal/driver/daemon.go b/internal/driver/daemon.go index 2f2d020fb..89fce833a 100644 --- a/internal/driver/daemon.go +++ b/internal/driver/daemon.go @@ -107,7 +107,7 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { innerCtx, cancel := context.WithCancel(ctx) defer cancel() - serveFuncs := []func(context.Context, chan<- struct{}) error{ + serveFuncs := []func(context.Context, chan<- struct{}) func() error{ r.serveRead, r.serveWrite, r.serveOPLSyntax, @@ -149,88 +149,94 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { // We need to separate the setup (invoking the functions that return the serve functions) from running the serve // functions to mitigate race conditions in the HTTP router. for _, serve := range serveFuncs { - eg.Go(func() error { - return serve(innerCtx, doneShutdown) - }) + eg.Go(serve(innerCtx, doneShutdown)) } return eg.Wait() } -func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.ReadRouter(ctx), r.ReadGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).ReadAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).ReadAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.WriteRouter(ctx), r.WriteGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).WriteAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).WriteAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.OPLSyntaxRouter(ctx), r.OplGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) func() error { ctx, cancel := context.WithCancel(ctx) - defer cancel() - - addr, listenFile := r.Config(ctx).MetricsListenOn() - l, err := listenAndWriteFile(ctx, addr, listenFile) - if err != nil { - return err - } //nolint:gosec // graceful.WithDefaults already sets a timeout s := graceful.WithDefaults(&http.Server{ Handler: r.metricsRouter(ctx), }) - eg := &errgroup.Group{} + return func() error { + defer cancel() + eg := &errgroup.Group{} - eg.Go(func() error { - if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { - return errors.WithStack(err) + addr, listenFile := r.Config(ctx).MetricsListenOn() + l, err := listenAndWriteFile(ctx, addr, listenFile) + if err != nil { + return err } - return nil - }) - eg.Go(func() (err error) { - defer func() { - l := r.Logger().WithField("endpoint", "metrics") - if err != nil { - l.WithError(err).Error("graceful shutdown failed") - } else { - l.Info("gracefully shutdown server") + + eg.Go(func() error { + if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { + return errors.WithStack(err) } - done <- struct{}{} - }() + return nil + }) + eg.Go(func() (err error) { + defer func() { + l := r.Logger().WithField("endpoint", "metrics") + if err != nil { + l.WithError(err).Error("graceful shutdown failed") + } else { + l.Info("gracefully shutdown server") + } + done <- struct{}{} + }() - <-ctx.Done() - ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) - defer cancel() - return s.Shutdown(ctx) - }) + <-ctx.Done() + ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) + defer cancel() + return s.Shutdown(ctx) + }) - return eg.Wait() + return eg.Wait() + } } func multiplexPort(ctx context.Context, log *logrusx.Logger, addr, listenFile string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error { diff --git a/internal/driver/daemon_test.go b/internal/driver/daemon_test.go index e84fb3031..0b4666d63 100644 --- a/internal/driver/daemon_test.go +++ b/internal/driver/daemon_test.go @@ -36,12 +36,8 @@ func TestScrapingEndpoint(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(func() error { - return r.serveWrite(ctx, doneShutdown) - }) - eg.Go(func() error { - return r.serveMetrics(ctx, doneShutdown) - }) + eg.Go(r.serveWrite(ctx, doneShutdown)) + eg.Go(r.serveMetrics(ctx, doneShutdown)) _, writePort, _ := getAddr(t, "write") _, metricsPort, _ := getAddr(t, "metrics") @@ -104,9 +100,7 @@ func TestPanicRecovery(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(func() error { - return r.serveWrite(ctx, doneShutdown) - }) + eg.Go(r.serveWrite(ctx, doneShutdown)) _, port, _ := getAddr(t, "write") diff --git a/internal/e2e/full_suit_test.go b/internal/e2e/full_suit_test.go index ab381729c..d16eeac49 100644 --- a/internal/e2e/full_suit_test.go +++ b/internal/e2e/full_suit_test.go @@ -53,7 +53,6 @@ const ( func Test(t *testing.T) { t.Parallel() for _, dsn := range dbx.GetDSNs(t, false) { - dsn := dsn t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) { t.Parallel() @@ -98,7 +97,6 @@ func Test(t *testing.T) { syntaxRemote: oplAddr, }, } { - cl := cl t.Run(fmt.Sprintf("client=%T", cl), runCases(cl, namespaceTestMgr)) if tc, ok := cl.(transactClient); ok {