Skip to content

Commit

Permalink
Implement backpressure logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
STRATZ-Ken committed Sep 23, 2024
1 parent d4b3573 commit 0dc1a40
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 46 deletions.
52 changes: 38 additions & 14 deletions src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public AsyncMessagePump(Action<T> callback)
{
if (callback == null)
throw new ArgumentNullException(nameof(callback));
_callback = message => {
_callback = message =>
{
callback(message);
return Task.CompletedTask;
};
Expand All @@ -46,51 +47,75 @@ public AsyncMessagePump(Action<T> callback)
/// <summary>
/// Posts the specified message to the message queue.
/// </summary>
public void Post(T message)
=> Post(new ValueTask<T>(message));
public void Post(T message) => Post(new ValueTask<T>(message));

/// <summary>
/// Posts the result of an asynchronous operation to the message queue.
/// </summary>
public void Post(ValueTask<T> messageTask)
{
bool attach = false;
lock (_queue) {
lock (_queue)
{
_queue.Enqueue(messageTask);
attach = _queue.Count == 1;
}

if (attach) {
if (attach)
{
CompleteAsync();
}
}

/// <summary>
/// Returns the number of messages waiting in the queue.
/// Includes the message currently being processed, if any.
/// </summary>
public int Count
{
get
{
lock (_queue)
{
return _queue.Count;
}
}
}

/// <summary>
/// Processes message in the queue until it is empty.
/// </summary>
private async void CompleteAsync()
{
// grab the message at the start of the queue, but don't remove it from the queue
ValueTask<T> messageTask;
lock (_queue) {
lock (_queue)
{
// should always successfully peek from the queue here
#pragma warning disable CA2012 // Use ValueTasks correctly
messageTask = _queue.Peek();
#pragma warning restore CA2012 // Use ValueTasks correctly
}
while (true) {
while (true)
{
// process the message
try {
try
{
var message = await messageTask.ConfigureAwait(false);
await _callback(message).ConfigureAwait(false);
} catch (Exception ex) {
try {
}
catch (Exception ex)
{
try
{
await HandleErrorAsync(ex);
} catch { }
}
catch { }
}

// once the message has been passed along, dequeue it
lock (_queue) {
lock (_queue)
{
#pragma warning disable CA2012 // Use ValueTasks correctly
_ = _queue.Dequeue();
#pragma warning restore CA2012 // Use ValueTasks correctly
Expand All @@ -105,6 +130,5 @@ private async void CompleteAsync()
/// <summary>
/// Handles exceptions that occur within the asynchronous message delegate or the callback.
/// </summary>
protected virtual Task HandleErrorAsync(Exception exception)
=> Task.CompletedTask;
protected virtual Task HandleErrorAsync(Exception exception) => Task.CompletedTask;
}
7 changes: 7 additions & 0 deletions src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ public class GraphQLWebSocketOptions
/// Disconnects a subscription from the client there are any GraphQL errors during a subscription.
/// </summary>
public bool DisconnectAfterAnyError { get; set; }

/// <summary>
/// To help prevent backpressure from slower internet speeds, this will prevent the queue from expanding
/// beyond the max length.
/// The default is null (no limit). Value must be greater than 0.
/// </summary>
public int? MaxSendQueueThreshold { get; set; }
}
5 changes: 5 additions & 0 deletions src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public interface IWebSocketConnection : IDisposable
/// </summary>
Task SendMessageAsync(OperationMessage message);

/// <summary>
/// Sends a message. Option to ignoreMaxSendQueueThreshold and force a message.
/// </summary>
Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false);

/// <summary>
/// Closes the WebSocket connection, and
/// prevents further incoming messages from being dispatched through <see cref="IOperationMessageProcessor"/>.
Expand Down
135 changes: 103 additions & 32 deletions src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class WebSocketConnection : IWebSocketConnection
private readonly WebSocketWriterStream _stream;
private readonly TaskCompletionSource<bool> _outputClosed = new();
private readonly int _closeTimeoutMs;
private readonly int? _maxSendQueueThreshold;
private volatile bool _closeRequested;
private int _executed;

Expand All @@ -44,18 +45,39 @@ public class WebSocketConnection : IWebSocketConnection
/// <summary>
/// Initializes an instance with the specified parameters.
/// </summary>
public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQLSerializer serializer, GraphQLWebSocketOptions options, CancellationToken cancellationToken)
public WebSocketConnection(
HttpContext httpContext,
WebSocket webSocket,
IGraphQLSerializer serializer,
GraphQLWebSocketOptions options,
CancellationToken cancellationToken
)
{
HttpContext = httpContext ?? throw new ArgumentNullException(nameof(httpContext));
if (options == null)
throw new ArgumentNullException(nameof(options));
if (options.DisconnectionTimeout.HasValue) {
if ((options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan && options.DisconnectionTimeout.Value.TotalMilliseconds < 0) || options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue)
if (options.DisconnectionTimeout.HasValue)
{
if (
(
options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan
&& options.DisconnectionTimeout.Value.TotalMilliseconds < 0
)
|| options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue
)
#pragma warning disable CA2208 // Instantiate argument exceptions correctly
throw new ArgumentOutOfRangeException(nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout));
throw new ArgumentOutOfRangeException(
nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout)
);
#pragma warning restore CA2208 // Instantiate argument exceptions correctly
}
_closeTimeoutMs = (int)(options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds;
_closeTimeoutMs = (int)
(options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds;
_maxSendQueueThreshold = options.MaxSendQueueThreshold;
if (_maxSendQueueThreshold != null && _maxSendQueueThreshold <= 0)
throw new ArgumentOutOfRangeException(
nameof(options) + "." + nameof(GraphQLWebSocketOptions.MaxSendQueueThreshold)
);
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
_stream = new(webSocket);
_serializer = serializer ?? throw new ArgumentNullException(nameof(serializer));
Expand All @@ -64,8 +86,7 @@ public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQ
}

/// <inheritdoc/>
public virtual void Dispose()
=> GC.SuppressFinalize(this);
public virtual void Dispose() => GC.SuppressFinalize(this);

/// <summary>
/// Listens to incoming messages on the WebSocket specified in the constructor,
Expand All @@ -77,8 +98,11 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa
if (operationMessageProcessor == null)
throw new ArgumentNullException(nameof(operationMessageProcessor));
if (Interlocked.Exchange(ref _executed, 1) == 1)
throw new InvalidOperationException($"{nameof(ExecuteAsync)} may only be called once per instance.");
try {
throw new InvalidOperationException(
$"{nameof(ExecuteAsync)} may only be called once per instance."
);
try
{
await operationMessageProcessor.InitializeConnectionAsync();
// set up a buffer in case a message is longer than one block
var receiveStream = new MemoryStream();
Expand All @@ -93,19 +117,23 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa
// prep a reader stream
var bufferStream = new ReusableMemoryReaderStream(buffer);
// read messages until an exception occurs, the cancellation token is signaled, or a 'close' message is received
while (!RequestAborted.IsCancellationRequested) {
while (!RequestAborted.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(bufferMemory, RequestAborted);
if (result.MessageType == WebSocketMessageType.Close) {
if (result.MessageType == WebSocketMessageType.Close)
{
// prevent any more messages from being queued
operationMessageProcessor.Dispose();
// send a close request if none was sent yet
if (!_outputClosed.Task.IsCompleted) {
if (!_outputClosed.Task.IsCompleted)
{
// queue the closure
_ = CloseAsync();
// wait until the close has been sent
await Task.WhenAny(
_outputClosed.Task,
Task.Delay(_closeTimeoutMs, RequestAborted));
Task.Delay(_closeTimeoutMs, RequestAborted)
);
}
// quit
return;
Expand All @@ -114,39 +142,53 @@ await Task.WhenAny(
if (_closeRequested)
continue;
// if this is the last block terminating a message
if (result.EndOfMessage) {
if (result.EndOfMessage)
{
// if only one block of data was sent for this message
if (receiveStream.Length == 0) {
if (receiveStream.Length == 0)
{
// if the message is empty, skip to the next message
if (result.Count == 0)
continue;
// read the message
bufferStream.ResetLength(result.Count);
var message = await _serializer.ReadAsync<OperationMessage>(bufferStream, RequestAborted);
var message = await _serializer.ReadAsync<OperationMessage>(
bufferStream,
RequestAborted
);
// dispatch the message
if (message != null)
await OnDispatchMessageAsync(operationMessageProcessor, message);
} else {
}
else
{
// if there is any data in this block, add it to the buffer
if (result.Count > 0)
receiveStream.Write(buffer, 0, result.Count);
// read the message from the buffer
receiveStream.Position = 0;
var message = await _serializer.ReadAsync<OperationMessage>(receiveStream, RequestAborted);
var message = await _serializer.ReadAsync<OperationMessage>(
receiveStream,
RequestAborted
);
// clear the buffer
receiveStream.SetLength(0);
// dispatch the message
if (message != null)
await OnDispatchMessageAsync(operationMessageProcessor, message);
}
} else {
}
else
{
// if there is any data in this block, add it to the buffer
if (result.Count > 0)
receiveStream.Write(buffer, 0, result.Count);
}
}
} catch (WebSocketException) {
} finally {
}
catch (WebSocketException) { }
finally
{
// prevent any more messages from being sent
_outputClosed.TrySetResult(false);
// prevent any more messages from attempting to send
Expand All @@ -155,21 +197,39 @@ await Task.WhenAny(
}

/// <inheritdoc/>
public Task CloseAsync()
=> CloseAsync(1000, null);
public Task CloseAsync() => CloseAsync(1000, null);

/// <inheritdoc/>
public Task CloseAsync(int eventId, string? description)
{
_closeRequested = true;
_pump.Post(new Message { CloseStatus = (WebSocketCloseStatus)eventId, CloseDescription = description });
_pump.Post(
new Message
{
CloseStatus = (WebSocketCloseStatus)eventId,
CloseDescription = description
}
);
return Task.CompletedTask;
}

/// <inheritdoc/>
public Task SendMessageAsync(OperationMessage message)
{
_pump.Post(new Message { OperationMessage = message });
if (_maxSendQueueThreshold == null || _maxSendQueueThreshold.Value > _pump.Count)
_pump.Post(new Message { OperationMessage = message });
return Task.CompletedTask;
}

/// <inheritdoc/>
public Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false)
{
if (
ignoreMaxSendQueueThreshold
|| _maxSendQueueThreshold == null
|| _maxSendQueueThreshold.Value > _pump.Count
)
_pump.Post(new Message { OperationMessage = message });
return Task.CompletedTask;
}

Expand All @@ -186,9 +246,12 @@ private async Task HandleMessageAsync(Message message)
if (_outputClosed.Task.IsCompleted)
return;
LastMessageSentAt = DateTime.UtcNow;
if (message.OperationMessage != null) {
if (message.OperationMessage != null)
{
await OnSendMessageAsync(message.OperationMessage);
} else {
}
else
{
await OnCloseOutputAsync(message.CloseStatus, message.CloseDescription);
_outputClosed.TrySetResult(true);
}
Expand All @@ -200,8 +263,10 @@ private async Task HandleMessageAsync(Message message)
/// <br/><br/>
/// This method is synchronized and will wait until completion before dispatching another message.
/// </summary>
protected virtual Task OnDispatchMessageAsync(IOperationMessageProcessor operationMessageProcessor, OperationMessage message)
=> operationMessageProcessor.OnMessageReceivedAsync(message);
protected virtual Task OnDispatchMessageAsync(
IOperationMessageProcessor operationMessageProcessor,
OperationMessage message
) => operationMessageProcessor.OnMessageReceivedAsync(message);

/// <summary>
/// Sends the specified message to the underlying <see cref="WebSocket"/>.
Expand All @@ -221,14 +286,20 @@ protected virtual async Task OnSendMessageAsync(OperationMessage message)
/// <br/><br/>
/// This method is synchronized and will wait until completion before sending another message or closing the output stream.
/// </summary>
protected virtual Task OnCloseOutputAsync(WebSocketCloseStatus closeStatus, string? closeDescription)
=> _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted);
protected virtual Task OnCloseOutputAsync(
WebSocketCloseStatus closeStatus,
string? closeDescription
) => _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted);

/// <summary>
/// A queue entry; see <see cref="HandleMessageAsync(Message)"/>.
/// </summary>
/// <param name="OperationMessage">The message to send, if set; if it is null then this is a closure message.</param>
/// <param name="CloseStatus">The close status.</param>
/// <param name="CloseDescription">The close description.</param>
private record struct Message(OperationMessage? OperationMessage, WebSocketCloseStatus CloseStatus, string? CloseDescription);
private record struct Message(
OperationMessage? OperationMessage,
WebSocketCloseStatus CloseStatus,
string? CloseDescription
);
}

0 comments on commit 0dc1a40

Please sign in to comment.