Skip to content

Commit

Permalink
grpc: factor out setup func
Browse files Browse the repository at this point in the history
This uses a pattern that is new to our tests. setup accepts a variadic list of
options, and uses a type switch to make use of those options during setup. This
allows us to pass setup only the options that are relevant to any given test
case, leaving the rest to sensible defaults.
  • Loading branch information
jsha committed Dec 20, 2024
1 parent d6e163c commit 9a1b4e1
Showing 1 changed file with 87 additions and 148 deletions.
235 changes: 87 additions & 148 deletions grpc/interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ func TestWaitForReadyFalse(t *testing.T) {
}
}

// testServer is used to implement TestTimeouts, and will attempt to sleep for
// testTimeoutServer is used to implement TestTimeouts, and will attempt to sleep for
// the given amount of time (unless it hits a timeout or cancel).
type testServer struct {
type testTimeoutServer struct {
test_proto.UnimplementedChillerServer
}

// Chill implements ChillerServer.Chill
func (s *testServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) {
func (s *testTimeoutServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) {
start := time.Now()
// Sleep for either the requested amount of time, or the context times out or
// is canceled.
Expand All @@ -175,42 +175,9 @@ func (s *testServer) Chill(ctx context.Context, in *test_proto.Time) (*test_prot
}

func TestTimeouts(t *testing.T) {
// start server
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
port := lis.Addr().(*net.TCPAddr).Port

serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, &testServer{})
go func() {
start := time.Now()
err := s.Serve(lis)
if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Logf("s.Serve: %v after %s", err, time.Since(start))
}
}()
defer s.Stop()

// make client
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clock.NewFake(),
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
c := test_proto.NewChillerClient(conn)
server := new(testTimeoutServer)
client, _, stop := setup(t, server, clock.NewFake())
defer stop()

testCases := []struct {
timeout time.Duration
Expand All @@ -224,7 +191,7 @@ func TestTimeouts(t *testing.T) {
t.Run(tc.timeout.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
defer cancel()
_, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)})
_, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)})
if err == nil {
t.Fatal("Got no error, expected a timeout")
}
Expand All @@ -236,58 +203,22 @@ func TestTimeouts(t *testing.T) {
}

func TestRequestTimeTagging(t *testing.T) {
clk := clock.NewFake()
// Listen for TCP requests on a random system assigned port number
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
// Retrieve the concrete port numberthe system assigned our listener
port := lis.Addr().(*net.TCPAddr).Port

// Create a new ChillerServer
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
server := new(testTimeoutServer)
metrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerMetadataInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, &testServer{})
// Chill until ill
go func() {
start := time.Now()
err := s.Serve(lis)
if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Logf("s.Serve: %v after %s", err, time.Since(start))
}
}()
defer s.Stop()

// Dial the ChillerServer
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clk,
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
// Create a ChillerClient with the connection to the ChillerServer
c := test_proto.NewChillerClient(conn)
client, _, stop := setup(t, server, metrics)
defer stop()

// Make an RPC request with the ChillerClient with a timeout higher than the
// requested ChillerServer delay so that the RPC completes normally
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if _, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil {
if _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil {
t.Fatalf("Unexpected error calling Chill RPC: %s", err)
}

// There should be one histogram sample in the serverInterceptor rpcLag stat
test.AssertMetricWithLabelsEquals(t, si.metrics.rpcLag, prometheus.Labels{}, 1)
test.AssertMetricWithLabelsEquals(t, metrics.rpcLag, prometheus.Labels{}, 1)
}

func TestClockSkew(t *testing.T) {
Expand All @@ -297,32 +228,23 @@ func TestClockSkew(t *testing.T) {
clientClk := clock.NewFake()
clientClk.Set(time.Now())

// Listen for TCP requests on a random system assigned port number
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
port := lis.Addr().(*net.TCPAddr).Port
_, serverPort, stop := setup(t, &testTimeoutServer{}, serverClk)
defer stop()

// Start a gRPC server listening on that port
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerMetadataInterceptor(serverMetrics, serverClk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, &testServer{})
go func() { _ = s.Serve(lis) }()
defer s.Stop()

// Start a gRPC client talking to the server
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientMetadataInterceptor{metrics: clientMetrics, clk: clientClk, timeout: time.Second}
conn, err := grpc.NewClient(
net.JoinHostPort("localhost", strconv.Itoa(port)),
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clientClk,
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(serverPort)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.Unary),
)
test.AssertNotError(t, err, "creating test client")
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}

client := test_proto.NewChillerClient(conn)

// Create a context with plenty of timeout
Expand Down Expand Up @@ -368,18 +290,15 @@ func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_prot
}

func TestInFlightRPCStat(t *testing.T) {
clk := clock.NewFake()
// Listen for TCP requests on a random system assigned port number
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
// Retrieve the concrete port numberthe system assigned our listener
port := lis.Addr().(*net.TCPAddr).Port

// Create a new blockedServer to act as a ChillerServer
server := &blockedServer{}

metrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")

client, _, stop := setup(t, server, metrics)
defer stop()

// Increment the roadblock waitgroup - this will cause all chill RPCs to
// the server to block until we call Done()!
server.roadblock.Add(1)
Expand All @@ -390,43 +309,11 @@ func TestInFlightRPCStat(t *testing.T) {
numRPCs := 5
server.received.Add(numRPCs)

serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerMetadataInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, server)
// Chill until ill
go func() {
start := time.Now()
err := s.Serve(lis)
if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Logf("s.Serve: %v after %s", err, time.Since(start))
}
}()
defer s.Stop()

// Dial the ChillerServer
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clk,
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
// Create a ChillerClient with the connection to the ChillerServer
c := test_proto.NewChillerClient(conn)

// Fire off a few RPCs. They will block on the blockedServer's roadblock wg
for range numRPCs {
go func() {
// Ignore errors, just chilllll.
_, _ = c.Chill(context.Background(), &test_proto.Time{})
_, _ = client.Chill(context.Background(), &test_proto.Time{})
}()
}

Expand All @@ -441,15 +328,15 @@ func TestInFlightRPCStat(t *testing.T) {
}

// We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs.
test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, float64(numRPCs))
test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, float64(numRPCs))

// Unblock the blockedServer to let all of the Chiller.Chill RPCs complete
server.roadblock.Done()
// Sleep for a little bit to let all the RPCs complete
time.Sleep(1 * time.Second)

// Check the gauge value again
test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, 0)
test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, 0)
}

func TestServiceAuthChecker(t *testing.T) {
Expand Down Expand Up @@ -524,3 +411,55 @@ func TestServiceAuthChecker(t *testing.T) {
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertNotError(t, err, "checking allowed cert")
}

// setup creates a server and client, returning the created client, the running server's port, and a stop function.
func setup(t *testing.T, server test_proto.ChillerServer, opts ...any) (test_proto.ChillerClient, int, func()) {
clk := clock.NewFake()
serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")

for _, opt := range opts {
switch optTyped := opt.(type) {
case clock.FakeClock:
clk = optTyped
case clientMetrics:
clientMetricsVal = optTyped
case serverMetrics:
serverMetricsVal = optTyped
default:
t.Fatalf("setup called with unrecognize option %#v", t)
}
}
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
port := lis.Addr().(*net.TCPAddr).Port

si := newServerMetadataInterceptor(serverMetricsVal, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, server)

go func() {
start := time.Now()
err := s.Serve(lis)
if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Logf("s.Serve: %v after %s", err, time.Since(start))
}
}()

ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetricsVal,
clk: clock.NewFake(),
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
return test_proto.NewChillerClient(conn), port, s.Stop
}

0 comments on commit 9a1b4e1

Please sign in to comment.