Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject payloads over the threshold set by server #378

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/NATS.Client.Core/Commands/CommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal sealed class CommandWriter : IAsyncDisposable
private const int MaxSendSize = 16384;

private readonly ILogger<CommandWriter> _logger;
private readonly NatsConnection _connection;
private readonly ObjectPool _pool;
private readonly int _arrayPoolInitialSize;
private readonly object _lock = new();
Expand All @@ -42,9 +43,10 @@ internal sealed class CommandWriter : IAsyncDisposable
private CancellationTokenSource? _ctsReader;
private volatile bool _disposed;

public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing, TimeSpan? overrideCommandTimeout = default)
public CommandWriter(NatsConnection connection, ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing, TimeSpan? overrideCommandTimeout = default)
{
_logger = opts.LoggerFactory.CreateLogger<CommandWriter>();
_connection = connection;
_pool = pool;

// Derive ArrayPool rent size from buffer size to
Expand Down Expand Up @@ -245,6 +247,12 @@ public ValueTask PublishAsync<T>(string subject, T? value, NatsHeaders? headers,
if (value != null)
serializer.Serialize(payloadBuffer, value);

var size = payloadBuffer.WrittenMemory.Length + (headersBuffer?.WrittenMemory.Length ?? 0);
if (_connection.ServerInfo is { } info && size > info.MaxPayload)
{
ThrowOnMaxPayload(size, info.MaxPayload);
}

return PublishLockedAsync(subject, replyTo, payloadBuffer, headersBuffer, cancellationToken);
}

Expand Down Expand Up @@ -309,6 +317,9 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken
// only used for internal testing
internal bool TestStallFlush() => _channelLock.Writer.TryWrite(1);

[MethodImpl(MethodImplOptions.NoInlining)]
private static void ThrowOnMaxPayload(int size, int max) => throw new NatsException($"Payload size {size} exceeds server's maximum payload size {max}");

private static async Task ReaderLoopAsync(ILogger<CommandWriter> logger, ISocketConnection connection, PipeReader pipeReader, Channel<int> channelSize, CancellationToken cancellationToken)
{
try
Expand Down
4 changes: 2 additions & 2 deletions src/NATS.Client.Core/Commands/PriorityCommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ internal sealed class PriorityCommandWriter : IAsyncDisposable
{
private int _disposed;

public PriorityCommandWriter(ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing)
public PriorityCommandWriter(NatsConnection connection, ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing)
{
CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan);
CommandWriter = new CommandWriter(connection, pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan);
CommandWriter.Reset(socketConnection);
}

Expand Down
6 changes: 3 additions & 3 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public partial class NatsConnection : INatsConnection
public Func<(string Host, int Port), ValueTask<(string Host, int Port)>>? OnConnectingAsync;

internal readonly ConnectionStatsCounter Counter; // allow to call from external sources
internal ServerInfo? WritableServerInfo;
internal volatile ServerInfo? WritableServerInfo;
internal bool IsDisposed;

#pragma warning restore SA1401
Expand Down Expand Up @@ -79,7 +79,7 @@ public NatsConnection(NatsOpts opts)
_cancellationTimerPool = new CancellationTimerPool(_pool, _disposedCancellationTokenSource.Token);
_name = opts.Name;
Counter = new ConnectionStatsCounter();
CommandWriter = new CommandWriter(_pool, Opts, Counter, EnqueuePing);
CommandWriter = new CommandWriter(this, _pool, Opts, Counter, EnqueuePing);
InboxPrefix = NewInbox(opts.InboxPrefix);
SubscriptionManager = new SubscriptionManager(this, InboxPrefix);
_logger = opts.LoggerFactory.CreateLogger<NatsConnection>();
Expand Down Expand Up @@ -431,7 +431,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// Authentication
_userCredentials?.Authenticate(_clientOpts, WritableServerInfo);

await using (var priorityCommandWriter = new PriorityCommandWriter(_pool, _socket!, Opts, Counter, EnqueuePing))
await using (var priorityCommandWriter = new PriorityCommandWriter(this, _pool, _socket!, Opts, Counter, EnqueuePing))
{
// add CONNECT and PING command to priority lane
await priorityCommandWriter.CommandWriter.ConnectAsync(_clientOpts, CancellationToken.None).ConfigureAwait(false);
Expand Down
74 changes: 74 additions & 0 deletions tests/NATS.Client.Core.Tests/ProtocolTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,80 @@ public async Task Protocol_parser_under_load(int size)
counts.Count.Should().BeGreaterOrEqualTo(3);
}

[Fact]
public async Task Proactively_reject_payloads_over_the_threshold_set_by_server()
{
await using var server = NatsServer.Start();
await using var nats = server.CreateClientConnection();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));

var sync = 0;
var count = 0;
var signal1 = new WaitSignal<NatsMsg<byte[]>>();
var signal2 = new WaitSignal<NatsMsg<byte[]>>();
var subTask = Task.Run(
async () =>
{
await foreach (var m in nats.SubscribeAsync<byte[]>("foo.*", cancellationToken: cts.Token))
{
if (m.Subject == "foo.sync")
{
Interlocked.Exchange(ref sync, 1);
continue;
}

Interlocked.Increment(ref count);

if (m.Subject == "foo.signal1")
{
signal1.Pulse(m);
}
else if (m.Subject == "foo.signal2")
{
signal2.Pulse(m);
}
else if (m.Subject == "foo.end")
{
break;
}
}
},
cancellationToken: cts.Token);

await Retry.Until(
reason: "subscription is active",
condition: () => Volatile.Read(ref sync) == 1,
action: async () => await nats.PublishAsync("foo.sync", cancellationToken: cts.Token),
retryDelay: TimeSpan.FromSeconds(.3));
{
var payload = new byte[nats.ServerInfo!.MaxPayload];
await nats.PublishAsync("foo.signal1", payload, cancellationToken: cts.Token);
var msg1 = await signal1;
Assert.Equal(payload.Length, msg1.Data!.Length);
}

{
var payload = new byte[nats.ServerInfo!.MaxPayload + 1];
var exception = await Assert.ThrowsAsync<NatsException>(async () =>
await nats.PublishAsync("foo.none", payload, cancellationToken: cts.Token));
Assert.Matches(@"Payload size \d+ exceeds server's maximum payload size \d+", exception.Message);
}

{
var payload = new byte[123];
await nats.PublishAsync("foo.signal2", payload, cancellationToken: cts.Token);
var msg1 = await signal2;
Assert.Equal(payload.Length, msg1.Data!.Length);
}

await nats.PublishAsync("foo.end", cancellationToken: cts.Token);

await subTask;

Assert.Equal(3, Volatile.Read(ref count));
}

private sealed class NatsSubReconnectTest : NatsSubBase
{
private readonly Action<int> _callback;
Expand Down
2 changes: 1 addition & 1 deletion tests/NATS.Client.TestUtilities/MockServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public MockServer(
var stream = client.GetStream();

var sw = new StreamWriter(stream, Encoding.ASCII);
await sw.WriteAsync("INFO {}\r\n");
await sw.WriteAsync("INFO {\"max_payload\":1048576}\r\n");
await sw.FlushAsync();

var sr = new StreamReader(stream, Encoding.ASCII);
Expand Down
Loading