diff --git a/cmd/sink-worker/main.go b/cmd/sink-worker/main.go index 7025589c7..52551aa11 100644 --- a/cmd/sink-worker/main.go +++ b/cmd/sink-worker/main.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "os" "syscall" @@ -39,11 +40,14 @@ import ( "github.com/openmeterio/openmeter/pkg/slicesx" ) +const ( + defaultShutdownTimeout = 5 * time.Second +) + var otelName string = "openmeter.io/sink-worker" func main() { v, flags := viper.New(), pflag.NewFlagSet("OpenMeter", pflag.ExitOnError) - ctx := context.Background() config.Configure(v, flags) @@ -78,8 +82,12 @@ func main() { panic(err) } + // Setup main context covering the application lifecycle and ensure that the context is canceled on process exit. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + extraResources, _ := resource.New( - context.Background(), + ctx, resource.WithContainer(), resource.WithAttributes( semconv.ServiceName("openmeter-sink-worker"), @@ -101,13 +109,18 @@ func main() { telemetryRouter.Mount("/debug", middleware.Profiler()) // Initialize OTel Metrics - otelMeterProvider, err := conf.Telemetry.Metrics.NewMeterProvider(context.Background(), res) + otelMeterProvider, err := conf.Telemetry.Metrics.NewMeterProvider(ctx, res) if err != nil { logger.Error(err.Error()) os.Exit(1) } defer func() { - if err := otelMeterProvider.Shutdown(context.Background()); err != nil { + // Use dedicated context with timeout for shutdown as parent context might be canceled + // by the time the execution reaches this stage. + ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) + defer cancel() + + if err := otelMeterProvider.Shutdown(ctx); err != nil { logger.Error("shutting down meter provider: %v", err) } }() @@ -119,13 +132,18 @@ func main() { } // Initialize OTel Tracer - otelTracerProvider, err := conf.Telemetry.Trace.NewTracerProvider(context.Background(), res) + otelTracerProvider, err := conf.Telemetry.Trace.NewTracerProvider(ctx, res) if err != nil { logger.Error(err.Error()) os.Exit(1) } defer func() { - if err := otelTracerProvider.Shutdown(context.Background()); err != nil { + // Use dedicated context with timeout for shutdown as parent context might be canceled + // by the time the execution reaches this stage. + ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) + defer cancel() + + if err := otelTracerProvider.Shutdown(ctx); err != nil { logger.Error("shutting down tracer provider", "error", err) } }() @@ -163,42 +181,51 @@ func main() { logger.Error("failed to initialize sink worker", "error", err) os.Exit(1) } - - var group run.Group + defer sink.Close() // Set up telemetry server - { - server := &http.Server{ - Addr: conf.Telemetry.Address, - Handler: telemetryRouter, - } - defer server.Close() - - group.Add( - func() error { return server.ListenAndServe() }, - func(err error) { _ = server.Shutdown(ctx) }, - ) + server := &http.Server{ + Addr: conf.Telemetry.Address, + Handler: telemetryRouter, + BaseContext: func(_ net.Listener) context.Context { + return ctx + }, } + defer server.Close() - // Starting sink worker - { - defer sink.Close() + var group run.Group + // Add sink worker to run group + group.Add( + func() error { return sink.Run(ctx) }, + func(err error) { sink.Close() }, + ) - group.Add( - func() error { return sink.Run() }, - func(err error) { _ = sink.Close() }, - ) - } + // Add telemetry server to run group + group.Add( + func() error { return server.ListenAndServe() }, + func(err error) { _ = server.Shutdown(ctx) }, + ) // Setup signal handler group.Add(run.SignalHandler(ctx, syscall.SIGINT, syscall.SIGTERM)) + // Run actors err = group.Run() + + var exitCode int if e := (run.SignalError{}); errors.As(err, &e) { - slog.Info("received signal; shutting down", slog.String("signal", e.Signal.String())) + logger.Info("received signal: shutting down", slog.String("signal", e.Signal.String())) + switch e.Signal { + case syscall.SIGTERM: + default: + exitCode = 130 + } } else if !errors.Is(err, http.ErrServerClosed) { logger.Error("application stopped due to error", slog.String("error", err.Error())) + exitCode = 1 } + + os.Exit(exitCode) } func initClickHouseClient(config config.Configuration) (clickhouse.Conn, error) { diff --git a/internal/sink/sink.go b/internal/sink/sink.go index 5fbaa8a08..96bb282fb 100644 --- a/internal/sink/sink.go +++ b/internal/sink/sink.go @@ -5,11 +5,9 @@ import ( "encoding/json" "fmt" "log/slog" - "os" - "os/signal" "regexp" "sort" - "syscall" + "sync/atomic" "time" "github.com/avast/retry-go/v4" @@ -35,7 +33,7 @@ type SinkMessage struct { type Sink struct { config SinkConfig - running bool + isRunning atomic.Bool buffer *SinkBuffer flushTimer *time.Timer flushEventCounter metric.Int64Counter @@ -404,14 +402,14 @@ func (s *Sink) clearFlushTimer() { } // Run starts the Kafka consumer and sinks the events to Clickhouse -func (s *Sink) Run() error { - ctx := context.TODO() +func (s *Sink) Run(ctx context.Context) error { + if s.isRunning.Load() { + return nil + } + logger := s.config.Logger.With("operation", "run") logger.Info("starting sink") - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) - // Fetch namespaces and meters and subscribe to them err := s.subscribeToNamespaces() if err != nil { @@ -431,16 +429,16 @@ func (s *Sink) Run() error { s.namespaceRefetch = time.AfterFunc(s.config.NamespaceRefetch, refetch) // Reset state - s.running = true + s.isRunning.Store(true) // Start flush timer, this will be cleared and restarted by flush s.setFlushTimer() - for s.running { + for s.isRunning.Load() { select { - case sig := <-sigchan: - logger.Error("caught signal, terminating", "sig", sig) - s.running = false + case <-ctx.Done(): + return fmt.Errorf("context canceled: %w", ctx.Err()) + default: ev := s.config.Consumer.Poll(100) if ev == nil { @@ -452,7 +450,7 @@ func (s *Sink) Run() error { sinkMessage := SinkMessage{ KafkaMessage: e, } - namespace, kafkaCloudEvent, err := s.ParseMessage(e) + namespace, kafkaCloudEvent, err := s.parseMessage(ctx, e) if err != nil { if perr, ok := err.(*ProcessingError); ok { sinkMessage.Error = perr @@ -492,8 +490,7 @@ func (s *Sink) Run() error { } } - logger.Info("closing sink") - return s.Close() + return nil } func (s *Sink) pause() error { @@ -593,15 +590,13 @@ func (s *Sink) rebalance(c *kafka.Consumer, event kafka.Event) error { // Remove messages for revoked partitions from buffer s.buffer.RemoveByPartitions(e.Partitions) default: - logger.Error("unxpected event type", "event", e) + logger.Error("unexpected event type", "event", e) } return nil } -func (s *Sink) ParseMessage(e *kafka.Message) (string, *serializer.CloudEventsKafkaPayload, error) { - ctx := context.TODO() - +func (s *Sink) parseMessage(ctx context.Context, e *kafka.Message) (string, *serializer.CloudEventsKafkaPayload, error) { // Get Namespace namespace, err := getNamespace(*e.TopicPartition.Topic) if err != nil { @@ -639,18 +634,21 @@ func (s *Sink) ParseMessage(e *kafka.Message) (string, *serializer.CloudEventsKa return namespace, &kafkaCloudEvent, err } -func (s *Sink) Close() error { +func (s *Sink) Close() { + if !s.isRunning.Load() { + return + } + s.config.Logger.Info("closing sink") + s.isRunning.Store(false) - s.running = false if s.namespaceRefetch != nil { - s.namespaceRefetch.Stop() + _ = s.namespaceRefetch.Stop() } + if s.flushTimer != nil { - s.flushTimer.Stop() + _ = s.flushTimer.Stop() } - - return nil } // getNamespace from topic