Skip to content

Commit

Permalink
Merge pull request #56 from k1LoW/unmatched-requests
Browse files Browse the repository at this point in the history
Requests() returns matched requests only and UnmatchedRequests() returns unmatched requests
  • Loading branch information
k1LoW authored Sep 22, 2023
2 parents bb0723b + 47b10fb commit 8322fdd
Show file tree
Hide file tree
Showing 6 changed files with 519 additions and 309 deletions.
136 changes: 136 additions & 0 deletions bidistreaming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package grpcstub

import (
"context"
"errors"
"fmt"
"io"
"strings"
"testing"

"github.com/k1LoW/grpcstub/testdata/routeguide"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestBidiStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RouteChat").Match(func(r *Request) bool {
m, ok := r.Message["message"]
if !ok {
return false
}
return strings.Contains(m.(string), "hello from client[0]")
}).Header("hello", "header").
Response(map[string]any{"location": nil, "message": "hello from server[0]"})
ts.Method("RouteChat").
Header("hello", "header").
Handler(func(r *Request) *Response {
res := NewResponse()
m, ok := r.Message["message"]
if !ok {
res.Status = status.New(codes.Unknown, codes.Unknown.String())
return res
}
mes := Message{}
mes["message"] = strings.Replace(m.(string), "client", "server", 1)
res.Messages = []Message{mes}
return res
})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RouteChat(ctx)
if err != nil {
t.Fatal(err)
}
max := 5
c := 0
recvCount := 0
var sendEnd, recvEnd bool
for !(sendEnd && recvEnd) {
if !sendEnd {
if err := stream.SendMsg(&routeguide.RouteNote{
Message: fmt.Sprintf("hello from client[%d]", c),
}); err != nil {
t.Error(err)
sendEnd = true
}
c++
if c == max {
sendEnd = true
if err := stream.CloseSend(); err != nil {
t.Error(err)
}
}
}

if !recvEnd {
if res, err := stream.Recv(); err != nil {
if !errors.Is(err, io.EOF) {
t.Error(err)
}
recvEnd = true
} else {
recvCount++
got := res.Message
if want := fmt.Sprintf("hello from server[%d]", recvCount-1); got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
}
if recvCount != max {
t.Errorf("got %v\nwant %v", recvCount, max)
}

{
got := len(ts.Requests())
if want := max; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestBidiStreamingUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RouteChat").Match(func(r *Request) bool {
return false
}).Header("hello", "header").
Response(map[string]any{"location": nil, "message": "hello from server[0]"})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RouteChat(ctx)
if err != nil {
t.Fatal(err)
}
if err := stream.SendMsg(&routeguide.RouteNote{
Message: fmt.Sprintf("hello from client[%d]", 0),
}); err != nil {
t.Error("want error")
}
if _, err := stream.Recv(); err == nil {
t.Error("want error")
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.UnmatchedRequests())
if want := 1; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
93 changes: 93 additions & 0 deletions clientstreaming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package grpcstub

import (
"context"
"testing"

"github.com/k1LoW/grpcstub/testdata/routeguide"
)

func TestClientStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Response(map[string]any{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}
c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
res, err := stream.CloseAndRecv()
if err != nil {
t.Fatal(err)
}

{
got := res.PointCount
if want := int32(2); got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.Requests())
if want := 2; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestClientStreamingUnmatched(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Match(func(r *Request) bool {
return false
}).Response(map[string]any{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}
c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
if _, err := stream.CloseAndRecv(); err == nil {
t.Error("want error")
}

{
got := len(ts.Requests())
if want := 0; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.UnmatchedRequests())
if want := 2; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}
Loading

0 comments on commit 8322fdd

Please sign in to comment.