diff --git a/src/NATS.Client.JetStream/Internal/NatsJSConsume.cs b/src/NATS.Client.JetStream/Internal/NatsJSConsume.cs index c83f7e0b6..9b384587d 100644 --- a/src/NATS.Client.JetStream/Internal/NatsJSConsume.cs +++ b/src/NATS.Client.JetStream/Internal/NatsJSConsume.cs @@ -164,6 +164,24 @@ public ValueTask CallMsgNextAsync(string origin, ConsumerGetnextRequest request, public void ResetHeartbeatTimer() => _timer.Change(_hbTimeout, _hbTimeout); + public void Delivered(int msgSize) + { + lock (_pendingGate) + { + if (_pendingMsgs > 0) + _pendingMsgs--; + + if (_maxBytes > 0) + { + _pendingBytes -= msgSize; + if (_pendingBytes < 0) + _pendingBytes = 0; + } + + CheckPending(); + } + } + public override async ValueTask DisposeAsync() { Interlocked.Exchange(ref _disposed, 1); @@ -183,26 +201,48 @@ public override async ValueTask DisposeAsync() internal override async ValueTask WriteReconnectCommandsAsync(CommandWriter commandWriter, int sid) { await base.WriteReconnectCommandsAsync(commandWriter, sid); - ResetPending(); - - var request = new ConsumerGetnextRequest - { - Batch = _maxMsgs, - MaxBytes = _maxBytes, - IdleHeartbeat = _idle, - Expires = _expires, - }; if (_cancellationToken.IsCancellationRequested) return; - await commandWriter.PublishAsync( - subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}", - value: request, - headers: default, - replyTo: Subject, - serializer: NatsJSJsonSerializer.Default, - cancellationToken: CancellationToken.None); + long maxMsgs = 0; + long maxBytes = 0; + + // We have to do the pending check here because we can't access + // the publish method here since the connection state is not open yet + // and we're just writing the reconnect commands. + lock (_pendingGate) + { + if (_maxBytes > 0 && _pendingBytes <= _thresholdBytes) + { + maxBytes = _maxBytes - _pendingBytes; + } + else if (_maxBytes == 0 && _pendingMsgs <= _thresholdMsgs && _pendingMsgs < _maxMsgs) + { + maxMsgs = _maxMsgs - _pendingMsgs; + } + } + + if (maxMsgs > 0 || maxBytes > 0) + { + var request = new ConsumerGetnextRequest + { + Batch = maxMsgs, + MaxBytes = maxBytes, + IdleHeartbeat = _idle, + Expires = _expires, + }; + + await commandWriter.PublishAsync( + subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}", + value: request, + headers: default, + replyTo: Subject, + serializer: NatsJSJsonSerializer.Default, + cancellationToken: CancellationToken.None); + + ResetPending(); + } } protected override async ValueTask ReceiveInternalAsync( @@ -323,6 +363,8 @@ protected override async ValueTask ReceiveInternalAsync( { throw new NatsJSException("No header found"); } + + CheckPending(); } else { @@ -337,23 +379,6 @@ protected override async ValueTask ReceiveInternalAsync( _serializer), _context); - lock (_pendingGate) - { - if (_pendingMsgs > 0) - _pendingMsgs--; - } - - if (_maxBytes > 0) - { - if (_debug) - _logger.LogDebug(NatsJSLogEvents.MessageProperty, "Message size {Size}", msg.Size); - - lock (_pendingGate) - { - _pendingBytes -= msg.Size; - } - } - // Stop feeding the user if we are disposed. // We need to exit as soon as possible. if (Volatile.Read(ref _disposed) == 0) @@ -364,8 +389,6 @@ protected override async ValueTask ReceiveInternalAsync( await _userMsgs.Writer.WriteAsync(msg).ConfigureAwait(false); } } - - CheckPending(); } protected override void TryComplete() diff --git a/src/NATS.Client.JetStream/NatsJSConsumer.cs b/src/NATS.Client.JetStream/NatsJSConsumer.cs index e1c7b400d..de84b7f72 100644 --- a/src/NATS.Client.JetStream/NatsJSConsumer.cs +++ b/src/NATS.Client.JetStream/NatsJSConsumer.cs @@ -96,6 +96,7 @@ public async IAsyncEnumerable> ConsumeAsync( break; yield return jsMsg; + cc.Delivered(jsMsg.Size); } } } diff --git a/tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs b/tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs index d01d6831a..bcb464256 100644 --- a/tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs +++ b/tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs @@ -53,17 +53,12 @@ await Assert.ThrowsAnyAsync(async () => [Fact] public async Task Consume_msgs_test() { - var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); - - await using var server = NatsServer.Start( - outputHelper: _output, - opts: new NatsServerOptsBuilder() - .UseTransport(TransportType.Tcp) - .Trace() - .UseJetStream() - .Build()); + await using var server = NatsServer.StartJS(); var (nats, proxy) = server.CreateProxiedClientConnection(); var js = new NatsJSContext(nats); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + await js.CreateStreamAsync("s1", new[] { "s1.*" }, cts.Token); await js.CreateOrUpdateConsumerAsync("s1", "c1", cancellationToken: cts.Token); @@ -76,8 +71,7 @@ public async Task Consume_msgs_test() var consumerOpts = new NatsJSConsumeOpts { MaxMsgs = 10 }; var consumer = (NatsJSConsumer)await js.GetConsumerAsync("s1", "c1", cts.Token); var count = 0; - await using var cc = await consumer.ConsumeInternalAsync(serializer: TestDataJsonSerializer.Default, consumerOpts, cancellationToken: cts.Token); - await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token)) + await foreach (var msg in consumer.ConsumeAsync(serializer: TestDataJsonSerializer.Default, consumerOpts, cancellationToken: cts.Token)) { await msg.AckAsync(cancellationToken: cts.Token); Assert.Equal(count, msg.Data!.Test); @@ -92,7 +86,7 @@ public async Task Consume_msgs_test() await Retry.Until( reason: "received enough pulls", - condition: () => PullCount() > 5, + condition: () => PullCount() > 4, action: () => { _output.WriteLine($"### PullCount:{PullCount()}"); @@ -215,12 +209,10 @@ public async Task Consume_reconnect_test() // Not interested in management messages sent upto this point await proxy.FlushFramesAsync(nats); - var cc = await consumer.ConsumeInternalAsync(serializer: TestDataJsonSerializer.Default, consumerOpts, cancellationToken: cts.Token); - var readerTask = Task.Run(async () => { var count = 0; - await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token)) + await foreach (var msg in consumer.ConsumeAsync(serializer: TestDataJsonSerializer.Default, consumerOpts, cancellationToken: cts.Token)) { await msg.AckAsync(cancellationToken: cts.Token); Assert.Equal(count, msg.Data!.Test); @@ -230,6 +222,8 @@ public async Task Consume_reconnect_test() if (count == 2) break; } + + return count; }); // Send a message before reconnect @@ -258,11 +252,9 @@ await Retry.Until( ack.EnsureSuccess(); } - await Retry.Until( - "acked", - () => proxy.ClientFrames.Any(f => f.Message.Contains("CONSUMER.MSG.NEXT"))); + var count = await readerTask; + Assert.Equal(2, count); - await readerTask; await nats.DisposeAsync(); } @@ -446,4 +438,97 @@ public async Task Serialization_errors() break; } } + + [Fact] + public async Task Consume_right_amount_of_messages() + { + await using var server = NatsServer.StartJS(); + await using var nats = server.CreateClientConnection(); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + var js = new NatsJSContext(nats); + await js.CreateStreamAsync("s1", ["s1.*"], cts.Token); + + var payload = new byte[1024]; + for (var i = 0; i < 50; i++) + { + var ack = await js.PublishAsync("s1.foo", payload, cancellationToken: cts.Token); + ack.EnsureSuccess(); + } + + // Max messages + { + var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c1", cancellationToken: cts.Token); + var opts = new NatsJSConsumeOpts { MaxMsgs = 10, }; + var count = 0; + await foreach (var msg in consumer.ConsumeAsync(opts: opts, cancellationToken: cts.Token)) + { + await msg.AckAsync(cancellationToken: cts.Token); + if (++count == 4) + break; + } + + await Retry.Until("consumer stats updated", async () => + { + var info = (await js.GetConsumerAsync("s1", "c1", cts.Token)).Info; + return info is { NumAckPending: 6, NumPending: 40 }; + }); + } + + // Max bytes + { + var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c2", cancellationToken: cts.Token); + var opts = new NatsJSConsumeOpts { MaxBytes = 10 * (1024 + 50), }; + var count = 0; + await foreach (var msg in consumer.ConsumeAsync(opts: opts, cancellationToken: cts.Token)) + { + await msg.AckAsync(cancellationToken: cts.Token); + if (++count == 4) + break; + } + + await Retry.Until("consumer stats updated", async () => + { + var info = (await js.GetConsumerAsync("s1", "c2", cts.Token)).Info; + return info is { NumAckPending: 6, NumPending: 40 }; + }); + } + } + + [Fact] + public async Task Consume_right_amount_of_messages_when_ack_wait_exceeded() + { + await using var server = NatsServer.StartJS(); + await using var nats = server.CreateClientConnection(); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20)); + + var js = new NatsJSContext(nats); + await js.CreateStreamAsync("email-queue", ["email.>"], cts.Token); + await js.PublishAsync("email.queue", "1", cancellationToken: cts.Token); + await js.PublishAsync("email.queue", "2", cancellationToken: cts.Token); + var consumer = await js.CreateOrUpdateConsumerAsync( + stream: "email-queue", + new ConsumerConfig("email-queue-consumer") { AckWait = TimeSpan.FromSeconds(10) }, + cancellationToken: cts.Token); + var count = 0; + await foreach (var msg in consumer.ConsumeAsync(opts: new NatsJSConsumeOpts { MaxMsgs = 1 }, cancellationToken: cts.Token)) + { + _output.WriteLine($"Received: {msg.Data}"); + + // Only wait for the first couple of messages + // to get close to the ack wait time + if (count < 2) + await Task.Delay(TimeSpan.FromSeconds(6), cts.Token); + + // Since we're pulling one message at a time, + // we should not exceed the ack wait time + await msg.AckAsync(cancellationToken: cts.Token); + count++; + } + + // Should not have redeliveries + Assert.Equal(2, count); + } }