Skip to content

Commit

Permalink
[FIXED] Race in FetchBatch in legacy API
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Dec 2, 2024
1 parent d1614ea commit e58d428
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
20 changes: 15 additions & 5 deletions js.go
Original file line number Diff line number Diff line change
Expand Up @@ -3114,20 +3114,27 @@ type MessageBatch interface {
}

type messageBatch struct {
sync.Mutex
msgs chan *Msg
err error
done chan struct{}
}

func (mb *messageBatch) Messages() <-chan *Msg {
mb.Lock()
defer mb.Unlock()
return mb.msgs
}

func (mb *messageBatch) Error() error {
mb.Lock()
defer mb.Unlock()
return mb.err
}

func (mb *messageBatch) Done() <-chan struct{} {
mb.Lock()
defer mb.Unlock()
return mb.done
}

Expand Down Expand Up @@ -3302,12 +3309,11 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
}
var hbTimer *time.Timer
var hbErr error
hbLock := sync.Mutex{}
if o.hb > 0 {
hbTimer = time.AfterFunc(2*o.hb, func() {
hbLock.Lock()
result.Lock()
hbErr = ErrNoHeartbeat
hbLock.Unlock()
result.Unlock()
cancel()
})
}
Expand Down Expand Up @@ -3338,21 +3344,25 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
break
}
if usrMsg {
result.Lock()
result.msgs <- msg
result.Unlock()
requestMsgs++
}
}
if err != nil {
hbLock.Lock()
result.Lock()
if hbErr != nil {
result.err = hbErr
} else {
result.err = o.checkCtxErr(err)
}
hbLock.Unlock()
result.Unlock()
}
close(result.msgs)
result.Lock()
result.done <- struct{}{}
result.Unlock()
}()
return result, nil
}
Expand Down
80 changes: 80 additions & 0 deletions test/js_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10788,3 +10788,83 @@ func TestJetStreamTransform(t *testing.T) {
}

}

func TestPullConsumerFetchRace(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

nc, js := jsClient(t, srv)
defer nc.Close()

_, err := js.AddStream(&nats.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

for i := 0; i < 3; i++ {
if _, err := js.Publish("FOO.123", []byte(fmt.Sprintf("msg-%d", i))); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
}
sub, err := js.PullSubscribe("FOO.123", "")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
cons, err := sub.ConsumerInfo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := sub.FetchBatch(5)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
errCh := make(chan error)
go func() {
for {
err := msgs.Error()
if err != nil {
errCh <- err
return
}
}
}()
deleteErrCh := make(chan error, 1)
go func() {
time.Sleep(100 * time.Millisecond)
if err := js.DeleteConsumer("foo", cons.Name); err != nil {
deleteErrCh <- err
}
close(deleteErrCh)
}()

var i int
for msg := range msgs.Messages() {
if string(msg.Data) != fmt.Sprintf("msg-%d", i) {
t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, fmt.Sprintf("msg-%d", i), string(msg.Data))
}
i++
}
if i != 3 {
t.Fatalf("Invalid number of messages received; want: %d; got: %d", 5, i)
}
select {
case err := <-errCh:
if !errors.Is(err, nats.ErrConsumerDeleted) {
t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Expected error: %v; got: %v", nats.ErrConsumerDeleted, nil)
}

// wait until the consumer is deleted, otherwise we may close the connection
// before the consumer delete response is received
select {
case ert, ok := <-deleteErrCh:
if !ok {
break
}
t.Fatalf("Error deleting consumer: %s", ert)
case <-time.After(1 * time.Second):
t.Fatalf("Expected done to be closed")
}
}

0 comments on commit e58d428

Please sign in to comment.