diff --git a/src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs b/src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs index 681b312..f0e0779 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs @@ -37,7 +37,8 @@ public AsyncMessagePump(Action callback) { if (callback == null) throw new ArgumentNullException(nameof(callback)); - _callback = message => { + _callback = message => + { callback(message); return Task.CompletedTask; }; @@ -46,8 +47,7 @@ public AsyncMessagePump(Action callback) /// /// Posts the specified message to the message queue. /// - public void Post(T message) - => Post(new ValueTask(message)); + public void Post(T message) => Post(new ValueTask(message)); /// /// Posts the result of an asynchronous operation to the message queue. @@ -55,16 +55,33 @@ public void Post(T message) public void Post(ValueTask messageTask) { bool attach = false; - lock (_queue) { + lock (_queue) + { _queue.Enqueue(messageTask); attach = _queue.Count == 1; } - if (attach) { + if (attach) + { CompleteAsync(); } } + /// + /// Returns the number of messages waiting in the queue. + /// Includes the message currently being processed, if any. + /// + public int Count + { + get + { + lock (_queue) + { + return _queue.Count; + } + } + } + /// /// Processes message in the queue until it is empty. /// @@ -72,25 +89,33 @@ private async void CompleteAsync() { // grab the message at the start of the queue, but don't remove it from the queue ValueTask messageTask; - lock (_queue) { + lock (_queue) + { // should always successfully peek from the queue here #pragma warning disable CA2012 // Use ValueTasks correctly messageTask = _queue.Peek(); #pragma warning restore CA2012 // Use ValueTasks correctly } - while (true) { + while (true) + { // process the message - try { + try + { var message = await messageTask.ConfigureAwait(false); await _callback(message).ConfigureAwait(false); - } catch (Exception ex) { - try { + } + catch (Exception ex) + { + try + { await HandleErrorAsync(ex); - } catch { } + } + catch { } } // once the message has been passed along, dequeue it - lock (_queue) { + lock (_queue) + { #pragma warning disable CA2012 // Use ValueTasks correctly _ = _queue.Dequeue(); #pragma warning restore CA2012 // Use ValueTasks correctly @@ -105,6 +130,5 @@ private async void CompleteAsync() /// /// Handles exceptions that occur within the asynchronous message delegate or the callback. /// - protected virtual Task HandleErrorAsync(Exception exception) - => Task.CompletedTask; + protected virtual Task HandleErrorAsync(Exception exception) => Task.CompletedTask; } diff --git a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs index a671bd4..83d0902 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs @@ -38,4 +38,11 @@ public class GraphQLWebSocketOptions /// Disconnects a subscription from the client there are any GraphQL errors during a subscription. /// public bool DisconnectAfterAnyError { get; set; } + + /// + /// To help prevent backpressure from slower internet speeds, this will prevent the queue from expanding + /// beyond the max length. + /// The default is null (no limit). Value must be greater than 0. + /// + public int? MaxSendQueueThreshold { get; set; } } diff --git a/src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs b/src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs index ae95e0d..a85406d 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs @@ -22,6 +22,11 @@ public interface IWebSocketConnection : IDisposable /// Task SendMessageAsync(OperationMessage message); + /// + /// Sends a message. Option to ignoreMaxSendQueueThreshold and force a message. + /// + Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false); + /// /// Closes the WebSocket connection, and /// prevents further incoming messages from being dispatched through . diff --git a/src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs b/src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs index 2885bbb..5cbd397 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs @@ -23,6 +23,7 @@ public class WebSocketConnection : IWebSocketConnection private readonly WebSocketWriterStream _stream; private readonly TaskCompletionSource _outputClosed = new(); private readonly int _closeTimeoutMs; + private readonly int? _maxSendQueueThreshold; private volatile bool _closeRequested; private int _executed; @@ -44,18 +45,39 @@ public class WebSocketConnection : IWebSocketConnection /// /// Initializes an instance with the specified parameters. /// - public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQLSerializer serializer, GraphQLWebSocketOptions options, CancellationToken cancellationToken) + public WebSocketConnection( + HttpContext httpContext, + WebSocket webSocket, + IGraphQLSerializer serializer, + GraphQLWebSocketOptions options, + CancellationToken cancellationToken + ) { HttpContext = httpContext ?? throw new ArgumentNullException(nameof(httpContext)); if (options == null) throw new ArgumentNullException(nameof(options)); - if (options.DisconnectionTimeout.HasValue) { - if ((options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan && options.DisconnectionTimeout.Value.TotalMilliseconds < 0) || options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue) + if (options.DisconnectionTimeout.HasValue) + { + if ( + ( + options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan + && options.DisconnectionTimeout.Value.TotalMilliseconds < 0 + ) + || options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue + ) #pragma warning disable CA2208 // Instantiate argument exceptions correctly - throw new ArgumentOutOfRangeException(nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout)); + throw new ArgumentOutOfRangeException( + nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout) + ); #pragma warning restore CA2208 // Instantiate argument exceptions correctly } - _closeTimeoutMs = (int)(options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds; + _closeTimeoutMs = (int) + (options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds; + _maxSendQueueThreshold = options.MaxSendQueueThreshold; + if (_maxSendQueueThreshold != null && _maxSendQueueThreshold <= 0) + throw new ArgumentOutOfRangeException( + nameof(options) + "." + nameof(GraphQLWebSocketOptions.MaxSendQueueThreshold) + ); _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); _stream = new(webSocket); _serializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); @@ -64,8 +86,7 @@ public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQ } /// - public virtual void Dispose() - => GC.SuppressFinalize(this); + public virtual void Dispose() => GC.SuppressFinalize(this); /// /// Listens to incoming messages on the WebSocket specified in the constructor, @@ -77,8 +98,11 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa if (operationMessageProcessor == null) throw new ArgumentNullException(nameof(operationMessageProcessor)); if (Interlocked.Exchange(ref _executed, 1) == 1) - throw new InvalidOperationException($"{nameof(ExecuteAsync)} may only be called once per instance."); - try { + throw new InvalidOperationException( + $"{nameof(ExecuteAsync)} may only be called once per instance." + ); + try + { await operationMessageProcessor.InitializeConnectionAsync(); // set up a buffer in case a message is longer than one block var receiveStream = new MemoryStream(); @@ -93,19 +117,23 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa // prep a reader stream var bufferStream = new ReusableMemoryReaderStream(buffer); // read messages until an exception occurs, the cancellation token is signaled, or a 'close' message is received - while (!RequestAborted.IsCancellationRequested) { + while (!RequestAborted.IsCancellationRequested) + { var result = await _webSocket.ReceiveAsync(bufferMemory, RequestAborted); - if (result.MessageType == WebSocketMessageType.Close) { + if (result.MessageType == WebSocketMessageType.Close) + { // prevent any more messages from being queued operationMessageProcessor.Dispose(); // send a close request if none was sent yet - if (!_outputClosed.Task.IsCompleted) { + if (!_outputClosed.Task.IsCompleted) + { // queue the closure _ = CloseAsync(); // wait until the close has been sent await Task.WhenAny( _outputClosed.Task, - Task.Delay(_closeTimeoutMs, RequestAborted)); + Task.Delay(_closeTimeoutMs, RequestAborted) + ); } // quit return; @@ -114,39 +142,53 @@ await Task.WhenAny( if (_closeRequested) continue; // if this is the last block terminating a message - if (result.EndOfMessage) { + if (result.EndOfMessage) + { // if only one block of data was sent for this message - if (receiveStream.Length == 0) { + if (receiveStream.Length == 0) + { // if the message is empty, skip to the next message if (result.Count == 0) continue; // read the message bufferStream.ResetLength(result.Count); - var message = await _serializer.ReadAsync(bufferStream, RequestAborted); + var message = await _serializer.ReadAsync( + bufferStream, + RequestAborted + ); // dispatch the message if (message != null) await OnDispatchMessageAsync(operationMessageProcessor, message); - } else { + } + else + { // if there is any data in this block, add it to the buffer if (result.Count > 0) receiveStream.Write(buffer, 0, result.Count); // read the message from the buffer receiveStream.Position = 0; - var message = await _serializer.ReadAsync(receiveStream, RequestAborted); + var message = await _serializer.ReadAsync( + receiveStream, + RequestAborted + ); // clear the buffer receiveStream.SetLength(0); // dispatch the message if (message != null) await OnDispatchMessageAsync(operationMessageProcessor, message); } - } else { + } + else + { // if there is any data in this block, add it to the buffer if (result.Count > 0) receiveStream.Write(buffer, 0, result.Count); } } - } catch (WebSocketException) { - } finally { + } + catch (WebSocketException) { } + finally + { // prevent any more messages from being sent _outputClosed.TrySetResult(false); // prevent any more messages from attempting to send @@ -155,21 +197,39 @@ await Task.WhenAny( } /// - public Task CloseAsync() - => CloseAsync(1000, null); + public Task CloseAsync() => CloseAsync(1000, null); /// public Task CloseAsync(int eventId, string? description) { _closeRequested = true; - _pump.Post(new Message { CloseStatus = (WebSocketCloseStatus)eventId, CloseDescription = description }); + _pump.Post( + new Message + { + CloseStatus = (WebSocketCloseStatus)eventId, + CloseDescription = description + } + ); return Task.CompletedTask; } /// public Task SendMessageAsync(OperationMessage message) { - _pump.Post(new Message { OperationMessage = message }); + if (_maxSendQueueThreshold == null || _maxSendQueueThreshold.Value > _pump.Count) + _pump.Post(new Message { OperationMessage = message }); + return Task.CompletedTask; + } + + /// + public Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false) + { + if ( + ignoreMaxSendQueueThreshold + || _maxSendQueueThreshold == null + || _maxSendQueueThreshold.Value > _pump.Count + ) + _pump.Post(new Message { OperationMessage = message }); return Task.CompletedTask; } @@ -186,9 +246,12 @@ private async Task HandleMessageAsync(Message message) if (_outputClosed.Task.IsCompleted) return; LastMessageSentAt = DateTime.UtcNow; - if (message.OperationMessage != null) { + if (message.OperationMessage != null) + { await OnSendMessageAsync(message.OperationMessage); - } else { + } + else + { await OnCloseOutputAsync(message.CloseStatus, message.CloseDescription); _outputClosed.TrySetResult(true); } @@ -200,8 +263,10 @@ private async Task HandleMessageAsync(Message message) ///

/// This method is synchronized and will wait until completion before dispatching another message. ///
- protected virtual Task OnDispatchMessageAsync(IOperationMessageProcessor operationMessageProcessor, OperationMessage message) - => operationMessageProcessor.OnMessageReceivedAsync(message); + protected virtual Task OnDispatchMessageAsync( + IOperationMessageProcessor operationMessageProcessor, + OperationMessage message + ) => operationMessageProcessor.OnMessageReceivedAsync(message); /// /// Sends the specified message to the underlying . @@ -221,8 +286,10 @@ protected virtual async Task OnSendMessageAsync(OperationMessage message) ///

/// This method is synchronized and will wait until completion before sending another message or closing the output stream. ///
- protected virtual Task OnCloseOutputAsync(WebSocketCloseStatus closeStatus, string? closeDescription) - => _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted); + protected virtual Task OnCloseOutputAsync( + WebSocketCloseStatus closeStatus, + string? closeDescription + ) => _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted); /// /// A queue entry; see . @@ -230,5 +297,9 @@ protected virtual Task OnCloseOutputAsync(WebSocketCloseStatus closeStatus, stri /// The message to send, if set; if it is null then this is a closure message. /// The close status. /// The close description. - private record struct Message(OperationMessage? OperationMessage, WebSocketCloseStatus CloseStatus, string? CloseDescription); + private record struct Message( + OperationMessage? OperationMessage, + WebSocketCloseStatus CloseStatus, + string? CloseDescription + ); }