Skip to content

Commit

Permalink
Update to new proto format (#237)
Browse files Browse the repository at this point in the history
## tl;dr

- Companion PR to xmtp/proto#223 which updates everything to the new format
  • Loading branch information
neekolas authored Oct 18, 2024
1 parent d3fe5c1 commit 235dad6
Show file tree
Hide file tree
Showing 31 changed files with 2,341 additions and 1,200 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Nightly Automation
on:
schedule:
- cron: '0 10 * * *'
- cron: "0 10 * * *"
workflow_dispatch:
jobs:
nightly-protos:
Expand All @@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-go@v5
- uses: bufbuild/buf-setup-action@v1.40.1
- name: Generate Protos
run: dev/gen_protos
run: dev/gen-protos
- name: Create Pull Request
uses: peter-evans/create-pull-request@v7
with:
Expand All @@ -26,4 +26,4 @@ jobs:
Auto-generated by [create-pull-request][1]
[1]: https://github.com/peter-evans/create-pull-request
branch: nightly-proto
branch: nightly-proto
File renamed without changes.
2 changes: 1 addition & 1 deletion dev/generate
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -euo pipefail

./dev/gen_protos
./dev/gen-protos
sqlc generate
go generate ./...
rm -rf pkg/mocks/*
Expand Down
31 changes: 16 additions & 15 deletions pkg/api/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
Expand All @@ -20,24 +21,24 @@ func TestPublishEnvelope(t *testing.T) {

payerEnvelope := envelopeTestUtils.CreatePayerEnvelope(t)

resp, err := api.PublishEnvelopes(
resp, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{payerEnvelope},
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{payerEnvelope},
},
)
require.NoError(t, err)
require.NotNil(t, resp)

unsignedEnv := &message_api.UnsignedOriginatorEnvelope{}
unsignedEnv := &envelopes.UnsignedOriginatorEnvelope{}
require.NoError(
t,
proto.Unmarshal(
resp.GetOriginatorEnvelopes()[0].GetUnsignedOriginatorEnvelope(),
unsignedEnv,
),
)
clientEnv := &message_api.ClientEnvelope{}
clientEnv := &envelopes.ClientEnvelope{}
require.NoError(
t,
proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv),
Expand All @@ -56,7 +57,7 @@ func TestPublishEnvelope(t *testing.T) {
return false
}

originatorEnv := &message_api.OriginatorEnvelope{}
originatorEnv := &envelopes.OriginatorEnvelope{}
require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv))
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelopes()[0])
}, 500*time.Millisecond, 50*time.Millisecond)
Expand All @@ -68,10 +69,10 @@ func TestUnmarshalErrorOnPublish(t *testing.T) {

envelope := envelopeTestUtils.CreatePayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := api.PublishEnvelopes(
_, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{envelope},
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{envelope},
},
)
require.ErrorContains(t, err, "invalid wire-format data")
Expand All @@ -83,10 +84,10 @@ func TestMismatchingOriginatorOnPublish(t *testing.T) {

clientEnv := envelopeTestUtils.CreateClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := api.PublishEnvelopes(
_, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
Expand All @@ -100,10 +101,10 @@ func TestMissingTopicOnPublish(t *testing.T) {

clientEnv := envelopeTestUtils.CreateClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := api.PublishEnvelopes(
_, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
Expand Down
11 changes: 6 additions & 5 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
Expand Down Expand Up @@ -145,7 +146,7 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
LastSeen: &message_api.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 2}},
LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 2}},
},
Limit: 0,
},
Expand All @@ -164,7 +165,7 @@ func TestQueryTopicFromLastSeen(t *testing.T) {
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA")},
LastSeen: &message_api.VectorClock{
LastSeen: &envelopes.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1},
},
},
Expand All @@ -185,7 +186,7 @@ func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), db.Topic("topicB")},
LastSeen: &message_api.VectorClock{
LastSeen: &envelopes.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1},
},
},
Expand All @@ -206,7 +207,7 @@ func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
OriginatorNodeIds: []uint32{1, 2},
LastSeen: &message_api.VectorClock{
LastSeen: &envelopes.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 1, 2: 1},
},
},
Expand Down Expand Up @@ -257,7 +258,7 @@ func checkRowsMatchProtos(
t *testing.T,
allRows []queries.InsertGatewayEnvelopeParams,
matchingIndices []int,
protos []*message_api.OriginatorEnvelope,
protos []*envelopes.OriginatorEnvelope,
) {
require.Len(t, protos, len(matchingIndices))
for i, p := range protos {
Expand Down
21 changes: 11 additions & 10 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/utils"
Expand Down Expand Up @@ -131,9 +132,9 @@ func (s *Service) QueryEnvelopes(
return nil, status.Errorf(codes.Internal, "could not select envelopes: %v", err)
}

envs := make([]*message_api.OriginatorEnvelope, 0, len(rows))
envs := make([]*envelopesProto.OriginatorEnvelope, 0, len(rows))
for _, row := range rows {
originatorEnv := &message_api.OriginatorEnvelope{}
originatorEnv := &envelopesProto.OriginatorEnvelope{}
err := proto.Unmarshal(row.OriginatorEnvelope, originatorEnv)
if err != nil {
// We expect to have already validated the envelope when it was inserted
Expand Down Expand Up @@ -213,10 +214,10 @@ func (s *Service) queryReqToDBParams(
return &params, nil
}

func (s *Service) PublishEnvelopes(
func (s *Service) PublishPayerEnvelopes(
ctx context.Context,
req *message_api.PublishEnvelopesRequest,
) (*message_api.PublishEnvelopesResponse, error) {
req *message_api.PublishPayerEnvelopesRequest,
) (*message_api.PublishPayerEnvelopesResponse, error) {
if len(req.GetPayerEnvelopes()) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing payer envelope")
}
Expand All @@ -231,7 +232,7 @@ func (s *Service) PublishEnvelopes(
return nil, err
}
if didPublish {
return &message_api.PublishEnvelopesResponse{}, nil
return &message_api.PublishPayerEnvelopesResponse{}, nil
}

// TODO(rich): Properly support batch publishing
Expand All @@ -257,16 +258,16 @@ func (s *Service) PublishEnvelopes(
return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err)
}

return &message_api.PublishEnvelopesResponse{
OriginatorEnvelopes: []*message_api.OriginatorEnvelope{originatorEnv},
return &message_api.PublishPayerEnvelopesResponse{
OriginatorEnvelopes: []*envelopesProto.OriginatorEnvelope{originatorEnv},
}, nil
}

func (s *Service) maybePublishToBlockchain(
ctx context.Context,
clientEnv *envelopes.ClientEnvelope,
) (didPublish bool, err error) {
payload, ok := clientEnv.Payload().(*message_api.ClientEnvelope_IdentityUpdate)
payload, ok := clientEnv.Payload().(*envelopesProto.ClientEnvelope_IdentityUpdate)
if ok && payload.IdentityUpdate != nil {
if err = s.publishIdentityUpdate(ctx, payload.IdentityUpdate); err != nil {
s.log.Error("could not publish identity update", zap.Error(err))
Expand Down Expand Up @@ -340,7 +341,7 @@ func (s *Service) GetInboxIds(
}

func (s *Service) validatePayerEnvelope(
rawEnv *message_api.PayerEnvelope,
rawEnv *envelopesProto.PayerEnvelope,
) (*envelopes.PayerEnvelope, error) {
payerEnv, err := envelopes.NewPayerEnvelope(rawEnv)
if err != nil {
Expand Down
15 changes: 8 additions & 7 deletions pkg/api/subscribeWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
Expand All @@ -22,7 +23,7 @@ const (

type listener struct {
ctx context.Context
ch chan<- []*message_api.OriginatorEnvelope
ch chan<- []*envelopesProto.OriginatorEnvelope
closed bool
topics map[string]struct{}
originators map[uint32]struct{}
Expand All @@ -32,7 +33,7 @@ type listener struct {
func newListener(
ctx context.Context,
query *message_api.EnvelopesQuery,
ch chan<- []*message_api.OriginatorEnvelope,
ch chan<- []*envelopesProto.OriginatorEnvelope,
) *listener {
l := &listener{
ctx: ctx,
Expand Down Expand Up @@ -206,7 +207,7 @@ func (s *subscribeWorker) start() {
}

func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) {
env := &message_api.OriginatorEnvelope{}
env := &envelopesProto.OriginatorEnvelope{}
err := proto.Unmarshal(row.OriginatorEnvelope, env)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
Expand All @@ -222,7 +223,7 @@ func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) {

func (s *subscribeWorker) dispatchToListeners(
listeners *listenerSet,
env *message_api.OriginatorEnvelope,
env *envelopesProto.OriginatorEnvelope,
) {
if listeners == nil {
return
Expand All @@ -239,7 +240,7 @@ func (s *subscribeWorker) dispatchToListeners(
s.closeListener(l)
default:
select {
case l.ch <- []*message_api.OriginatorEnvelope{env}:
case l.ch <- []*envelopesProto.OriginatorEnvelope{env}:
default:
s.log.Info("Channel full, removing listener", zap.Any("listener", l.ch))
s.closeListener(l)
Expand Down Expand Up @@ -268,8 +269,8 @@ func (s *subscribeWorker) closeListener(l *listener) {
func (s *subscribeWorker) listen(
ctx context.Context,
query *message_api.EnvelopesQuery,
) <-chan []*message_api.OriginatorEnvelope {
ch := make(chan []*message_api.OriginatorEnvelope, subscriptionBufferSize)
) <-chan []*envelopesProto.OriginatorEnvelope {
ch := make(chan []*envelopesProto.OriginatorEnvelope, subscriptionBufferSize)
l := newListener(ctx, query, ch)

if l.isGlobal {
Expand Down
25 changes: 13 additions & 12 deletions pkg/envelopes/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ package envelopes
import (
"errors"

"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/topic"
"github.com/xmtp/xmtpd/pkg/utils"
"google.golang.org/protobuf/proto"
)

type ClientEnvelope struct {
proto *message_api.ClientEnvelope
proto *envelopesProto.ClientEnvelope
targetTopic topic.Topic
}

func NewClientEnvelope(proto *message_api.ClientEnvelope) (*ClientEnvelope, error) {
func NewClientEnvelope(proto *envelopesProto.ClientEnvelope) (*ClientEnvelope, error) {
if proto == nil {
return nil, errors.New("proto is nil")
}
Expand All @@ -35,11 +36,11 @@ func NewClientEnvelope(proto *message_api.ClientEnvelope) (*ClientEnvelope, erro
}

func NewClientEnvelopeFromBytes(bytes []byte) (*ClientEnvelope, error) {
var message message_api.ClientEnvelope
if err := proto.Unmarshal(bytes, &message); err != nil {
message, err := utils.UnmarshalClientEnvelope(bytes)
if err != nil {
return nil, err
}
return NewClientEnvelope(&message)
return NewClientEnvelope(message)
}

func (c *ClientEnvelope) Bytes() ([]byte, error) {
Expand All @@ -58,11 +59,11 @@ func (c *ClientEnvelope) Payload() interface{} {
return c.proto.Payload
}

func (c *ClientEnvelope) Aad() *message_api.AuthenticatedData {
func (c *ClientEnvelope) Aad() *envelopesProto.AuthenticatedData {
return c.proto.Aad
}

func (c *ClientEnvelope) Proto() *message_api.ClientEnvelope {
func (c *ClientEnvelope) Proto() *envelopesProto.ClientEnvelope {
return c.proto
}

Expand All @@ -72,13 +73,13 @@ func (c *ClientEnvelope) TopicMatchesPayload() bool {
payload := c.proto.Payload

switch payload.(type) {
case *message_api.ClientEnvelope_WelcomeMessage:
case *envelopesProto.ClientEnvelope_WelcomeMessage:
return targetTopicKind == topic.TOPIC_KIND_WELCOME_MESSAGES_V1
case *message_api.ClientEnvelope_GroupMessage:
case *envelopesProto.ClientEnvelope_GroupMessage:
return targetTopicKind == topic.TOPIC_KIND_GROUP_MESSAGES_V1
case *message_api.ClientEnvelope_IdentityUpdate:
case *envelopesProto.ClientEnvelope_IdentityUpdate:
return targetTopicKind == topic.TOPIC_KIND_IDENTITY_UPDATES_V1
case *message_api.ClientEnvelope_UploadKeyPackage:
case *envelopesProto.ClientEnvelope_UploadKeyPackage:
return targetTopicKind == topic.TOPIC_KIND_KEY_PACKAGES_V1
default:
return false
Expand Down
Loading

0 comments on commit 235dad6

Please sign in to comment.