diff --git a/bidistreaming_test.go b/bidistreaming_test.go new file mode 100644 index 0000000..6612c86 --- /dev/null +++ b/bidistreaming_test.go @@ -0,0 +1,96 @@ +package grpcstub + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "testing" + + "github.com/k1LoW/grpcstub/testdata/routeguide" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestBidiStreaming(t *testing.T) { + ctx := context.Background() + ts := NewServer(t, "testdata/route_guide.proto") + t.Cleanup(func() { + ts.Close() + }) + ts.Method("RouteChat").Match(func(r *Request) bool { + m, ok := r.Message["message"] + if !ok { + return false + } + return strings.Contains(m.(string), "hello from client[0]") + }).Header("hello", "header"). + Response(map[string]any{"location": nil, "message": "hello from server[0]"}) + ts.Method("RouteChat"). + Header("hello", "header"). + Handler(func(r *Request) *Response { + res := NewResponse() + m, ok := r.Message["message"] + if !ok { + res.Status = status.New(codes.Unknown, codes.Unknown.String()) + return res + } + mes := Message{} + mes["message"] = strings.Replace(m.(string), "client", "server", 1) + res.Messages = []Message{mes} + return res + }) + + client := routeguide.NewRouteGuideClient(ts.Conn()) + stream, err := client.RouteChat(ctx) + if err != nil { + t.Fatal(err) + } + max := 5 + c := 0 + recvCount := 0 + var sendEnd, recvEnd bool + for !(sendEnd && recvEnd) { + if !sendEnd { + if err := stream.SendMsg(&routeguide.RouteNote{ + Message: fmt.Sprintf("hello from client[%d]", c), + }); err != nil { + t.Error(err) + sendEnd = true + } + c++ + if c == max { + sendEnd = true + if err := stream.CloseSend(); err != nil { + t.Error(err) + } + } + } + + if !recvEnd { + if res, err := stream.Recv(); err != nil { + if !errors.Is(err, io.EOF) { + t.Error(err) + } + recvEnd = true + } else { + recvCount++ + got := res.Message + if want := fmt.Sprintf("hello from server[%d]", recvCount-1); got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + } + } + if recvCount != max { + t.Errorf("got %v\nwant %v", recvCount, max) + } + + { + got := len(ts.Requests()) + if want := max; got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +} diff --git a/clientstreaming_test.go b/clientstreaming_test.go new file mode 100644 index 0000000..cf5c528 --- /dev/null +++ b/clientstreaming_test.go @@ -0,0 +1,50 @@ +package grpcstub + +import ( + "context" + "testing" + + "github.com/k1LoW/grpcstub/testdata/routeguide" +) + +func TestClientStreaming(t *testing.T) { + ctx := context.Background() + ts := NewServer(t, "testdata/route_guide.proto") + t.Cleanup(func() { + ts.Close() + }) + ts.Method("RecordRoute").Response(map[string]any{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345}) + + client := routeguide.NewRouteGuideClient(ts.Conn()) + stream, err := client.RecordRoute(ctx) + if err != nil { + t.Fatal(err) + } + c := 2 + for i := 0; i < c; i++ { + if err := stream.Send(&routeguide.Point{ + Latitude: int32(i + 10), + Longitude: int32(i * i * 2), + }); err != nil { + t.Fatal(err) + } + } + res, err := stream.CloseAndRecv() + if err != nil { + t.Fatal(err) + } + + { + got := res.PointCount + if want := int32(2); got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + + { + got := len(ts.Requests()) + if want := 2; got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +} diff --git a/grpcstub.go b/grpcstub.go index e274956..8bc70bd 100644 --- a/grpcstub.go +++ b/grpcstub.go @@ -528,13 +528,7 @@ func (s *Server) createUnaryHandler(md protoreflect.MethodDescriptor) func(srv a var mes *dynamicpb.Message for _, m := range s.matchers { - match := true - for _, fn := range m.matchFuncs { - if !fn(r) { - match = false - } - } - if !match { + if !m.matchRequest(r) { continue } m.mu.Lock() @@ -571,6 +565,9 @@ func (s *Server) createUnaryHandler(md protoreflect.MethodDescriptor) func(srv a return mes, nil } + s.mu.Lock() + s.unmatchedRequests = append(s.unmatchedRequests, r) + s.mu.Unlock() return mes, status.Error(codes.NotFound, codes.NotFound.String()) } } @@ -582,7 +579,7 @@ func (s *Server) createStreamHandler(md protoreflect.MethodDescriptor) func(srv case md.IsStreamingClient() && !md.IsStreamingServer(): return s.createClientStreamingHandler(md) case md.IsStreamingClient() && md.IsStreamingServer(): - return s.createBiStreamingHandler(md) + return s.createBidiStreamingHandler(md) default: return func(srv any, stream grpc.ServerStream) error { return nil @@ -613,13 +610,7 @@ func (s *Server) createServerStreamingHandler(md protoreflect.MethodDescriptor) s.requests = append(s.requests, r) s.mu.Unlock() for _, m := range s.matchers { - match := true - for _, fn := range m.matchFuncs { - if !fn(r) { - match = false - } - } - if !match { + if !m.matchRequest(r) { continue } m.mu.Lock() @@ -656,8 +647,12 @@ func (s *Server) createServerStreamingHandler(md protoreflect.MethodDescriptor) } } } + return nil } - return nil + s.mu.Lock() + s.unmatchedRequests = append(s.unmatchedRequests, r) + s.mu.Unlock() + return status.Error(codes.NotFound, codes.NotFound.String()) } } @@ -685,59 +680,59 @@ func (s *Server) createClientStreamingHandler(md protoreflect.MethodDescriptor) s.requests = append(s.requests, r) s.mu.Unlock() rs = append(rs, r) + continue } - if err == io.EOF { - var mes *dynamicpb.Message - for _, r := range rs { - for _, m := range s.matchers { - match := true - for _, fn := range m.matchFuncs { - if !fn(r) { - match = false - } - } - if !match { - continue - } - m.mu.Lock() - m.requests = append(m.requests, r) - m.mu.Unlock() - res := m.handler(r, md) - if res.Status != nil && res.Status.Err() != nil { - return res.Status.Err() - } - mes = dynamicpb.NewMessage(md.Output()) - if len(res.Messages) > 0 { - b, err := json.Marshal(res.Messages[0]) - if err != nil { - return err - } - if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil { - return err - } - } - for k, v := range res.Headers { - for _, vv := range v { - if err := stream.SendHeader(metadata.Pairs(k, vv)); err != nil { - return err - } - } - } - for k, v := range res.Trailers { - for _, vv := range v { - stream.SetTrailer((metadata.Pairs(k, vv))) - } + + if err != io.EOF { + return err + } + + var mes *dynamicpb.Message + for _, m := range s.matchers { + if !m.matchRequest(rs...) { + continue + } + m.mu.Lock() + m.requests = append(m.requests, rs...) + m.mu.Unlock() + last := rs[len(rs)-1] + res := m.handler(last, md) + if res.Status != nil && res.Status.Err() != nil { + return res.Status.Err() + } + mes = dynamicpb.NewMessage(md.Output()) + if len(res.Messages) > 0 { + b, err := json.Marshal(res.Messages[0]) + if err != nil { + return err + } + if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil { + return err + } + } + for k, v := range res.Headers { + for _, vv := range v { + if err := stream.SendHeader(metadata.Pairs(k, vv)); err != nil { + return err } - return stream.SendMsg(mes) } } - return status.Error(codes.NotFound, codes.NotFound.String()) + for k, v := range res.Trailers { + for _, vv := range v { + stream.SetTrailer((metadata.Pairs(k, vv))) + } + } + return stream.SendMsg(mes) } + s.mu.Lock() + s.unmatchedRequests = append(s.unmatchedRequests, rs...) + s.mu.Unlock() + return status.Error(codes.NotFound, codes.NotFound.String()) } } } -func (s *Server) createBiStreamingHandler(md protoreflect.MethodDescriptor) func(srv any, stream grpc.ServerStream) error { +func (s *Server) createBidiStreamingHandler(md protoreflect.MethodDescriptor) func(srv any, stream grpc.ServerStream) error { return func(srv any, stream grpc.ServerStream) error { headerSent := false L: @@ -767,13 +762,7 @@ func (s *Server) createBiStreamingHandler(md protoreflect.MethodDescriptor) func s.requests = append(s.requests, r) s.mu.Unlock() for _, m := range s.matchers { - match := true - for _, fn := range m.matchFuncs { - if !fn(r) { - match = false - } - } - if !match { + if !m.matchRequest(r) { continue } m.mu.Lock() @@ -815,11 +804,25 @@ func (s *Server) createBiStreamingHandler(md protoreflect.MethodDescriptor) func } continue L } + s.mu.Lock() + s.unmatchedRequests = append(s.unmatchedRequests, r) + s.mu.Unlock() return status.Error(codes.NotFound, codes.NotFound.String()) } } } +func (m *matcher) matchRequest(rs ...*Request) bool { + for _, r := range rs { + for _, fn := range m.matchFuncs { + if !fn(r) { + return false + } + } + } + return true +} + func serviceMatchFunc(service string) matchFunc { return func(r *Request) bool { return r.Service == strings.TrimPrefix(service, "/") diff --git a/grpcstub_test.go b/grpcstub_test.go index d9253ab..618a8f6 100644 --- a/grpcstub_test.go +++ b/grpcstub_test.go @@ -2,9 +2,7 @@ package grpcstub import ( "context" - "errors" "fmt" - "io" "os" "strings" "testing" @@ -22,235 +20,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func TestUnary(t *testing.T) { - ctx := context.Background() - ts := NewServer(t, "testdata/route_guide.proto") - t.Cleanup(func() { - ts.Close() - }) - ts.Method("GetFeature").Response(map[string]any{"name": "hello", "location": map[string]any{"latitude": 10, "longitude": 13}}) - ts.Method("GetFeature").Response(map[string]any{"name": "hello", "location": map[string]any{"latitude": 99, "longitude": 99}}) - - client := routeguide.NewRouteGuideClient(ts.Conn()) - res, err := client.GetFeature(ctx, &routeguide.Point{ - Latitude: 10, - Longitude: 13, - }) - if err != nil { - t.Fatal(err) - } - { - got := res.Name - if want := "hello"; got != want { - t.Errorf("got %v\nwant %v", got, want) - return - } - } - { - got := res.Location.Latitude - if want := int32(10); got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } - - { - got := len(ts.Requests()) - if want := 1; got != want { - t.Errorf("got %v\nwant %v", got, want) - return - } - } - - req := ts.Requests()[0] - { - got := int32(req.Message["longitude"].(float64)) - if want := int32(13); got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } -} - -func TestServerStreaming(t *testing.T) { - ctx := context.Background() - ts := NewServer(t, "testdata/route_guide.proto") - t.Cleanup(func() { - ts.Close() - }) - ts.Method("ListFeatures").Response(map[string]any{"name": "hello"}).Response(map[string]any{"name": "world"}) - - client := routeguide.NewRouteGuideClient(ts.Conn()) - stream, err := client.ListFeatures(ctx, &routeguide.Rectangle{ - Lo: &routeguide.Point{ - Latitude: int32(10), - Longitude: int32(2), - }, - Hi: &routeguide.Point{ - Latitude: int32(20), - Longitude: int32(7), - }, - }) - if err != nil { - t.Fatal(err) - } - - c := 0 - for { - res, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - t.Fatal(err) - } - switch c { - case 0: - got := res.Name - if want := "hello"; got != want { - t.Errorf("got %v\nwant %v", got, want) - } - case 1: - got := res.Name - if want := "world"; got != want { - t.Errorf("got %v\nwant %v", got, want) - } - default: - t.Errorf("recv messages got %v\nwant %v", c+1, 2) - } - c++ - } - - { - got := len(ts.Requests()) - if want := 1; got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } -} - -func TestClientStreaming(t *testing.T) { - ctx := context.Background() - ts := NewServer(t, "testdata/route_guide.proto") - t.Cleanup(func() { - ts.Close() - }) - ts.Method("RecordRoute").Response(map[string]any{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345}) - - client := routeguide.NewRouteGuideClient(ts.Conn()) - stream, err := client.RecordRoute(ctx) - if err != nil { - t.Fatal(err) - } - c := 2 - for i := 0; i < c; i++ { - if err := stream.Send(&routeguide.Point{ - Latitude: int32(i + 10), - Longitude: int32(i * i * 2), - }); err != nil { - t.Fatal(err) - } - } - res, err := stream.CloseAndRecv() - if err != nil { - t.Fatal(err) - } - - { - got := res.PointCount - if want := int32(2); got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } - - { - got := len(ts.Requests()) - if want := 2; got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } -} - -func TestBiStreaming(t *testing.T) { - ctx := context.Background() - ts := NewServer(t, "testdata/route_guide.proto") - t.Cleanup(func() { - ts.Close() - }) - ts.Method("RouteChat").Match(func(r *Request) bool { - m, ok := r.Message["message"] - if !ok { - return false - } - return strings.Contains(m.(string), "hello from client[0]") - }).Header("hello", "header"). - Response(map[string]any{"location": nil, "message": "hello from server[0]"}) - ts.Method("RouteChat"). - Header("hello", "header"). - Handler(func(r *Request) *Response { - res := NewResponse() - m, ok := r.Message["message"] - if !ok { - res.Status = status.New(codes.Unknown, codes.Unknown.String()) - return res - } - mes := Message{} - mes["message"] = strings.Replace(m.(string), "client", "server", 1) - res.Messages = []Message{mes} - return res - }) - - client := routeguide.NewRouteGuideClient(ts.Conn()) - stream, err := client.RouteChat(ctx) - if err != nil { - t.Fatal(err) - } - max := 5 - c := 0 - recvCount := 0 - var sendEnd, recvEnd bool - for !(sendEnd && recvEnd) { - if !sendEnd { - if err := stream.SendMsg(&routeguide.RouteNote{ - Message: fmt.Sprintf("hello from client[%d]", c), - }); err != nil { - t.Error(err) - sendEnd = true - } - c++ - if c == max { - sendEnd = true - if err := stream.CloseSend(); err != nil { - t.Error(err) - } - } - } - - if !recvEnd { - if res, err := stream.Recv(); err != nil { - if !errors.Is(err, io.EOF) { - t.Error(err) - } - recvEnd = true - } else { - recvCount++ - got := res.Message - if want := fmt.Sprintf("hello from server[%d]", recvCount-1); got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } - } - } - if recvCount != max { - t.Errorf("got %v\nwant %v", recvCount, max) - } - - { - got := len(ts.Requests()) - if want := max; got != want { - t.Errorf("got %v\nwant %v", got, want) - } - } -} - func TestAddr(t *testing.T) { ts := NewServer(t, "testdata/route_guide.proto") t.Cleanup(func() { diff --git a/serverstreaming_test.go b/serverstreaming_test.go new file mode 100644 index 0000000..10df9ee --- /dev/null +++ b/serverstreaming_test.go @@ -0,0 +1,67 @@ +package grpcstub + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/k1LoW/grpcstub/testdata/routeguide" +) + +func TestServerStreaming(t *testing.T) { + ctx := context.Background() + ts := NewServer(t, "testdata/route_guide.proto") + t.Cleanup(func() { + ts.Close() + }) + ts.Method("ListFeatures").Response(map[string]any{"name": "hello"}).Response(map[string]any{"name": "world"}) + + client := routeguide.NewRouteGuideClient(ts.Conn()) + stream, err := client.ListFeatures(ctx, &routeguide.Rectangle{ + Lo: &routeguide.Point{ + Latitude: int32(10), + Longitude: int32(2), + }, + Hi: &routeguide.Point{ + Latitude: int32(20), + Longitude: int32(7), + }, + }) + if err != nil { + t.Fatal(err) + } + + c := 0 + for { + res, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatal(err) + } + switch c { + case 0: + got := res.Name + if want := "hello"; got != want { + t.Errorf("got %v\nwant %v", got, want) + } + case 1: + got := res.Name + if want := "world"; got != want { + t.Errorf("got %v\nwant %v", got, want) + } + default: + t.Errorf("recv messages got %v\nwant %v", c+1, 2) + } + c++ + } + + { + got := len(ts.Requests()) + if want := 1; got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +} diff --git a/unary_test.go b/unary_test.go new file mode 100644 index 0000000..ee784af --- /dev/null +++ b/unary_test.go @@ -0,0 +1,56 @@ +package grpcstub + +import ( + "context" + "testing" + + "github.com/k1LoW/grpcstub/testdata/routeguide" +) + +func TestUnary(t *testing.T) { + ctx := context.Background() + ts := NewServer(t, "testdata/route_guide.proto") + t.Cleanup(func() { + ts.Close() + }) + ts.Method("GetFeature").Response(map[string]any{"name": "hello", "location": map[string]any{"latitude": 10, "longitude": 13}}) + ts.Method("GetFeature").Response(map[string]any{"name": "hello", "location": map[string]any{"latitude": 99, "longitude": 99}}) + + client := routeguide.NewRouteGuideClient(ts.Conn()) + res, err := client.GetFeature(ctx, &routeguide.Point{ + Latitude: 10, + Longitude: 13, + }) + if err != nil { + t.Fatal(err) + } + { + got := res.Name + if want := "hello"; got != want { + t.Errorf("got %v\nwant %v", got, want) + return + } + } + { + got := res.Location.Latitude + if want := int32(10); got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + + { + got := len(ts.Requests()) + if want := 1; got != want { + t.Errorf("got %v\nwant %v", got, want) + return + } + } + + req := ts.Requests()[0] + { + got := int32(req.Message["longitude"].(float64)) + if want := int32(13); got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +}