Skip to content

Commit

Permalink
Merge pull request #1 from k1LoW/status
Browse files Browse the repository at this point in the history
Add response status handling
  • Loading branch information
k1LoW authored Jul 2, 2022
2 parents 8cf7e7d + d4d809b commit a9f9fb2
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 42 deletions.
33 changes: 33 additions & 0 deletions grpcstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,27 @@ func (m *matcher) ResponseString(message string) *matcher {
return m.Response(mes)
}

// Status set status which return response.
func (m *matcher) Status(s *status.Status) *matcher {
var fn handlerFunc
if m.handler == nil {
fn = func(r *Request) *Response {
res := NewResponse()
res.Status = s
return res
}
} else {
prev := m.handler
fn = func(r *Request) *Response {
res := prev(r)
res.Status = s
return res
}
}
m.handler = fn
return m
}

// Requests returns []*grpcstub.Request received by router.
func (s *Server) Requests() []*Request {
s.mu.RLock()
Expand Down Expand Up @@ -406,6 +427,9 @@ func (s *Server) createUnaryHandler(md *desc.MethodDescriptor) func(srv interfac
}
}
}
if res.Status != nil && res.Status.Err() != nil {
return nil, res.Status.Err()
}
mes = msgFactory.NewMessage(md.GetOutputType())
if len(res.Messages) > 0 {
b, err := json.Marshal(res.Messages[0])
Expand Down Expand Up @@ -486,6 +510,9 @@ func (s *Server) createServerStreamingHandler(md *desc.MethodDescriptor) func(sr
stream.SetTrailer(metadata.Pairs(k, vv))
}
}
if res.Status != nil && res.Status.Err() != nil {
return res.Status.Err()
}
if len(res.Messages) > 0 {
for _, resm := range res.Messages {
mes := msgFactory.NewMessage(md.GetOutputType())
Expand Down Expand Up @@ -548,6 +575,9 @@ func (s *Server) createClientStreamingHandler(md *desc.MethodDescriptor) func(sr
m.requests = append(m.requests, r)
m.mu.Unlock()
res := m.handler(r)
if res.Status != nil && res.Status.Err() != nil {
return res.Status.Err()
}
mes = msgFactory.NewMessage(md.GetOutputType())
if len(res.Messages) > 0 {
b, err := json.Marshal(res.Messages[0])
Expand Down Expand Up @@ -636,6 +666,9 @@ func (s *Server) createBiStreamingHandler(md *desc.MethodDescriptor) func(srv in
stream.SetTrailer(metadata.Pairs(k, vv))
}
}
if res.Status != nil && res.Status.Err() != nil {
return res.Status.Err()
}
if len(res.Messages) > 0 {
for _, resm := range res.Messages {
mes := msgFactory.NewMessage(md.GetOutputType())
Expand Down
255 changes: 213 additions & 42 deletions grpcstub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,48 +52,6 @@ func TestUnary(t *testing.T) {
}
}

func TestClientStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Response(map[string]interface{}{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}
c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
res, err := stream.CloseAndRecv()
if err != nil {
t.Fatal(err)
}

{
got := res.PointCount
if want := int32(2); got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.Requests())
if want := 2; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestServerStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
Expand Down Expand Up @@ -151,6 +109,48 @@ func TestServerStreaming(t *testing.T) {
}
}

func TestClientStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Response(map[string]interface{}{"point_count": 2, "feature_count": 2, "distance": 10, "elapsed_time": 345})

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}
c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
res, err := stream.CloseAndRecv()
if err != nil {
t.Fatal(err)
}

{
got := res.PointCount
if want := int32(2); got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

{
got := len(ts.Requests())
if want := 2; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestBiStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
Expand Down Expand Up @@ -417,3 +417,174 @@ func TestResponseHeader(t *testing.T) {
t.Errorf("got %v\nwant %v", got[0], want)
}
}

func TestStatusUnary(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("GetFeature").Status(status.New(codes.Aborted, "aborted"))
client := routeguide.NewRouteGuideClient(ts.Conn())

_, err := client.GetFeature(ctx, &routeguide.Point{})
if err == nil {
t.Error("want error")
return
}

s, ok := status.FromError(err)
if !ok {
t.Error("want status.Status")
return
}
{
got := s.Code()
if want := codes.Aborted; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := s.Message()
if want := "aborted"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestStatusServerStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("ListFeatures").Status(status.New(codes.Aborted, "aborted"))

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.ListFeatures(ctx, &routeguide.Rectangle{
Lo: &routeguide.Point{
Latitude: int32(10),
Longitude: int32(2),
},
Hi: &routeguide.Point{
Latitude: int32(20),
Longitude: int32(7),
},
})
if err != nil {
t.Fatal(err)
}

_, err = stream.Recv()
if err == nil {
t.Error("want error")
}
s, ok := status.FromError(err)
if !ok {
t.Error("want status.Status")
return
}
{
got := s.Code()
if want := codes.Aborted; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := s.Message()
if want := "aborted"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestStatusClientStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RecordRoute").Status(status.New(codes.Aborted, "aborted"))

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RecordRoute(ctx)
if err != nil {
t.Fatal(err)
}

c := 2
for i := 0; i < c; i++ {
if err := stream.Send(&routeguide.Point{
Latitude: int32(i + 10),
Longitude: int32(i * i * 2),
}); err != nil {
t.Fatal(err)
}
}
_, err = stream.CloseAndRecv()
if err == nil {
t.Error("want error")
return
}

s, ok := status.FromError(err)
if !ok {
t.Error("want status.Status")
return
}
{
got := s.Code()
if want := codes.Aborted; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := s.Message()
if want := "aborted"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

func TestStatusBiStreaming(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, []string{}, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Method("RouteChat").Status(status.New(codes.Aborted, "aborted"))

client := routeguide.NewRouteGuideClient(ts.Conn())
stream, err := client.RouteChat(ctx)
if err != nil {
t.Fatal(err)
}
if err := stream.SendMsg(&routeguide.RouteNote{
Message: "hello from client",
}); err != nil {
t.Fatal(err)
}
_, err = stream.Recv()
if err == nil {
t.Error("want error")
return
}

s, ok := status.FromError(err)
if !ok {
t.Error("want status.Status")
return
}
{
got := s.Code()
if want := codes.Aborted; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := s.Message()
if want := "aborted"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

0 comments on commit a9f9fb2

Please sign in to comment.