Skip to content

Commit

Permalink
Requests() returns matched requests only
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Sep 22, 2023
1 parent 1a6eee7 commit 845d064
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 12 deletions.
40 changes: 40 additions & 0 deletions bidistreaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,43 @@ func TestBidiStreaming(t *testing.T) {
}
}
}

func TestBidiStreamingUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RouteChat").Match(func(r *Request) bool {
return false
}).Header("hello", "header").
Response(map[string]any{"location": nil, "message": "hello from server[0]"})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RouteChat(ctx)
if err != nil {
t.Fatal(err)
}
if err := stream.SendMsg(&routeguide.RouteNote{
Message: fmt.Sprintf("hello from client[%d]", 0),
}); err != nil {
t.Error("want error")
}
if _, err := stream.Recv(); err == nil {
t.Error("want error")
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.UnmatchedRequests())
if want := 1; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
43 changes: 43 additions & 0 deletions clientstreaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,46 @@ func TestClientStreaming(t *testing.T) {
}
}
}

func TestClientStreamingUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Match(func(r *Request) bool {
return false
}).Response(map[string]any{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}
c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
if _, err := stream.CloseAndRecv(); err == nil {
t.Error("want error")
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.UnmatchedRequests())
if want := 2; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
34 changes: 22 additions & 12 deletions grpcstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ func (s *Server) Requests() []*Request {
return s.requests
}

// UnmatchedRequests returns []*grpcstub.Request received but not matched by router.
func (s *Server) UnmatchedRequests() []*Request {
s.mu.RLock()
defer s.mu.RUnlock()
return s.unmatchedRequests
}

// Requests returns []*grpcstub.Request received by matcher.
func (m *matcher) Requests() []*Request {
m.mu.RLock()
Expand Down Expand Up @@ -522,15 +529,15 @@ func (s *Server) createUnaryHandler(md protoreflect.MethodDescriptor) func(srv a
if ok {
r.Headers = h
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()

var mes *dynamicpb.Message
for _, m := range s.matchers {
if !m.matchRequest(r) {
continue
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
m.mu.Lock()
m.requests = append(m.requests, r)
m.mu.Unlock()
Expand Down Expand Up @@ -606,16 +613,16 @@ func (s *Server) createServerStreamingHandler(md protoreflect.MethodDescriptor)
if ok {
r.Headers = h
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
for _, m := range s.matchers {
if !m.matchRequest(r) {
continue
}
m.mu.Lock()
m.requests = append(m.requests, r)
m.mu.Unlock()
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
res := m.handler(r, md)
for k, v := range res.Headers {
for _, vv := range v {
Expand Down Expand Up @@ -676,14 +683,14 @@ func (s *Server) createClientStreamingHandler(md protoreflect.MethodDescriptor)
if ok {
r.Headers = h
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
rs = append(rs, r)
continue
}

if err != io.EOF {
s.mu.Lock()
s.unmatchedRequests = append(s.unmatchedRequests, rs...)
s.mu.Unlock()
return err
}

Expand All @@ -692,6 +699,9 @@ func (s *Server) createClientStreamingHandler(md protoreflect.MethodDescriptor)
if !m.matchRequest(rs...) {
continue
}
s.mu.Lock()
s.requests = append(s.requests, rs...)
s.mu.Unlock()
m.mu.Lock()
m.requests = append(m.requests, rs...)
m.mu.Unlock()
Expand Down Expand Up @@ -758,13 +768,13 @@ func (s *Server) createBidiStreamingHandler(md protoreflect.MethodDescriptor) fu
if ok {
r.Headers = h
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
for _, m := range s.matchers {
if !m.matchRequest(r) {
continue
}
s.mu.Lock()
s.requests = append(s.requests, r)
s.mu.Unlock()
m.mu.Lock()
m.requests = append(m.requests, r)
m.mu.Unlock()
Expand Down
50 changes: 50 additions & 0 deletions serverstreaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,53 @@ func TestServerStreaming(t *testing.T) {
}
}
}

func TestServerStreamingUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("ListFeatures").Match(func(r *Request) bool {
return false
}).Response(map[string]any{"name": "hello"}).Response(map[string]any{"name": "world"})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.ListFeatures(ctx, &routeguide.Rectangle{
Lo: &routeguide.Point{
Latitude: int32(10),
Longitude: int32(2),
},
Hi: &routeguide.Point{
Latitude: int32(20),
Longitude: int32(7),
},
})
if err != nil {
t.Fatal(err)
}

for {
_, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err == nil {
t.Error("want error")
}
break

Check failure on line 102 in serverstreaming_test.go

View workflow job for this annotation

GitHub Actions / Test

SA4004: the surrounding loop is unconditionally terminated (staticcheck)

Check failure on line 102 in serverstreaming_test.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] serverstreaming_test.go#L102

SA4004: the surrounding loop is unconditionally terminated (staticcheck)
Raw output
serverstreaming_test.go:102:3: SA4004: the surrounding loop is unconditionally terminated (staticcheck)
		break
		^
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := len(ts.UnmatchedRequests())
if want := 1; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
33 changes: 33 additions & 0 deletions unary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,36 @@ func TestUnary(t *testing.T) {
}
}
}

func TestUnaryUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("GetFeature").Match(func(r *Request) bool {
return false
}).Response(map[string]any{"name": "hello", "location": map[string]any{"latitude": 10, "longitude": 13}})

client := routeguide.NewRouteGuideClient(ts.Conn())
_, err := client.GetFeature(ctx, &routeguide.Point{
Latitude: 10,
Longitude: 13,
})
if err == nil {
t.Error("want error")
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := len(ts.UnmatchedRequests())
if want := 1; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

0 comments on commit 845d064

Please sign in to comment.