diff --git a/pkg/http/pubsub_test.go b/pkg/http/pubsub_test.go index 40b734d..dd6416a 100644 --- a/pkg/http/pubsub_test.go +++ b/pkg/http/pubsub_test.go @@ -3,6 +3,7 @@ package http_test import ( "context" "fmt" + nethttp "net/http" "testing" "time" @@ -65,16 +66,31 @@ func TestHttpPubSub(t *testing.T) { waitForHTTP(t, sub, time.Second*10) - receivedMessages := make(chan message.Messages) + t.Run("publish a message with invalid metadata", func(t *testing.T) { + req, err := nethttp.NewRequest(nethttp.MethodPost, fmt.Sprintf("http://%s/test", sub.Addr()), nil) + require.NoError(t, err) - go func() { - received, _ := subscriber.BulkRead(msgs, 100, time.Second*10) - receivedMessages <- received - }() + req.Header.Set("Content-Type", "application/json") + req.Header.Set(http.HeaderMetadata, "invalid_metadata") + + resp, err := nethttp.DefaultClient.Do(req) + require.NoError(t, err) + + require.Equal(t, nethttp.StatusBadRequest, resp.StatusCode) + }) + + t.Run("publish correct simple messages", func(t *testing.T) { + receivedMessages := make(chan message.Messages) + + go func() { + received, _ := subscriber.BulkRead(msgs, 100, time.Second*10) + receivedMessages <- received + }() - publishedMessages := tests.PublishSimpleMessages(t, 100, pub, fmt.Sprintf("http://%s/test", sub.Addr())) + publishedMessages := tests.PublishSimpleMessages(t, 100, pub, fmt.Sprintf("http://%s/test", sub.Addr())) - tests.AssertAllMessagesReceived(t, publishedMessages, <-receivedMessages) + tests.AssertAllMessagesReceived(t, publishedMessages, <-receivedMessages) + }) } func waitForHTTP(t *testing.T, sub *http.Subscriber, timeoutTime time.Duration) { diff --git a/pkg/http/subscriber.go b/pkg/http/subscriber.go index 51396f6..d50813e 100644 --- a/pkg/http/subscriber.go +++ b/pkg/http/subscriber.go @@ -116,10 +116,6 @@ func (s *Subscriber) Subscribe(ctx context.Context, url string) (<-chan *message s.config.Router.Post(url, func(w http.ResponseWriter, r *http.Request) { msg, err := s.config.UnmarshalMessageFunc(url, r) - ctx, cancelCtx := context.WithCancel(ctx) - msg.SetContext(ctx) - defer cancelCtx() - if err != nil { s.logger.Info("Cannot unmarshal message", baseLogFields.Add(watermill.LogFields{"err": err})) w.WriteHeader(http.StatusBadRequest) @@ -130,6 +126,11 @@ func (s *Subscriber) Subscribe(ctx context.Context, url string) (<-chan *message w.WriteHeader(http.StatusBadRequest) return } + + ctx, cancelCtx := context.WithCancel(ctx) + msg.SetContext(ctx) + defer cancelCtx() + logFields := baseLogFields.Add(watermill.LogFields{"message_uuid": msg.UUID}) s.logger.Trace("Sending msg", logFields)