Skip to content

Commit

Permalink
Add MarshalProtoMessage / UnmarshalProtoMessage to grpcstub
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Jul 24, 2024
1 parent 564153b commit fb06667
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 41 deletions.
74 changes: 33 additions & 41 deletions grpcstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,10 @@ func (s *Server) createUnaryHandler(md protoreflect.MethodDescriptor) func(srv a
if err := dec(in); err != nil {
return nil, err
}
b, err := protojson.MarshalOptions{UseProtoNames: true, UseEnumNumbers: true, EmitUnpopulated: true}.Marshal(in)
m, err := MarshalProtoMessage(in)
if err != nil {
return nil, err
}
m := Message{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}

req := newRequest(md, m)
h, ok := metadata.FromIncomingContext(ctx)
if ok {
Expand Down Expand Up @@ -653,11 +648,7 @@ func (s *Server) createUnaryHandler(md protoreflect.MethodDescriptor) func(srv a
}
mes = dynamicpb.NewMessage(md.Output())
if len(res.Messages) > 0 {
b, err := json.Marshal(res.Messages[0])
if err != nil {
return nil, err
}
if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil {
if err := UnmarshalProtoMessage(res.Messages[0], mes); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -692,14 +683,10 @@ func (s *Server) createServerStreamingHandler(md protoreflect.MethodDescriptor)
if err := stream.RecvMsg(in); err != nil {
return err
}
b, err := protojson.MarshalOptions{UseProtoNames: true, UseEnumNumbers: true, EmitUnpopulated: true}.Marshal(in)
m, err := MarshalProtoMessage(in)
if err != nil {
return err
}
m := Message{}
if err := json.Unmarshal(b, &m); err != nil {
return err
}
r := newRequest(md, m)
h, ok := metadata.FromIncomingContext(stream.Context())
if ok {
Expand Down Expand Up @@ -734,11 +721,7 @@ func (s *Server) createServerStreamingHandler(md protoreflect.MethodDescriptor)
if len(res.Messages) > 0 {
for _, resm := range res.Messages {
mes := dynamicpb.NewMessage(md.Output())
b, err := json.Marshal(resm)
if err != nil {
return err
}
if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil {
if err := UnmarshalProtoMessage(resm, mes); err != nil {
return err
}
if err := stream.SendMsg(mes); err != nil {
Expand All @@ -762,14 +745,10 @@ func (s *Server) createClientStreamingHandler(md protoreflect.MethodDescriptor)
in := dynamicpb.NewMessage(md.Input())
err := stream.RecvMsg(in)
if err == nil {
b, err := protojson.MarshalOptions{UseProtoNames: true, UseEnumNumbers: true, EmitUnpopulated: true}.Marshal(in)
m, err := MarshalProtoMessage(in)
if err != nil {
return err
}
m := Message{}
if err := json.Unmarshal(b, &m); err != nil {
return err
}
r := newRequest(md, m)
h, ok := metadata.FromIncomingContext(stream.Context())
if ok {
Expand Down Expand Up @@ -804,11 +783,7 @@ func (s *Server) createClientStreamingHandler(md protoreflect.MethodDescriptor)
}
mes = dynamicpb.NewMessage(md.Output())
if len(res.Messages) > 0 {
b, err := json.Marshal(res.Messages[0])
if err != nil {
return err
}
if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil {
if err := UnmarshalProtoMessage(res.Messages[0], mes); err != nil {
return err
}
}
Expand Down Expand Up @@ -847,14 +822,10 @@ func (s *Server) createBidiStreamingHandler(md protoreflect.MethodDescriptor) fu
if err != nil {
return err
}
b, err := protojson.MarshalOptions{UseProtoNames: true, UseEnumNumbers: true, EmitUnpopulated: true}.Marshal(in)
m, err := MarshalProtoMessage(in)
if err != nil {
return err
}
m := Message{}
if err := json.Unmarshal(b, &m); err != nil {
return err
}
r := newRequest(md, m)
h, ok := metadata.FromIncomingContext(stream.Context())
if ok {
Expand Down Expand Up @@ -892,11 +863,7 @@ func (s *Server) createBidiStreamingHandler(md protoreflect.MethodDescriptor) fu
if len(res.Messages) > 0 {
for _, resm := range res.Messages {
mes := dynamicpb.NewMessage(md.Output())
b, err := json.Marshal(resm)
if err != nil {
return err
}
if err := (protojson.UnmarshalOptions{}).Unmarshal(b, mes); err != nil {
if err := UnmarshalProtoMessage(resm, mes); err != nil {
return err
}
if err := stream.SendMsg(mes); err != nil {
Expand All @@ -914,6 +881,31 @@ func (s *Server) createBidiStreamingHandler(md protoreflect.MethodDescriptor) fu
}
}

// MarshalProtoMessage marshals [proto.Message] to [Message].
func MarshalProtoMessage(pm protoreflect.ProtoMessage) (Message, error) {
b, err := protojson.MarshalOptions{UseProtoNames: true, UseEnumNumbers: true, EmitUnpopulated: true}.Marshal(pm)
if err != nil {
return nil, err
}
m := Message{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
return m, nil
}

// UnmarshalProtoMessage unmarshals [Message] to [proto.Message].
func UnmarshalProtoMessage(m Message, pm protoreflect.ProtoMessage) error {
b, err := json.Marshal(m)
if err != nil {
return err
}
if err := (protojson.UnmarshalOptions{}).Unmarshal(b, pm); err != nil {
return err
}
return nil
}

func (m *matcher) matchRequest(rs ...*Request) bool {
for _, r := range rs {
for _, fn := range m.matchFuncs {
Expand Down
37 changes: 37 additions & 0 deletions grpcstub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,40 @@ func TestWithConnectClient(t *testing.T) {
t.Errorf("got %v\nwant %v", res.Msg.GetMessage(), want)
}
}

func TestUnmarshalProtoMessage(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Match(func(req *Request) bool {
return req.Method == "GetFeature"
}).Handler(func(req *Request) *Response {
m := &routeguide.Point{}
if err := UnmarshalProtoMessage(req.Message, m); err != nil {
t.Fatal(err)
}
if m.Latitude != 10 || m.Longitude != 13 {
t.Errorf("got %v\nwant %v", m, &routeguide.Point{Latitude: 10, Longitude: 13})
}
return &Response{
Messages: []Message{
{"name": "hello"},
},
}
})

client := routeguide.NewRouteGuideClient(ts.Conn())
res, err := client.GetFeature(ctx, &routeguide.Point{
Latitude: 10,
Longitude: 13,
})
if err != nil {
t.Fatal(err)
}
got := res.Name
if want := "hello"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}

0 comments on commit fb06667

Please sign in to comment.