From d1614eab64f1685aea5318ab8b1d4a28bba68701 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 2 Dec 2024 11:16:34 +0100 Subject: [PATCH 1/2] [FIXED] Race in Fetch when accessing the error Signed-off-by: Piotr Piotrowski --- jetstream/ordered.go | 13 +- jetstream/pull.go | 25 +++- jetstream/test/ordered_test.go | 216 ++++++++++++++++----------------- jetstream/test/pull_test.go | 85 +++++++++++++ 4 files changed, 217 insertions(+), 122 deletions(-) diff --git a/jetstream/ordered.go b/jetstream/ordered.go index 5fe656e9b..ed5ef6ac4 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -393,26 +393,26 @@ func (s *orderedSubscription) Closed() <-chan struct{} { // reset the consumer for each subsequent Fetch call. // Consider using [Consumer.Consume] or [Consumer.Messages] instead. func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { + c.Lock() if c.consumerType == consumerTypeConsume { + c.Unlock() return nil, ErrOrderConsumerUsedAsConsume } - c.currentConsumer.Lock() if c.runningFetch != nil { - if !c.runningFetch.done { - c.currentConsumer.Unlock() + if !c.runningFetch.closed() { return nil, ErrOrderedConsumerConcurrentRequests } if c.runningFetch.sseq != 0 { c.cursor.streamSeq = c.runningFetch.sseq } } - c.currentConsumer.Unlock() c.consumerType = consumerTypeFetch sub := orderedSubscription{ consumer: c, done: make(chan struct{}), } c.subscription = &sub + c.Unlock() err := c.reset() if err != nil { return nil, err @@ -433,11 +433,13 @@ func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, erro // reset the consumer for each subsequent Fetch call. // Consider using [Consumer.Consume] or [Consumer.Messages] instead. func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) { + c.Lock() if c.consumerType == consumerTypeConsume { + c.Unlock() return nil, ErrOrderConsumerUsedAsConsume } if c.runningFetch != nil { - if !c.runningFetch.done { + if !c.runningFetch.closed() { return nil, ErrOrderedConsumerConcurrentRequests } if c.runningFetch.sseq != 0 { @@ -450,6 +452,7 @@ func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBat done: make(chan struct{}), } c.subscription = &sub + c.Unlock() err := c.reset() if err != nil { return nil, err diff --git a/jetstream/pull.go b/jetstream/pull.go index 386968108..764bf2a1d 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -144,6 +144,7 @@ type ( } fetchResult struct { + sync.Mutex msgs chan Msg err error done bool @@ -780,7 +781,7 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { for { select { case msg := <-msgs: - p.Lock() + res.Lock() if hbTimer != nil { hbTimer.Reset(2 * req.Heartbeat) } @@ -791,11 +792,11 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { res.err = err } res.done = true - p.Unlock() + res.Unlock() return } if !userMsg { - p.Unlock() + res.Unlock() continue } res.msgs <- p.jetStream.toJSMsg(msg) @@ -810,16 +811,20 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { } if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes >= req.MaxBytes) { res.done = true - p.Unlock() + res.Unlock() return } - p.Unlock() + res.Unlock() case err := <-sub.errs: + res.Lock() res.err = err res.done = true + res.Unlock() return case <-time.After(req.Expires + 1*time.Second): + res.Lock() res.done = true + res.Unlock() return } } @@ -828,13 +833,23 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { } func (fr *fetchResult) Messages() <-chan Msg { + fr.Lock() + defer fr.Unlock() return fr.msgs } func (fr *fetchResult) Error() error { + fr.Lock() + defer fr.Unlock() return fr.err } +func (fr *fetchResult) closed() bool { + fr.Lock() + defer fr.Unlock() + return fr.done +} + // Next is used to retrieve the next message from the stream. This // method will block until the message is retrieved or timeout is // reached. diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index 5a6231b2d..a40867ac5 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -580,131 +580,123 @@ func TestOrderedConsumerConsume(t *testing.T) { }) t.Run("wait for closed after drain", func(t *testing.T) { - for i := 0; i < 10; i++ { - t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - js, err := jetstream.New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - msgs := make([]jetstream.Msg, 0) - lock := sync.Mutex{} - publishTestMsgs(t, js) - cc, err := c.Consume(func(msg jetstream.Msg) { - time.Sleep(50 * time.Millisecond) - msg.Ack() - lock.Lock() - msgs = append(msgs, msg) - lock.Unlock() - }) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - closed := cc.Closed() - time.Sleep(100 * time.Millisecond) - if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { - t.Fatalf("Unexpected error: %v", err) - } - publishTestMsgs(t, js) - - // wait for the consumer to be recreated before calling drain - for i := 0; i < 5; i++ { - _, err = c.Info(ctx) - if err != nil { - if errors.Is(err, jetstream.ErrConsumerNotFound) { - time.Sleep(100 * time.Millisecond) - continue - } - t.Fatalf("Unexpected error: %v", err) - } - break + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + closed := cc.Closed() + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + + // wait for the consumer to be recreated before calling drain + for i := 0; i < 5; i++ { + _, err = c.Info(ctx) + if err != nil { + if errors.Is(err, jetstream.ErrConsumerNotFound) { + time.Sleep(100 * time.Millisecond) + continue } + t.Fatalf("Unexpected error: %v", err) + } + break + } - cc.Drain() + cc.Drain() - select { - case <-closed: - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for consume to be closed") - } + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } - if len(msgs) != 2*len(testMsgs) { - t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs)) - } - }) + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs)) } }) t.Run("wait for closed on already closed consume", func(t *testing.T) { - for i := 0; i < 10; i++ { - t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - js, err := jetstream.New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - msgs := make([]jetstream.Msg, 0) - lock := sync.Mutex{} - publishTestMsgs(t, js) - cc, err := c.Consume(func(msg jetstream.Msg) { - time.Sleep(50 * time.Millisecond) - msg.Ack() - lock.Lock() - msgs = append(msgs, msg) - lock.Unlock() - }) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - time.Sleep(100 * time.Millisecond) - if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { - t.Fatalf("Unexpected error: %v", err) - } + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } - cc.Stop() + cc.Stop() - time.Sleep(100 * time.Millisecond) + time.Sleep(100 * time.Millisecond) - select { - case <-cc.Closed(): - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for consume to be closed") - } - }) + select { + case <-cc.Closed(): + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") } }) } diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 4042e52f5..dcca6d1c5 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -477,6 +477,91 @@ func TestPullConsumerFetch(t *testing.T) { }) } +func TestPullConsumerFetchRace(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for i := 0; i < 3; i++ { + if _, err := js.Publish(context.Background(), "FOO.123", []byte(fmt.Sprintf("msg-%d", i))); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + msgs, err := c.Fetch(5) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + errCh := make(chan error) + go func() { + for { + err := msgs.Error() + if err != nil { + errCh <- err + return + } + } + }() + deleteErrCh := make(chan error, 1) + go func() { + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil { + deleteErrCh <- err + } + close(deleteErrCh) + }() + + var i int + for msg := range msgs.Messages() { + if string(msg.Data()) != fmt.Sprintf("msg-%d", i) { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, fmt.Sprintf("msg-%d", i), string(msg.Data())) + } + i++ + } + if i != 3 { + t.Fatalf("Invalid number of messages received; want: %d; got: %d", 5, i) + } + select { + case err := <-errCh: + if !errors.Is(err, jetstream.ErrConsumerDeleted) { + t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, nil) + } + + // wait until the consumer is deleted, otherwise we may close the connection + // before the consumer delete response is received + select { + case ert, ok := <-deleteErrCh: + if !ok { + break + } + t.Fatalf("Error deleting consumer: %s", ert) + case <-time.After(1 * time.Second): + t.Fatalf("Expected done to be closed") + } +} + func TestPullConsumerFetchBytes(t *testing.T) { testSubject := "FOO.123" msg := [10]byte{} From e58d428becedb70a9c0ab5f72deb7e00a8b66b11 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 2 Dec 2024 12:05:53 +0100 Subject: [PATCH 2/2] [FIXED] Race in FetchBatch in legacy API Signed-off-by: Piotr Piotrowski --- js.go | 20 +++++++++---- test/js_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/js.go b/js.go index e024fae0a..8c0e324b2 100644 --- a/js.go +++ b/js.go @@ -3114,20 +3114,27 @@ type MessageBatch interface { } type messageBatch struct { + sync.Mutex msgs chan *Msg err error done chan struct{} } func (mb *messageBatch) Messages() <-chan *Msg { + mb.Lock() + defer mb.Unlock() return mb.msgs } func (mb *messageBatch) Error() error { + mb.Lock() + defer mb.Unlock() return mb.err } func (mb *messageBatch) Done() <-chan struct{} { + mb.Lock() + defer mb.Unlock() return mb.done } @@ -3302,12 +3309,11 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } var hbTimer *time.Timer var hbErr error - hbLock := sync.Mutex{} if o.hb > 0 { hbTimer = time.AfterFunc(2*o.hb, func() { - hbLock.Lock() + result.Lock() hbErr = ErrNoHeartbeat - hbLock.Unlock() + result.Unlock() cancel() }) } @@ -3338,21 +3344,25 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e break } if usrMsg { + result.Lock() result.msgs <- msg + result.Unlock() requestMsgs++ } } if err != nil { - hbLock.Lock() + result.Lock() if hbErr != nil { result.err = hbErr } else { result.err = o.checkCtxErr(err) } - hbLock.Unlock() + result.Unlock() } close(result.msgs) + result.Lock() result.done <- struct{}{} + result.Unlock() }() return result, nil } diff --git a/test/js_test.go b/test/js_test.go index db791eb50..91a259b31 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -10788,3 +10788,83 @@ func TestJetStreamTransform(t *testing.T) { } } + +func TestPullConsumerFetchRace(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + + nc, js := jsClient(t, srv) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for i := 0; i < 3; i++ { + if _, err := js.Publish("FOO.123", []byte(fmt.Sprintf("msg-%d", i))); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + sub, err := js.PullSubscribe("FOO.123", "") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cons, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs, err := sub.FetchBatch(5) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + errCh := make(chan error) + go func() { + for { + err := msgs.Error() + if err != nil { + errCh <- err + return + } + } + }() + deleteErrCh := make(chan error, 1) + go func() { + time.Sleep(100 * time.Millisecond) + if err := js.DeleteConsumer("foo", cons.Name); err != nil { + deleteErrCh <- err + } + close(deleteErrCh) + }() + + var i int + for msg := range msgs.Messages() { + if string(msg.Data) != fmt.Sprintf("msg-%d", i) { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, fmt.Sprintf("msg-%d", i), string(msg.Data)) + } + i++ + } + if i != 3 { + t.Fatalf("Invalid number of messages received; want: %d; got: %d", 5, i) + } + select { + case err := <-errCh: + if !errors.Is(err, nats.ErrConsumerDeleted) { + t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, nil) + } + + // wait until the consumer is deleted, otherwise we may close the connection + // before the consumer delete response is received + select { + case ert, ok := <-deleteErrCh: + if !ok { + break + } + t.Fatalf("Error deleting consumer: %s", ert) + case <-time.After(1 * time.Second): + t.Fatalf("Expected done to be closed") + } +}