Skip to content

Commit

Permalink
Fixed consume pending message calculation (#626)
Browse files Browse the repository at this point in the history
* Fixed consume pending message calculation

Implemented a new method `Delivered` to accurately adjust the pending message
size on message delivery. Ensured pending checks and size adjustments are
correctly handled during reconnect and message delivery to maintain a
consistent consumer state. Also updated tests to validate these changes.

* dotnet format
  • Loading branch information
mtmk authored Sep 11, 2024
1 parent dc5fa62 commit 7b6af12
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 54 deletions.
93 changes: 58 additions & 35 deletions src/NATS.Client.JetStream/Internal/NatsJSConsume.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ public ValueTask CallMsgNextAsync(string origin, ConsumerGetnextRequest request,

public void ResetHeartbeatTimer() => _timer.Change(_hbTimeout, _hbTimeout);

public void Delivered(int msgSize)
{
lock (_pendingGate)
{
if (_pendingMsgs > 0)
_pendingMsgs--;

if (_maxBytes > 0)
{
_pendingBytes -= msgSize;
if (_pendingBytes < 0)
_pendingBytes = 0;
}

CheckPending();
}
}

public override async ValueTask DisposeAsync()
{
Interlocked.Exchange(ref _disposed, 1);
Expand All @@ -183,26 +201,48 @@ public override async ValueTask DisposeAsync()
internal override async ValueTask WriteReconnectCommandsAsync(CommandWriter commandWriter, int sid)
{
await base.WriteReconnectCommandsAsync(commandWriter, sid);
ResetPending();

var request = new ConsumerGetnextRequest
{
Batch = _maxMsgs,
MaxBytes = _maxBytes,
IdleHeartbeat = _idle,
Expires = _expires,
};

if (_cancellationToken.IsCancellationRequested)
return;

await commandWriter.PublishAsync(
subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}",
value: request,
headers: default,
replyTo: Subject,
serializer: NatsJSJsonSerializer<ConsumerGetnextRequest>.Default,
cancellationToken: CancellationToken.None);
long maxMsgs = 0;
long maxBytes = 0;

// We have to do the pending check here because we can't access
// the publish method here since the connection state is not open yet
// and we're just writing the reconnect commands.
lock (_pendingGate)
{
if (_maxBytes > 0 && _pendingBytes <= _thresholdBytes)
{
maxBytes = _maxBytes - _pendingBytes;
}
else if (_maxBytes == 0 && _pendingMsgs <= _thresholdMsgs && _pendingMsgs < _maxMsgs)
{
maxMsgs = _maxMsgs - _pendingMsgs;
}
}

if (maxMsgs > 0 || maxBytes > 0)
{
var request = new ConsumerGetnextRequest
{
Batch = maxMsgs,
MaxBytes = maxBytes,
IdleHeartbeat = _idle,
Expires = _expires,
};

await commandWriter.PublishAsync(
subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}",
value: request,
headers: default,
replyTo: Subject,
serializer: NatsJSJsonSerializer<ConsumerGetnextRequest>.Default,
cancellationToken: CancellationToken.None);

ResetPending();
}
}

protected override async ValueTask ReceiveInternalAsync(
Expand Down Expand Up @@ -323,6 +363,8 @@ protected override async ValueTask ReceiveInternalAsync(
{
throw new NatsJSException("No header found");
}

CheckPending();
}
else
{
Expand All @@ -337,23 +379,6 @@ protected override async ValueTask ReceiveInternalAsync(
_serializer),
_context);

lock (_pendingGate)
{
if (_pendingMsgs > 0)
_pendingMsgs--;
}

if (_maxBytes > 0)
{
if (_debug)
_logger.LogDebug(NatsJSLogEvents.MessageProperty, "Message size {Size}", msg.Size);

lock (_pendingGate)
{
_pendingBytes -= msg.Size;
}
}

// Stop feeding the user if we are disposed.
// We need to exit as soon as possible.
if (Volatile.Read(ref _disposed) == 0)
Expand All @@ -364,8 +389,6 @@ protected override async ValueTask ReceiveInternalAsync(
await _userMsgs.Writer.WriteAsync(msg).ConfigureAwait(false);
}
}

CheckPending();
}

protected override void TryComplete()
Expand Down
1 change: 1 addition & 0 deletions src/NATS.Client.JetStream/NatsJSConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public async IAsyncEnumerable<NatsJSMsg<T>> ConsumeAsync<T>(
break;

yield return jsMsg;
cc.Delivered(jsMsg.Size);
}
}
}
Expand Down
123 changes: 104 additions & 19 deletions tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,12 @@ await Assert.ThrowsAnyAsync<ArgumentException>(async () =>
[Fact]
public async Task Consume_msgs_test()
{
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

await using var server = NatsServer.Start(
outputHelper: _output,
opts: new NatsServerOptsBuilder()
.UseTransport(TransportType.Tcp)
.Trace()
.UseJetStream()
.Build());
await using var server = NatsServer.StartJS();
var (nats, proxy) = server.CreateProxiedClientConnection();
var js = new NatsJSContext(nats);

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

await js.CreateStreamAsync("s1", new[] { "s1.*" }, cts.Token);
await js.CreateOrUpdateConsumerAsync("s1", "c1", cancellationToken: cts.Token);

Expand All @@ -76,8 +71,7 @@ public async Task Consume_msgs_test()
var consumerOpts = new NatsJSConsumeOpts { MaxMsgs = 10 };
var consumer = (NatsJSConsumer)await js.GetConsumerAsync("s1", "c1", cts.Token);
var count = 0;
await using var cc = await consumer.ConsumeInternalAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token);
await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token))
await foreach (var msg in consumer.ConsumeAsync(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
Assert.Equal(count, msg.Data!.Test);
Expand All @@ -92,7 +86,7 @@ public async Task Consume_msgs_test()

await Retry.Until(
reason: "received enough pulls",
condition: () => PullCount() > 5,
condition: () => PullCount() > 4,
action: () =>
{
_output.WriteLine($"### PullCount:{PullCount()}");
Expand Down Expand Up @@ -215,12 +209,10 @@ public async Task Consume_reconnect_test()
// Not interested in management messages sent upto this point
await proxy.FlushFramesAsync(nats);

var cc = await consumer.ConsumeInternalAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token);

var readerTask = Task.Run(async () =>
{
var count = 0;
await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token))
await foreach (var msg in consumer.ConsumeAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
Assert.Equal(count, msg.Data!.Test);
Expand All @@ -230,6 +222,8 @@ public async Task Consume_reconnect_test()
if (count == 2)
break;
}

return count;
});

// Send a message before reconnect
Expand Down Expand Up @@ -258,11 +252,9 @@ await Retry.Until(
ack.EnsureSuccess();
}

await Retry.Until(
"acked",
() => proxy.ClientFrames.Any(f => f.Message.Contains("CONSUMER.MSG.NEXT")));
var count = await readerTask;
Assert.Equal(2, count);

await readerTask;
await nats.DisposeAsync();
}

Expand Down Expand Up @@ -446,4 +438,97 @@ public async Task Serialization_errors()
break;
}
}

[Fact]
public async Task Consume_right_amount_of_messages()
{
await using var server = NatsServer.StartJS();
await using var nats = server.CreateClientConnection();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

var js = new NatsJSContext(nats);
await js.CreateStreamAsync("s1", ["s1.*"], cts.Token);

var payload = new byte[1024];
for (var i = 0; i < 50; i++)
{
var ack = await js.PublishAsync("s1.foo", payload, cancellationToken: cts.Token);
ack.EnsureSuccess();
}

// Max messages
{
var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c1", cancellationToken: cts.Token);
var opts = new NatsJSConsumeOpts { MaxMsgs = 10, };
var count = 0;
await foreach (var msg in consumer.ConsumeAsync<byte[]>(opts: opts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
if (++count == 4)
break;
}

await Retry.Until("consumer stats updated", async () =>
{
var info = (await js.GetConsumerAsync("s1", "c1", cts.Token)).Info;
return info is { NumAckPending: 6, NumPending: 40 };
});
}

// Max bytes
{
var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c2", cancellationToken: cts.Token);
var opts = new NatsJSConsumeOpts { MaxBytes = 10 * (1024 + 50), };
var count = 0;
await foreach (var msg in consumer.ConsumeAsync<byte[]>(opts: opts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
if (++count == 4)
break;
}

await Retry.Until("consumer stats updated", async () =>
{
var info = (await js.GetConsumerAsync("s1", "c2", cts.Token)).Info;
return info is { NumAckPending: 6, NumPending: 40 };
});
}
}

[Fact]
public async Task Consume_right_amount_of_messages_when_ack_wait_exceeded()
{
await using var server = NatsServer.StartJS();
await using var nats = server.CreateClientConnection();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));

var js = new NatsJSContext(nats);
await js.CreateStreamAsync("email-queue", ["email.>"], cts.Token);
await js.PublishAsync("email.queue", "1", cancellationToken: cts.Token);
await js.PublishAsync("email.queue", "2", cancellationToken: cts.Token);
var consumer = await js.CreateOrUpdateConsumerAsync(
stream: "email-queue",
new ConsumerConfig("email-queue-consumer") { AckWait = TimeSpan.FromSeconds(10) },
cancellationToken: cts.Token);
var count = 0;
await foreach (var msg in consumer.ConsumeAsync<string>(opts: new NatsJSConsumeOpts { MaxMsgs = 1 }, cancellationToken: cts.Token))
{
_output.WriteLine($"Received: {msg.Data}");

// Only wait for the first couple of messages
// to get close to the ack wait time
if (count < 2)
await Task.Delay(TimeSpan.FromSeconds(6), cts.Token);

// Since we're pulling one message at a time,
// we should not exceed the ack wait time
await msg.AckAsync(cancellationToken: cts.Token);
count++;
}

// Should not have redeliveries
Assert.Equal(2, count);
}
}

0 comments on commit 7b6af12

Please sign in to comment.