From e58d428becedb70a9c0ab5f72deb7e00a8b66b11 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 2 Dec 2024 12:05:53 +0100 Subject: [PATCH] [FIXED] Race in FetchBatch in legacy API Signed-off-by: Piotr Piotrowski --- js.go | 20 +++++++++---- test/js_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/js.go b/js.go index e024fae0a..8c0e324b2 100644 --- a/js.go +++ b/js.go @@ -3114,20 +3114,27 @@ type MessageBatch interface { } type messageBatch struct { + sync.Mutex msgs chan *Msg err error done chan struct{} } func (mb *messageBatch) Messages() <-chan *Msg { + mb.Lock() + defer mb.Unlock() return mb.msgs } func (mb *messageBatch) Error() error { + mb.Lock() + defer mb.Unlock() return mb.err } func (mb *messageBatch) Done() <-chan struct{} { + mb.Lock() + defer mb.Unlock() return mb.done } @@ -3302,12 +3309,11 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } var hbTimer *time.Timer var hbErr error - hbLock := sync.Mutex{} if o.hb > 0 { hbTimer = time.AfterFunc(2*o.hb, func() { - hbLock.Lock() + result.Lock() hbErr = ErrNoHeartbeat - hbLock.Unlock() + result.Unlock() cancel() }) } @@ -3338,21 +3344,25 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e break } if usrMsg { + result.Lock() result.msgs <- msg + result.Unlock() requestMsgs++ } } if err != nil { - hbLock.Lock() + result.Lock() if hbErr != nil { result.err = hbErr } else { result.err = o.checkCtxErr(err) } - hbLock.Unlock() + result.Unlock() } close(result.msgs) + result.Lock() result.done <- struct{}{} + result.Unlock() }() return result, nil } diff --git a/test/js_test.go b/test/js_test.go index db791eb50..91a259b31 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -10788,3 +10788,83 @@ func TestJetStreamTransform(t *testing.T) { } } + +func TestPullConsumerFetchRace(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + + nc, js := jsClient(t, srv) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for i := 0; i < 3; i++ { + if _, err := js.Publish("FOO.123", []byte(fmt.Sprintf("msg-%d", i))); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + sub, err := js.PullSubscribe("FOO.123", "") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cons, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs, err := sub.FetchBatch(5) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + errCh := make(chan error) + go func() { + for { + err := msgs.Error() + if err != nil { + errCh <- err + return + } + } + }() + deleteErrCh := make(chan error, 1) + go func() { + time.Sleep(100 * time.Millisecond) + if err := js.DeleteConsumer("foo", cons.Name); err != nil { + deleteErrCh <- err + } + close(deleteErrCh) + }() + + var i int + for msg := range msgs.Messages() { + if string(msg.Data) != fmt.Sprintf("msg-%d", i) { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, fmt.Sprintf("msg-%d", i), string(msg.Data)) + } + i++ + } + if i != 3 { + t.Fatalf("Invalid number of messages received; want: %d; got: %d", 5, i) + } + select { + case err := <-errCh: + if !errors.Is(err, nats.ErrConsumerDeleted) { + t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, nil) + } + + // wait until the consumer is deleted, otherwise we may close the connection + // before the consumer delete response is received + select { + case ert, ok := <-deleteErrCh: + if !ok { + break + } + t.Fatalf("Error deleting consumer: %s", ert) + case <-time.After(1 * time.Second): + t.Fatalf("Expected done to be closed") + } +}