From 4ab46d2a142e16d135d5e0445ef506e84168b3c3 Mon Sep 17 00:00:00 2001 From: Ziya Suzen Date: Wed, 7 Feb 2024 14:42:49 +0000 Subject: [PATCH] Reject payloads over the threshold set by server (#378) * Reject payloads over the threshold set by server * format and test fix --- .../Commands/CommandWriter.cs | 13 +++- .../Commands/PriorityCommandWriter.cs | 4 +- src/NATS.Client.Core/NatsConnection.cs | 6 +- tests/NATS.Client.Core.Tests/ProtocolTest.cs | 74 +++++++++++++++++++ tests/NATS.Client.TestUtilities/MockServer.cs | 2 +- 5 files changed, 92 insertions(+), 7 deletions(-) diff --git a/src/NATS.Client.Core/Commands/CommandWriter.cs b/src/NATS.Client.Core/Commands/CommandWriter.cs index 31cc150d1..158eb874b 100644 --- a/src/NATS.Client.Core/Commands/CommandWriter.cs +++ b/src/NATS.Client.Core/Commands/CommandWriter.cs @@ -22,6 +22,7 @@ internal sealed class CommandWriter : IAsyncDisposable private const int MaxSendSize = 16384; private readonly ILogger _logger; + private readonly NatsConnection _connection; private readonly ObjectPool _pool; private readonly int _arrayPoolInitialSize; private readonly object _lock = new(); @@ -42,9 +43,10 @@ internal sealed class CommandWriter : IAsyncDisposable private CancellationTokenSource? _ctsReader; private volatile bool _disposed; - public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing, TimeSpan? overrideCommandTimeout = default) + public CommandWriter(NatsConnection connection, ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing, TimeSpan? overrideCommandTimeout = default) { _logger = opts.LoggerFactory.CreateLogger(); + _connection = connection; _pool = pool; // Derive ArrayPool rent size from buffer size to @@ -245,6 +247,12 @@ public ValueTask PublishAsync(string subject, T? value, NatsHeaders? headers, if (value != null) serializer.Serialize(payloadBuffer, value); + var size = payloadBuffer.WrittenMemory.Length + (headersBuffer?.WrittenMemory.Length ?? 0); + if (_connection.ServerInfo is { } info && size > info.MaxPayload) + { + ThrowOnMaxPayload(size, info.MaxPayload); + } + return PublishLockedAsync(subject, replyTo, payloadBuffer, headersBuffer, cancellationToken); } @@ -309,6 +317,9 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken // only used for internal testing internal bool TestStallFlush() => _channelLock.Writer.TryWrite(1); + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowOnMaxPayload(int size, int max) => throw new NatsException($"Payload size {size} exceeds server's maximum payload size {max}"); + private static async Task ReaderLoopAsync(ILogger logger, ISocketConnection connection, PipeReader pipeReader, Channel channelSize, CancellationToken cancellationToken) { try diff --git a/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs b/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs index 3020e84f4..f6c70427f 100644 --- a/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs +++ b/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs @@ -6,9 +6,9 @@ internal sealed class PriorityCommandWriter : IAsyncDisposable { private int _disposed; - public PriorityCommandWriter(ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing) + public PriorityCommandWriter(NatsConnection connection, ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing) { - CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan); + CommandWriter = new CommandWriter(connection, pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan); CommandWriter.Reset(socketConnection); } diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 0a285d285..4ccc8b7d1 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -32,7 +32,7 @@ public partial class NatsConnection : INatsConnection public Func<(string Host, int Port), ValueTask<(string Host, int Port)>>? OnConnectingAsync; internal readonly ConnectionStatsCounter Counter; // allow to call from external sources - internal ServerInfo? WritableServerInfo; + internal volatile ServerInfo? WritableServerInfo; internal bool IsDisposed; #pragma warning restore SA1401 @@ -79,7 +79,7 @@ public NatsConnection(NatsOpts opts) _cancellationTimerPool = new CancellationTimerPool(_pool, _disposedCancellationTokenSource.Token); _name = opts.Name; Counter = new ConnectionStatsCounter(); - CommandWriter = new CommandWriter(_pool, Opts, Counter, EnqueuePing); + CommandWriter = new CommandWriter(this, _pool, Opts, Counter, EnqueuePing); InboxPrefix = NewInbox(opts.InboxPrefix); SubscriptionManager = new SubscriptionManager(this, InboxPrefix); _logger = opts.LoggerFactory.CreateLogger(); @@ -431,7 +431,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) // Authentication _userCredentials?.Authenticate(_clientOpts, WritableServerInfo); - await using (var priorityCommandWriter = new PriorityCommandWriter(_pool, _socket!, Opts, Counter, EnqueuePing)) + await using (var priorityCommandWriter = new PriorityCommandWriter(this, _pool, _socket!, Opts, Counter, EnqueuePing)) { // add CONNECT and PING command to priority lane await priorityCommandWriter.CommandWriter.ConnectAsync(_clientOpts, CancellationToken.None).ConfigureAwait(false); diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs index 6c7228c35..27202246d 100644 --- a/tests/NATS.Client.Core.Tests/ProtocolTest.cs +++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs @@ -415,6 +415,80 @@ public async Task Protocol_parser_under_load(int size) counts.Count.Should().BeGreaterOrEqualTo(3); } + [Fact] + public async Task Proactively_reject_payloads_over_the_threshold_set_by_server() + { + await using var server = NatsServer.Start(); + await using var nats = server.CreateClientConnection(); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + var sync = 0; + var count = 0; + var signal1 = new WaitSignal>(); + var signal2 = new WaitSignal>(); + var subTask = Task.Run( + async () => + { + await foreach (var m in nats.SubscribeAsync("foo.*", cancellationToken: cts.Token)) + { + if (m.Subject == "foo.sync") + { + Interlocked.Exchange(ref sync, 1); + continue; + } + + Interlocked.Increment(ref count); + + if (m.Subject == "foo.signal1") + { + signal1.Pulse(m); + } + else if (m.Subject == "foo.signal2") + { + signal2.Pulse(m); + } + else if (m.Subject == "foo.end") + { + break; + } + } + }, + cancellationToken: cts.Token); + + await Retry.Until( + reason: "subscription is active", + condition: () => Volatile.Read(ref sync) == 1, + action: async () => await nats.PublishAsync("foo.sync", cancellationToken: cts.Token), + retryDelay: TimeSpan.FromSeconds(.3)); + { + var payload = new byte[nats.ServerInfo!.MaxPayload]; + await nats.PublishAsync("foo.signal1", payload, cancellationToken: cts.Token); + var msg1 = await signal1; + Assert.Equal(payload.Length, msg1.Data!.Length); + } + + { + var payload = new byte[nats.ServerInfo!.MaxPayload + 1]; + var exception = await Assert.ThrowsAsync(async () => + await nats.PublishAsync("foo.none", payload, cancellationToken: cts.Token)); + Assert.Matches(@"Payload size \d+ exceeds server's maximum payload size \d+", exception.Message); + } + + { + var payload = new byte[123]; + await nats.PublishAsync("foo.signal2", payload, cancellationToken: cts.Token); + var msg1 = await signal2; + Assert.Equal(payload.Length, msg1.Data!.Length); + } + + await nats.PublishAsync("foo.end", cancellationToken: cts.Token); + + await subTask; + + Assert.Equal(3, Volatile.Read(ref count)); + } + private sealed class NatsSubReconnectTest : NatsSubBase { private readonly Action _callback; diff --git a/tests/NATS.Client.TestUtilities/MockServer.cs b/tests/NATS.Client.TestUtilities/MockServer.cs index d9a0f273d..5ebfa186f 100644 --- a/tests/NATS.Client.TestUtilities/MockServer.cs +++ b/tests/NATS.Client.TestUtilities/MockServer.cs @@ -29,7 +29,7 @@ public MockServer( var stream = client.GetStream(); var sw = new StreamWriter(stream, Encoding.ASCII); - await sw.WriteAsync("INFO {}\r\n"); + await sw.WriteAsync("INFO {\"max_payload\":1048576}\r\n"); await sw.FlushAsync(); var sr = new StreamReader(stream, Encoding.ASCII);