diff --git a/grpc/interceptors_test.go b/grpc/interceptors_test.go index 1b5415fedcd..4d992c2c73f 100644 --- a/grpc/interceptors_test.go +++ b/grpc/interceptors_test.go @@ -28,6 +28,7 @@ import ( "github.com/letsencrypt/boulder/grpc/test_proto" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/test" + "github.com/letsencrypt/boulder/web" ) var fc = clock.NewFake() @@ -524,3 +525,73 @@ func TestServiceAuthChecker(t *testing.T) { err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertNotError(t, err, "checking allowed cert") } + +// testUserAgentServer stores the last value it saw in the user agent field of its context. +type testUserAgentServer struct { + test_proto.UnimplementedChillerServer + + lastSeenUA string +} + +// Chill implements ChillerServer.Chill +func (s *testUserAgentServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { + s.lastSeenUA = web.UserAgent(ctx) + return nil, nil +} + +func TestUserAgentMetadata(t *testing.T) { + server := new(testUserAgentServer) + client, stop := setup(t, server) + defer stop() + + testUA := "test UA" + ctx := web.WithUserAgent(context.Background(), testUA) + + _, err := client.Chill(ctx, &test_proto.Time{}) + if err != nil { + t.Fatalf("calling c.Chill: %s", err) + } + + if server.lastSeenUA != testUA { + t.Errorf("last seen User-Agent on server side was %q, want %q", server.lastSeenUA, testUA) + } +} + +func setup(t *testing.T, server test_proto.ChillerServer) (test_proto.ChillerClient, func()) { + clk := clock.NewFake() + serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer) + test.AssertNotError(t, err, "creating server metrics") + clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer) + test.AssertNotError(t, err, "creating client metrics") + + lis, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + port := lis.Addr().(*net.TCPAddr).Port + + si := newServerMetadataInterceptor(serverMetricsVal, clk) + s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) + test_proto.RegisterChillerServer(s, server) + + go func() { + start := time.Now() + err := s.Serve(lis) + if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Logf("s.Serve: %v after %s", err, time.Since(start)) + } + }() + + ci := &clientMetadataInterceptor{ + timeout: 30 * time.Second, + metrics: clientMetricsVal, + clk: clock.NewFake(), + } + conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor(ci.Unary)) + if err != nil { + t.Fatalf("did not connect: %v", err) + } + return test_proto.NewChillerClient(conn), s.Stop +}