diff --git a/pkg/internal/runtime/runtime.go b/pkg/internal/runtime/runtime.go index 79165084..9c414c08 100644 --- a/pkg/internal/runtime/runtime.go +++ b/pkg/internal/runtime/runtime.go @@ -52,6 +52,10 @@ type Args struct { TLSCertFile string // Absolute path to the TLS key file. TLSKeyFile string + // HttpEndpoint is the address of the metrics sever + HttpEndpoint string + // MetricsPath is the path where metrics will be recorded + MetricsPath string } func (args *Args) Validate() error { @@ -169,7 +173,9 @@ func (rt *Runtime) kubeConnect(kubeconfig string, kubeAPIQPS float32, kubeAPIBur func (rt *Runtime) csiConnect(csiAddress string) error { ctx := context.Background() - metricsManager := metrics.NewCSIMetricsManagerForSidecar("" /* driverName */) + metricsManager := metrics.NewCSIMetricsManagerWithOptions("", + metrics.WithSubsystem(SubSystem), + metrics.WithLabelNames(LabelTargetSnapshotName, LabelBaseSnapshotName)) csiConn, err := connection.Connect( ctx, csiAddress, diff --git a/pkg/internal/runtime/test_harness.go b/pkg/internal/runtime/test_harness.go index 5aa04b9f..eba6475c 100644 --- a/pkg/internal/runtime/test_harness.go +++ b/pkg/internal/runtime/test_harness.go @@ -71,6 +71,7 @@ type TestHarness struct { MockCSIIdentityServer *driver.MockIdentityServer MockCSISnapshotMetadataServer *driver.MockSnapshotMetadataServer MockCSIDriverConn *grpc.ClientConn + MetricsManager metrics.CSIMetricsManager FakeCSIDriver *driver.CSIDriver @@ -120,6 +121,8 @@ func (th *TestHarness) RuntimeArgs() Args { GRPCPort: th.rtaPortNumber, TLSCertFile: th.tlsCertFile, TLSKeyFile: th.tlsKeyFile, + HttpEndpoint: "localhost:8081", + MetricsPath: "/metrics", } } @@ -205,7 +208,6 @@ func (th *TestHarness) WithMockCSIDriver(t *testing.T) *TestHarness { mockController := gomock.NewController(t) identityServer := driver.NewMockIdentityServer(mockController) snapshotMetadataServer := driver.NewMockSnapshotMetadataServer(mockController) - metricsManager := metrics.NewCSIMetricsManagerForSidecar("" /* driverName */) drv := driver.NewMockCSIDriver(&driver.MockCSIDriverServers{ Identity: identityServer, SnapshotMetadata: snapshotMetadataServer, @@ -215,6 +217,9 @@ func (th *TestHarness) WithMockCSIDriver(t *testing.T) *TestHarness { // Create a client connection to it addr := drv.Address() + metricsManager := metrics.NewCSIMetricsManagerWithOptions("", + metrics.WithSubsystem(SubSystem), + metrics.WithLabelNames(LabelTargetSnapshotName, LabelBaseSnapshotName)) csiConn, err := connection.Connect(context.Background(), addr, metricsManager) if err != nil { t.Fatal("Connect", err) @@ -226,6 +231,7 @@ func (th *TestHarness) WithMockCSIDriver(t *testing.T) *TestHarness { th.MockCSIIdentityServer = identityServer th.MockCSISnapshotMetadataServer = snapshotMetadataServer th.driverName = "mock-csi-driver" + th.MetricsManager = metricsManager return th } diff --git a/pkg/internal/runtime/util_metrics.go b/pkg/internal/runtime/util_metrics.go new file mode 100644 index 00000000..f38fdaf2 --- /dev/null +++ b/pkg/internal/runtime/util_metrics.go @@ -0,0 +1,54 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "time" + + "k8s.io/klog/v2" +) + +const ( + LabelTargetSnapshotName = "target_snapshot" + LabelBaseSnapshotName = "base_snapshot" + SubSystem = "snapshot_metadata_controller" + + // MetadataAllocatedOperationName is the operation that tracks how long the controller takes to get the allocated blocks for a snapshot. + // Specifically, the operation metric is emitted based on the following timestamps: + // - Start_time: controller notices the first time that there is a GetMetadataAllocated RPC call to fetch the allocated blocks of metadata + // - End_time: controller notices that the RPC call is finished and the allocated blocks is streamed back to the driver + MetadataAllocatedOperationName = "MetadataAllocated" + + // MetadataDeltaOperationName is the operation that tracks how long the controller takes to get the changed blocks between 2 snapshots + // Specifically, the operation metric is emitted based on the following timestamps: + // - Start_time: controller notices the first time that there is a GetMetadataDelta RPC call to fetch the changed blocks between 2 snapshots + // - End_time: controller notices that the RPC call is finished and the changed blocks is streamed back to the driver + MetadataDeltaOperationName = "MetadataDelta" +) + +// RecordMetricsWithLabels is a wrapper on the csi-lib-utils RecordMetrics function, that calls the +// "RecordMetrics" functions with the necessary labels added to the MetricsManager runtime. +func (rt *Runtime) RecordMetricsWithLabels(opLabel map[string]string, opName string, startTime time.Time, opErr error) { + metricsWithLabel, err := rt.MetricsManager.WithLabelValues(opLabel) + if err != nil { + klog.Error(err, "failed to add labels to metrics") + return + } + + opDuration := time.Since(startTime) + metricsWithLabel.RecordMetrics(opName, opErr, opDuration) +} diff --git a/pkg/internal/server/grpc/common_test.go b/pkg/internal/server/grpc/common_test.go index be4877b2..2ef5da9b 100644 --- a/pkg/internal/server/grpc/common_test.go +++ b/pkg/internal/server/grpc/common_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/kubernetes-csi/csi-lib-utils/metrics" snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v8/apis/volumesnapshot/v1" fakesnapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned/fake" snapshotutils "github.com/kubernetes-csi/external-snapshotter/v8/pkg/utils" @@ -112,6 +113,7 @@ func (th *testHarness) Runtime() *runtime.Runtime { SnapshotClient: th.FakeSnapshotClient, DriverName: th.DriverName, CSIConn: th.mockCSIDriverConn, + MetricsManager: metrics.NewCSIMetricsManagerWithOptions(th.DriverName, metrics.WithSubsystem(runtime.SubSystem), metrics.WithLabelNames(runtime.LabelTargetSnapshotName, runtime.LabelBaseSnapshotName)), } } diff --git a/pkg/internal/server/grpc/get_metadata_allocated.go b/pkg/internal/server/grpc/get_metadata_allocated.go index 35f84e7a..588757cb 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated.go +++ b/pkg/internal/server/grpc/get_metadata_allocated.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "strings" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" @@ -28,15 +29,24 @@ import ( "k8s.io/klog/v2" "github.com/kubernetes-csi/external-snapshot-metadata/pkg/api" + "github.com/kubernetes-csi/external-snapshot-metadata/pkg/internal/runtime" ) -func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) error { +func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) (err error) { // Create a timeout context so that failure in either sending to the client or // receiving from the CSI driver will ultimately abort the handler session. // The context could also get canceled by the client. ctx, cancelFn := context.WithTimeout(s.getMetadataAllocatedContextWithLogger(req, stream), s.config.MaxStreamDur) defer cancelFn() + // Record metrics when the operation ends + defer func(startTime time.Time) { + opLabel := map[string]string{ + runtime.LabelTargetSnapshotName: fmt.Sprintf("%s/%s", req.Namespace, req.SnapshotName), + } + s.config.Runtime.RecordMetricsWithLabels(opLabel, runtime.MetadataAllocatedOperationName, startTime, err) + }(time.Now()) + if err := s.validateGetMetadataAllocatedRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") return err @@ -63,7 +73,8 @@ func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stre return err } - return s.streamGetMetadataAllocatedResponse(ctx, stream, csiStream) + err = s.streamGetMetadataAllocatedResponse(ctx, stream, csiStream) + return err } // getMetadataAllocatedContextWithLogger returns the stream context with an embedded diff --git a/pkg/internal/server/grpc/get_metadata_allocated_test.go b/pkg/internal/server/grpc/get_metadata_allocated_test.go index c1a708d8..e56ed8db 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated_test.go +++ b/pkg/internal/server/grpc/get_metadata_allocated_test.go @@ -151,6 +151,31 @@ func TestGetMetadataAllocatedViaGRPCClient(t *testing.T) { } else if errStream != nil { assert.ErrorIs(t, errStream, io.EOF) } + + // Validate metrics are recorded correctly + metrics, _ := grpcServer.config.Runtime.MetricsManager.GetRegistry().Gather() + statusFound := 0 + snapshotFound := 0 + + // Validate that both gauge and controller metrics is recorded + assert.GreaterOrEqual(t, 2, len(metrics)) + assert.Equal(t, *metrics[0].Name, "process_start_time_seconds") + assert.Equal(t, *metrics[1].Name, "snapshot_metadata_controller_operations_seconds") + + // Validate grpc_status_code and target_snapshot name + for _, metric := range metrics[1].Metric { + for _, labels := range metric.Label { + expTargetSnapshotName := fmt.Sprintf("%s/%s", tc.req.Namespace, tc.req.SnapshotName) + if *labels.Name == "grpc_status_code" && *labels.Value == tc.expStatusCode.String() { + statusFound = 1 + } + if *labels.Name == "target_snapshot" && *labels.Value == expTargetSnapshotName { + snapshotFound = 1 + } + } + } + assert.Equal(t, 1, statusFound) + assert.Equal(t, 1, snapshotFound) }) } } diff --git a/pkg/internal/server/grpc/get_metadata_delta.go b/pkg/internal/server/grpc/get_metadata_delta.go index 42b1422a..4769be35 100644 --- a/pkg/internal/server/grpc/get_metadata_delta.go +++ b/pkg/internal/server/grpc/get_metadata_delta.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "strings" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" @@ -28,15 +29,25 @@ import ( "k8s.io/klog/v2" "github.com/kubernetes-csi/external-snapshot-metadata/pkg/api" + "github.com/kubernetes-csi/external-snapshot-metadata/pkg/internal/runtime" ) -func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) error { +func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) (err error) { // Create a timeout context so that failure in either sending to the client or // receiving from the CSI driver will ultimately abort the handler session. // The context could also get canceled by the client. ctx, cancelFn := context.WithTimeout(s.getMetadataDeltaContextWithLogger(req, stream), s.config.MaxStreamDur) defer cancelFn() + // Record metrics when the operation ends + defer func(startTime time.Time) { + opLabel := map[string]string{ + runtime.LabelTargetSnapshotName: fmt.Sprintf("%s/%s", req.Namespace, req.TargetSnapshotName), + runtime.LabelBaseSnapshotName: fmt.Sprintf("%s/%s", req.Namespace, req.BaseSnapshotName), + } + s.config.Runtime.RecordMetricsWithLabels(opLabel, runtime.MetadataAllocatedOperationName, startTime, err) + }(time.Now()) + if err := s.validateGetMetadataDeltaRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") return err @@ -63,7 +74,8 @@ func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.S return err } - return s.streamGetMetadataDeltaResponse(ctx, stream, csiStream) + err = s.streamGetMetadataDeltaResponse(ctx, stream, csiStream) + return err } func (s *Server) getMetadataDeltaContextWithLogger(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) context.Context { diff --git a/pkg/internal/server/grpc/get_metadata_delta_test.go b/pkg/internal/server/grpc/get_metadata_delta_test.go index 3c0848ea..34a2ef52 100644 --- a/pkg/internal/server/grpc/get_metadata_delta_test.go +++ b/pkg/internal/server/grpc/get_metadata_delta_test.go @@ -156,6 +156,37 @@ func TestGetMetadataDeltaViaGRPCClient(t *testing.T) { } else if errStream != nil { assert.ErrorIs(t, errStream, io.EOF) } + + // Validate metrics are recorded correctly + metrics, _ := grpcServer.config.Runtime.MetricsManager.GetRegistry().Gather() + statusFound := 0 + targetSnapshotFound := 0 + baseSnapshotFound := 0 + + // Validate that both gauge and controller metrics is recorded + assert.GreaterOrEqual(t, 2, len(metrics)) + assert.Equal(t, *metrics[0].Name, "process_start_time_seconds") + assert.Equal(t, *metrics[1].Name, "snapshot_metadata_controller_operations_seconds") + + // Validate grpc_status_code and target_snapshot name + for _, metric := range metrics[1].Metric { + for _, labels := range metric.Label { + expTargetSnapshotName := fmt.Sprintf("%s/%s", tc.req.Namespace, tc.req.TargetSnapshotName) + expBaseSnapshotName := fmt.Sprintf("%s/%s", tc.req.Namespace, tc.req.BaseSnapshotName) + if *labels.Name == "grpc_status_code" && *labels.Value == tc.expStatusCode.String() { + statusFound = 1 + } + if *labels.Name == "target_snapshot" && *labels.Value == expTargetSnapshotName { + targetSnapshotFound = 1 + } + if *labels.Name == "base_snapshot" && *labels.Value == expBaseSnapshotName { + baseSnapshotFound = 1 + } + } + } + assert.Equal(t, 1, statusFound) + assert.Equal(t, 1, targetSnapshotFound) + assert.Equal(t, 1, baseSnapshotFound) }) } } diff --git a/pkg/sidecar/sidecar.go b/pkg/sidecar/sidecar.go index e07dd25b..9f63d5e9 100644 --- a/pkg/sidecar/sidecar.go +++ b/pkg/sidecar/sidecar.go @@ -19,6 +19,7 @@ package sidecar import ( "flag" "fmt" + "net/http" "os" "os/signal" "strconv" @@ -86,16 +87,25 @@ func Run(argv []string, version string) int { klog.Infof("CSI driver name: %q", rt.DriverName) - // TBD May need to exposed metric HTTP end point - // here because the wait for the CSI driver is open ended. - grpcServer, err := startGRPCServerAndValidateCSIDriver(s.createServerConfig(rt)) if err != nil { klog.Error(err) return 1 } - // TODO: Start the HTTP metrics server here. + // start listening & serving http endpoint, if set + mux := http.NewServeMux() + if *s.httpEndpoint != "" { + rt.MetricsManager.RegisterToServer(mux, *s.metricsPath) + rt.MetricsManager.SetDriverName(rt.DriverName) + go func() { + klog.Infof("ServeMux listening at %q", *s.httpEndpoint) + err := http.ListenAndServe(*s.httpEndpoint, mux) + if err != nil { + klog.Fatalf("Failed to start HTTP server at specified address (%q) and metrics path (%q): %s", *s.httpEndpoint, *s.metricsPath, err) + } + }() + } shutdownOnTerminationSignal(grpcServer) @@ -172,7 +182,6 @@ func (s *sidecarFlagSet) parseFlagsAndHandleShowVersion(args []string) (handledS } func (s *sidecarFlagSet) runtimeArgsFromFlags() runtime.Args { - // TODO: set the HTTP server properties. return runtime.Args{ CSIAddress: *s.csiAddress, CSITimeout: *s.csiTimeout, @@ -182,6 +191,8 @@ func (s *sidecarFlagSet) runtimeArgsFromFlags() runtime.Args { GRPCPort: *s.grpcPort, TLSCertFile: *s.tlsCert, TLSKeyFile: *s.tlsKey, + HttpEndpoint: *s.httpEndpoint, + MetricsPath: *s.metricsPath, } } @@ -222,6 +233,14 @@ func (s *sidecarFlagSet) runtimeArgsToArgv(progName string, rta runtime.Args) [] argv = append(argv, "-"+flagKubeAPIQPS, strconv.FormatFloat(float64(rta.KubeAPIQPS), 'f', -1, 32)) } + if rta.HttpEndpoint != "" { + argv = append(argv, "-"+flagHTTPEndpoint, rta.HttpEndpoint) + } + + if rta.MetricsPath != defaultMetricsPath { + argv = append(argv, "-"+flagMetricsPath, rta.MetricsPath) + } + return argv } diff --git a/pkg/sidecar/sidecar_test.go b/pkg/sidecar/sidecar_test.go index d036fe9b..1b9bbca1 100644 --- a/pkg/sidecar/sidecar_test.go +++ b/pkg/sidecar/sidecar_test.go @@ -21,8 +21,10 @@ import ( "fmt" "io" "math/rand/v2" + "net/http" "os" "regexp" + "strings" "sync" "syscall" "testing" @@ -110,6 +112,45 @@ func TestSidecarFlagSet(t *testing.T) { GRPCPort: defaultGRPCPort, TLSCertFile: expTLSCertFile, TLSKeyFile: expTLSKeyFile, + MetricsPath: defaultMetricsPath, + } + + assert.Equal(t, expRTA, rta) + + rt := &runtime.Runtime{} + config := sfs.createServerConfig(rt) + assert.Equal(t, rt, config.Runtime) + assert.Equal(t, time.Duration(defaultMaxStreamingDurationMin)*time.Minute, config.MaxStreamDur) + }) + + t.Run("http-endpoint-and-metrics-flag", func(t *testing.T) { + defer saveAndResetGlobalState()() + + expTLSCertFile := "/tls/certFile" + t.Setenv(tlsCertEnvVar, expTLSCertFile) + expTLSKeyFile := "/tls/keyFile" + t.Setenv(tlsKeyEnvVar, expTLSKeyFile) + + argv := []string{"progName", "-http-endpoint=localhost:8080", "-metrics-path=/metPath"} + sfs := newSidecarFlagSet(argv[0], "version") + + hsv, err := sfs.parseFlagsAndHandleShowVersion(argv[1:]) + assert.NoError(t, err) + assert.False(t, hsv) + + rta := sfs.runtimeArgsFromFlags() + + expRTA := runtime.Args{ + CSIAddress: defaultCSISocket, + CSITimeout: defaultCSITimeout, + KubeAPIBurst: defaultKubeAPIBurst, + KubeAPIQPS: defaultKubeAPIQPS, + Kubeconfig: defaultKubeconfig, + GRPCPort: defaultGRPCPort, + TLSCertFile: expTLSCertFile, + TLSKeyFile: expTLSKeyFile, + HttpEndpoint: "localhost:8080", + MetricsPath: "/metPath", } assert.Equal(t, expRTA, rta) @@ -230,6 +271,71 @@ func TestRun(t *testing.T) { assert.Equal(t, 0, rc) }) + + t.Run("launch-and-terminate-with-http-server", func(t *testing.T) { + proc, err := os.FindProcess(syscall.Getpid()) + assert.NoError(t, err) + + // Specifying a fake snapshot metadata server to WithFakeCSIDriver() + // makes the fake identity server advertise the needed capabilities. + sms := &testSnapshotMetadataServer{} + rth := runtime.NewTestHarness().WithTestTLSFiles(t).WithFakeKubeConfig(t).WithFakeCSIDriver(t, sms) + defer rth.RemoveTestTLSFiles(t) + defer rth.RemoveFakeKubeConfig(t) + defer rth.TerminateFakeCSIDriver(t) + + // Still need to add a response to the fake identity server Probe. + rth.FakeProbeResponse = &csi.ProbeResponse{Ready: wrapperspb.Bool(true)} + + rt := rth.RuntimeForFakeCSIDriver(t) + rt.Args.HttpEndpoint = "localhost:8082" + rt.Args.MetricsPath = defaultMetricsPath + + sfs := &sidecarFlagSet{} + argv := sfs.runtimeArgsToArgv("progName", rt.Args) + argv = append(argv, flagMaxStreamingDurationMin, fmt.Sprintf("%d", defaultMaxStreamingDurationMin+1)) + + // invoke Run() in a goroutine so as not to block. + wg := sync.WaitGroup{} + wg.Add(1) + startedChan := make(chan int) + + rc := -1 // this will track the return value of Run(). + + go func() { + close(startedChan) + rc = Run(argv, "version") + srvAddr := "http://" + rt.Args.HttpEndpoint + rt.Args.MetricsPath + rsp, err := http.Get(srvAddr) + if err != nil || rsp.StatusCode != http.StatusOK { + t.Errorf("failed to get response from server %v, %v", err, rsp) + } + r, err := io.ReadAll(rsp.Body) + if err != nil { + t.Errorf("failed to read response body %v", err) + } + // Validate that the metrics contains "snapshot_metadata_controller_operations_seconds" type histogram + if !strings.Contains(string(r), "snapshot_metadata_controller_operations_seconds") { + t.Errorf("didn't find expected type in metrics[%s]", string(r)) + } + wg.Done() + }() + + <-startedChan + + // Send a termination signal to the server after a brief delay. + // As there are multiple possible termination signals we randomly + // select one each invocation. + go func() { + time.Sleep(time.Millisecond * 100) + termSigIdx := rand.IntN(len(terminationSignals)) + proc.Signal(terminationSignals[termSigIdx]) + }() + + wg.Wait() + + assert.Equal(t, 0, rc) + }) } func saveAndResetGlobalState() func() {