From 0c7efec09666babeebea1f608c91757b8faccca0 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 14 Feb 2024 18:02:10 +0000 Subject: [PATCH] [release/8.0] Fix HTTP/2 WebSocket Abort --- .../SocketsHttpHandler/Http2Connection.cs | 8 + .../Http/SocketsHttpHandler/Http2Stream.cs | 97 +++++-- ...etsHttpHandlerTest.Http2ExtendedConnect.cs | 157 ++++++++++- .../tests/AbortTest.Loopback.cs | 246 ++++++++++++++++++ .../tests/LoopbackHelper.cs | 21 +- .../LoopbackServer/Http2LoopbackStream.cs | 100 +++++++ .../LoopbackServer/LoopbackWebSocketServer.cs | 148 +++++++++++ .../WebSocketHandshakeHelper.cs | 134 ++++++++++ .../LoopbackServer/WebSocketRequestData.cs | 20 ++ .../System.Net.WebSockets.Client.Tests.csproj | 5 + 10 files changed, 902 insertions(+), 34 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index e9cce1c24d34d..14c9685a4f511 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1446,6 +1446,14 @@ private int WriteHeaderCollection(HttpRequestMessage request, HttpHeaders header continue; } + // Extended connect requests will use the response content stream for bidirectional communication. + // We will ignore any content set for such requests in Http2Stream.SendRequestBodyAsync, as it has no defined semantics. + // Drop the Content-Length header as well in the unlikely case it was set. + if (knownHeader == KnownHeaders.ContentLength && request.IsExtendedConnectRequest) + { + continue; + } + // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); string? separator = null; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index de66b7cfa103d..d834679274b4a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -105,7 +105,9 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection) _headerBudgetRemaining = connection._pool.Settings.MaxResponseHeadersByteLength; - if (_request.Content == null) + // Extended connect requests will use the response content stream for bidirectional communication. + // We will ignore any content set for such requests in SendRequestBodyAsync, as it has no defined semantics. + if (_request.Content == null || _request.IsExtendedConnectRequest) { _requestCompletionState = StreamCompletionState.Completed; if (_request.IsExtendedConnectRequest) @@ -173,7 +175,9 @@ public HttpResponseMessage GetAndClearResponse() public async Task SendRequestBodyAsync(CancellationToken cancellationToken) { - if (_request.Content == null) + // Extended connect requests will use the response content stream for bidirectional communication. + // Ignore any content set for such requests, as it has no defined semantics. + if (_request.Content == null || _request.IsExtendedConnectRequest) { Debug.Assert(_requestCompletionState == StreamCompletionState.Completed); return; @@ -250,6 +254,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) // and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios. Debug.Assert(_responseCompletionState == StreamCompletionState.Completed); _requestCompletionState = StreamCompletionState.Completed; + Debug.Assert(!ConnectProtocolEstablished); Complete(); return; } @@ -261,6 +266,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) _requestCompletionState = StreamCompletionState.Failed; SendReset(); + Debug.Assert(!ConnectProtocolEstablished); Complete(); } @@ -313,6 +319,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) if (complete) { + Debug.Assert(!ConnectProtocolEstablished); Complete(); } } @@ -420,7 +427,17 @@ private void Cancel() if (sendReset) { SendReset(); - Complete(); + + // Extended CONNECT notes: + // + // To prevent from calling it *twice*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *always* called + // from Extended CONNECT stream's Dispose(). + + if (!ConnectProtocolEstablished) + { + Complete(); + } } } @@ -810,7 +827,20 @@ public void OnHeadersComplete(bool endStream) Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}"); _responseCompletionState = StreamCompletionState.Completed; - if (_requestCompletionState == StreamCompletionState.Completed) + + // Extended CONNECT notes: + // + // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *only* called + // from Extended CONNECT stream's Dispose(). + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose(). + // + // The streaming in both ways happens over the single "response" stream instance, which makes + // _requestCompletionState *not indicative* of the actual state of the write side of the stream. + + if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished) { Complete(); } @@ -871,7 +901,20 @@ public void OnResponseData(ReadOnlySpan buffer, bool endStream) Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}"); _responseCompletionState = StreamCompletionState.Completed; - if (_requestCompletionState == StreamCompletionState.Completed) + + // Extended CONNECT notes: + // + // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *only* called + // from Extended CONNECT stream's Dispose(). + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose(). + // + // The streaming in both ways happens over the single "response" stream instance, which makes + // _requestCompletionState *not indicative* of the actual state of the write side of the stream. + + if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished) { Complete(); } @@ -1036,17 +1079,17 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken) Debug.Assert(_response != null && _response.Content != null); // Start to process the response body. var responseContent = (HttpConnectionResponseContent)_response.Content; - if (emptyResponse) + if (ConnectProtocolEstablished) + { + responseContent.SetStream(new Http2ReadWriteStream(this, closeResponseBodyOnDispose: true)); + } + else if (emptyResponse) { // If there are any trailers, copy them over to the response. Normally this would be handled by // the response stream hitting EOF, but if there is no response body, we do it here. MoveTrailersToResponseMessage(_response); responseContent.SetStream(EmptyReadStream.Instance); } - else if (ConnectProtocolEstablished) - { - responseContent.SetStream(new Http2ReadWriteStream(this)); - } else { responseContent.SetStream(new Http2ReadStream(this)); @@ -1309,8 +1352,25 @@ private async ValueTask SendDataAsync(ReadOnlyMemory buffer, CancellationT } } + // This method should only be called from Http2ReadWriteStream.Dispose() private void CloseResponseBody() { + // Extended CONNECT notes: + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose() + // (which, for Extended CONNECT case, will in turn call CloseResponseBody()) + // + // Similarly to QuicStream, disposal *gracefully* closes the write side of the stream + // (unless we've received RST_STREAM before) and *abortively* closes the read side + // of the stream (unless we've received EOS before). + + if (ConnectProtocolEstablished && _resetException is null) + { + // Gracefully close the write side of the Extended CONNECT stream + _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId)); + } + // Check if the response body has been fully consumed. bool fullyConsumed = false; Debug.Assert(!Monitor.IsEntered(SyncObject)); @@ -1323,6 +1383,7 @@ private void CloseResponseBody() } // If the response body isn't completed, cancel it now. + // This includes aborting the read side of the Extended CONNECT stream. if (!fullyConsumed) { Cancel(); @@ -1337,6 +1398,12 @@ private void CloseResponseBody() lock (SyncObject) { + if (ConnectProtocolEstablished) + { + // This should be the only place where Extended Connect stream is completed + Complete(); + } + _responseBuffer.Dispose(); } } @@ -1430,10 +1497,7 @@ private enum StreamCompletionState : byte private sealed class Http2ReadStream : Http2ReadWriteStream { - public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream) - { - base.CloseResponseBodyOnDispose = true; - } + public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream, closeResponseBodyOnDispose: true) { } public override bool CanWrite => false; @@ -1482,12 +1546,13 @@ public class Http2ReadWriteStream : HttpBaseStream private Http2Stream? _http2Stream; private readonly HttpResponseMessage _responseMessage; - public Http2ReadWriteStream(Http2Stream http2Stream) + public Http2ReadWriteStream(Http2Stream http2Stream, bool closeResponseBodyOnDispose = false) { Debug.Assert(http2Stream != null); Debug.Assert(http2Stream._response != null); _http2Stream = http2Stream; _responseMessage = _http2Stream._response; + CloseResponseBodyOnDispose = closeResponseBodyOnDispose; } ~Http2ReadWriteStream() @@ -1503,7 +1568,7 @@ public Http2ReadWriteStream(Http2Stream http2Stream) } } - protected bool CloseResponseBodyOnDispose { get; set; } + protected bool CloseResponseBodyOnDispose { get; private init; } protected override void Dispose(bool disposing) { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs index 0ea6ae9e13f60..cb1a15df14e77 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Net.Test.Common; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -31,6 +33,7 @@ public static IEnumerable UseSsl_MemberData() [MemberData(nameof(UseSsl_MemberData))] public async Task Connect_ReadWriteResponseStream(bool useSsl) { + const int MessageCount = 3; byte[] clientMessage = new byte[] { 1, 2, 3 }; byte[] serverMessage = new byte[] { 4, 5, 6, 7 }; @@ -43,19 +46,39 @@ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); request.Headers.Protocol = "foo"; + bool readFromContentStream = false; + + // We won't send the content bytes, but we will send content headers. + // Since we're dropping the content, we'll also drop the Content-Length header. + request.Content = new StreamContent(new DelegateStream( + readAsyncFunc: (_, _, _, _) => + { + readFromContentStream = true; + throw new UnreachableException(); + })); + + request.Headers.Add("User-Agent", "foo"); + request.Content.Headers.Add("Content-Language", "bar"); + request.Content.Headers.ContentLength = 42; + using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); using Stream responseStream = await response.Content.ReadAsStreamAsync(); - await responseStream.WriteAsync(clientMessage); - await responseStream.FlushAsync(); + for (int i = 0; i < MessageCount; i++) + { + await responseStream.WriteAsync(clientMessage); + await responseStream.FlushAsync(); - byte[] readBuffer = new byte[serverMessage.Length]; - await responseStream.ReadExactlyAsync(readBuffer); - Assert.Equal(serverMessage, readBuffer); + byte[] readBuffer = new byte[serverMessage.Length]; + await responseStream.ReadExactlyAsync(readBuffer); + Assert.Equal(serverMessage, readBuffer); + } // Receive server's EOS - Assert.Equal(0, await responseStream.ReadAsync(readBuffer)); + Assert.Equal(0, await responseStream.ReadAsync(new byte[1])); + + Assert.False(readFromContentStream); clientCompleted.SetResult(); }, @@ -63,14 +86,21 @@ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri { await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + (int streamId, HttpRequestData request) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + + Assert.Equal("foo", request.GetSingleHeaderValue("User-Agent")); + Assert.Equal("bar", request.GetSingleHeaderValue("Content-Language")); + Assert.Equal(0, request.GetHeaderValueCount("Content-Length")); await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false); - DataFrame dataFrame = await connection.ReadDataFrameAsync(); - Assert.Equal(clientMessage, dataFrame.Data.ToArray()); + for (int i = 0; i < MessageCount; i++) + { + DataFrame dataFrame = await connection.ReadDataFrameAsync(); + Assert.Equal(clientMessage, dataFrame.Data.ToArray()); - await connection.SendResponseDataAsync(streamId, serverMessage, endStream: true); + await connection.SendResponseDataAsync(streamId, serverMessage, endStream: i == MessageCount - 1); + } await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout); }, options: new GenericLoopbackOptions { UseSsl = useSsl }); @@ -163,5 +193,112 @@ await server.AcceptConnectionAsync(async connection => await new[] { serverTask, clientTask }.WhenAllOrAnyFailed().WaitAsync(TestHelper.PassingTestTimeout); } + + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public async Task Connect_ServerSideEOS_ReceivedByClient(bool useSsl) + { + var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout); + var serverReceivedEOS = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync( + clientFunc: async uri => + { + var client = CreateHttpClient(); + var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); + request.Headers.Protocol = "foo"; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token); + var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token); + + // receive server's EOS + Assert.Equal(0, await responseStream.ReadAsync(new byte[1], timeoutTcs.Token)); + + // send client's EOS + responseStream.Dispose(); + + // wait for "ack" from server + await serverReceivedEOS.Task.WaitAsync(timeoutTcs.Token); + + // can dispose handler now + client.Dispose(); + }, + serverFunc: async server => + { + await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync( + new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId, endStream: false); + + // send server's EOS + await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true); + + // receive client's EOS "in response" to server's EOS + var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, eosFrame.StreamId); + Assert.Equal(0, eosFrame.Data.Length); + Assert.True(eosFrame.EndStreamFlag); + + serverReceivedEOS.SetResult(); + + // on handler dispose, client should shutdown the connection without sending additional frames + await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token); + }, + options: new GenericLoopbackOptions { UseSsl = useSsl }); + } + + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public async Task Connect_ClientSideEOS_ReceivedByServer(bool useSsl) + { + var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout); + var serverReceivedRst = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync( + clientFunc: async uri => + { + var client = CreateHttpClient(); + var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); + request.Headers.Protocol = "foo"; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token); + var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token); + + // send client's EOS + // this will also send RST_STREAM as we didn't receive server's EOS before + responseStream.Dispose(); + + // wait for "ack" from server + await serverReceivedRst.Task.WaitAsync(timeoutTcs.Token); + + // can dispose handler now + client.Dispose(); + }, + serverFunc: async server => + { + await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync( + new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId, endStream: false); + + // receive client's EOS + var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, eosFrame.StreamId); + Assert.Equal(0, eosFrame.Data.Length); + Assert.True(eosFrame.EndStreamFlag); + + // receive client's RST_STREAM as we didn't send server's EOS before + var rstFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, rstFrame.StreamId); + + serverReceivedRst.SetResult(); + + // on handler dispose, client should shutdown the connection without sending additional frames + await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token); + }, + options: new GenericLoopbackOptions { UseSsl = useSsl }); + } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs new file mode 100644 index 0000000000000..0aa83697a9de7 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -0,0 +1,246 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + [ConditionalClass(typeof(ClientWebSocketTestBase), nameof(WebSocketsSupported))] + [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets are not supported on browser")] + public abstract class AbortTest_Loopback : ClientWebSocketTestBase + { + public AbortTest_Loopback(ITestOutputHelper output) : base(output) { } + + protected virtual Version HttpVersion => Net.HttpVersion.Version11; + + [Theory] + [MemberData(nameof(AbortClient_MemberData))] + public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + return LoopbackWebSocketServer.RunAsync( + async (clientWebSocket, token) => + { + if (verifySendReceive) + { + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + } + + switch (abortType) + { + case AbortType.Abort: + clientWebSocket.Abort(); + break; + case AbortType.Dispose: + clientWebSocket.Dispose(); + break; + } + }, + async (serverWebSocket, token) => + { + if (verifySendReceive) + { + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + } + + var readBuffer = new byte[1]; + var exception = await Assert.ThrowsAsync(async () => + await serverWebSocket.ReceiveAsync(readBuffer, token)); + + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + Assert.Equal(WebSocketState.Aborted, serverWebSocket.State); + }, + new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()), + timeoutCts.Token); + } + + [Theory] + [MemberData(nameof(ServerPrematureEos_MemberData))] + public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null) + { + DisposeServerWebSocket = false, + ManualServerHandshakeResponse = true + }; + + var serverReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + return LoopbackWebSocketServer.RunAsync( + async uri => + { + var token = timeoutCts.Token; + var clientOptions = globalOptions with { HttpInvoker = GetInvoker() }; + var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false); + + if (serverEosType == ServerEosType.AfterSomeData) + { + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + } + + // only one side of the stream was closed. the other should work + await clientWebSocket.SendAsync(clientMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false); + + var exception = await Assert.ThrowsAsync(() => clientWebSocket.ReceiveAsync(new byte[1], token)); + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + + clientReceivedEosTcs.SetResult(); + clientWebSocket.Dispose(); + }, + async (requestData, token) => + { + WebSocket serverWebSocket = null!; + await SendServerResponseAndEosAsync( + requestData, + serverEosType, + (wsData, ct) => + { + var wsOptions = new WebSocketCreationOptions { IsServer = true }; + serverWebSocket = WebSocket.CreateFromStream(wsData.WebSocketStream, wsOptions); + + return serverEosType == ServerEosType.AfterSomeData + ? VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, ct) + : Task.CompletedTask; + }, + token); + + Assert.NotNull(serverWebSocket); + + // only one side of the stream was closed. the other should work + var readBuffer = new byte[clientMsg.Length]; + var result = await serverWebSocket.ReceiveAsync(readBuffer, token); + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + Assert.Equal(clientMsg.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal(clientMsg, readBuffer); + + await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false); + + var exception = await Assert.ThrowsAsync(() => serverWebSocket.ReceiveAsync(readBuffer, token)); + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + + serverWebSocket.Dispose(); + }, + globalOptions, + timeoutCts.Token); + } + + protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken) + => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2 + + private static readonly bool[] Bool_Values = new[] { false, true }; + private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; + + public static IEnumerable AbortClient_MemberData() + { + foreach (var abortType in Enum.GetValues()) + { + foreach (var useSsl in UseSsl_Values) + { + foreach (var verifySendReceive in Bool_Values) + { + yield return new object[] { abortType, useSsl, verifySendReceive }; + } + } + } + } + + public static IEnumerable ServerPrematureEos_MemberData() + { + foreach (var serverEosType in Enum.GetValues()) + { + foreach (var useSsl in UseSsl_Values) + { + yield return new object[] { serverEosType, useSsl }; + } + } + } + + public enum AbortType + { + Abort, + Dispose + } + + public enum ServerEosType + { + WithHeaders, + RightAfterHeaders, + AfterSomeData + } + + private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) + { + var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); + + var recvBuf = new byte[remoteMsg.Length * 2]; + var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false); + + Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType); + Assert.Equal(remoteMsg.Length, recvResult.Count); + Assert.True(recvResult.EndOfMessage); + Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]); + + localAckTcs.SetResult(); + + await sendTask.ConfigureAwait(false); + await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } + + // --- HTTP/1.1 WebSocket loopback tests --- + + public class AbortTest_Invoker_Loopback : AbortTest_Loopback + { + public AbortTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseCustomInvoker => true; + } + + public class AbortTest_HttpClient_Loopback : AbortTest_Loopback + { + public AbortTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseHttpClient => true; + } + + public class AbortTest_SharedHandler_Loopback : AbortTest_Loopback + { + public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } + } + + // --- HTTP/2 WebSocket loopback tests --- + + public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback + { + public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) + => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + } + + public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback + { + public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) + => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 48d167b072f78..cee509ee06846 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -28,14 +28,7 @@ public static async Task> WebSocketHandshakeAsync(Loo if (headerName == "Sec-WebSocket-Key") { string headerValue = tokens[1].Trim(); - string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue); - serverResponse = - "HTTP/1.1 101 Switching Protocols\r\n" + - "Content-Length: 0\r\n" + - "Upgrade: websocket\r\n" + - "Connection: Upgrade\r\n" + - (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + - "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; + serverResponse = GetServerResponseString(headerValue, extensions); } } } @@ -50,6 +43,18 @@ public static async Task> WebSocketHandshakeAsync(Loo return null; } + public static string GetServerResponseString(string secWebSocketKey, string? extensions = null) + { + var responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(secWebSocketKey); + return + "HTTP/1.1 101 Switching Protocols\r\n" + + "Content-Length: 0\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; + } + private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey) { // GUID specified by RFC 6455. diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs new file mode 100644 index 0000000000000..1b3b51840ec99 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Test.Common +{ + public class Http2LoopbackStream : Stream + { + private readonly Http2LoopbackConnection _connection; + private readonly int _streamId; + private bool _readEnded; + private ReadOnlyMemory _leftoverReadData; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public Http2LoopbackConnection Connection => _connection; + public int StreamId => _streamId; + + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId) + { + _connection = connection; + _streamId = streamId; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (!_leftoverReadData.IsEmpty) + { + int read = Math.Min(buffer.Length, _leftoverReadData.Length); + _leftoverReadData.Span.Slice(0, read).CopyTo(buffer.Span); + _leftoverReadData = _leftoverReadData.Slice(read); + return read; + } + + if (_readEnded) + { + return 0; + } + + DataFrame dataFrame = (DataFrame)await _connection.ReadFrameAsync(cancellationToken); + Assert.Equal(_streamId, dataFrame.StreamId); + _leftoverReadData = dataFrame.Data; + _readEnded = dataFrame.EndStreamFlag; + + return await ReadAsync(buffer, cancellationToken); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + await _connection.SendResponseDataAsync(_streamId, buffer, endStream: false); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + protected override void Dispose(bool disposing) => DisposeAsync().GetAwaiter().GetResult(); + + public override async ValueTask DisposeAsync() + { + try + { + await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); + + if (!_readEnded) + { + var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId); + await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); + } + } + catch (IOException) + { + // Ignore connection errors + } + catch (SocketException) + { + // Ignore connection errors + } + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); + public override void SetLength(long value) => throw new NotImplementedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException(); + public override long Length => throw new NotImplementedException(); + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs new file mode 100644 index 0000000000000..b24e2e20d40df --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Client.Tests +{ + public static class LoopbackWebSocketServer + { + public static Task RunAsync( + Func clientWebSocketFunc, + Func serverWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + Assert.False(options.ManualServerHandshakeResponse, "Not supported in this overload"); + + return RunAsyncPrivate( + uri => RunClientAsync(uri, clientWebSocketFunc, options, cancellationToken), + (requestData, token) => RunServerAsync(requestData, serverWebSocketFunc, options, token), + options, + cancellationToken); + } + + public static Task RunAsync( + Func loopbackClientFunc, + Func loopbackServerFunc, + Options options, + CancellationToken cancellationToken) + { + Assert.False(options.DisposeClientWebSocket, "Not supported in this overload"); + Assert.False(options.DisposeServerWebSocket, "Not supported in this overload"); + Assert.False(options.DisposeHttpInvoker, "Not supported in this overload"); + Assert.Null(options.HttpInvoker); // Not supported in this overload + + return RunAsyncPrivate(loopbackClientFunc, loopbackServerFunc, options, cancellationToken); + } + + private static Task RunAsyncPrivate( + Func loopbackClientFunc, + Func loopbackServerFunc, + Options options, + CancellationToken cancellationToken) + { + bool sendDefaultServerHandshakeResponse = !options.ManualServerHandshakeResponse; + if (options.HttpVersion == HttpVersion.Version11) + { + return LoopbackServer.CreateClientAndServerAsync( + loopbackClientFunc, + async server => + { + await server.AcceptConnectionAsync(async connection => + { + var requestData = await WebSocketHandshakeHelper.ProcessHttp11RequestAsync(connection, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false); + await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); + }); + }, + new LoopbackServer.Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); + } + else if (options.HttpVersion == HttpVersion.Version20) + { + return Http2LoopbackServer.CreateClientAndServerAsync( + loopbackClientFunc, + async server => + { + var requestData = await WebSocketHandshakeHelper.ProcessHttp2RequestAsync(server, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false); + var http2Connection = requestData.Http2Connection!; + var http2StreamId = requestData.Http2StreamId.Value; + + await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); + + await http2Connection.DisposeAsync().ConfigureAwait(false); + }, + new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); + } + else + { + throw new ArgumentException(nameof(options.HttpVersion)); + } + } + + private static async Task RunServerAsync( + WebSocketRequestData requestData, + Func serverWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + var wsOptions = new WebSocketCreationOptions { IsServer = true }; + var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions); + + await serverWebSocketFunc(serverWebSocket, cancellationToken).ConfigureAwait(false); + + if (options.DisposeServerWebSocket) + { + serverWebSocket.Dispose(); + } + } + + private static async Task RunClientAsync( + Uri uri, + Func clientWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + var clientWebSocket = await GetConnectedClientAsync(uri, options, cancellationToken).ConfigureAwait(false); + + await clientWebSocketFunc(clientWebSocket, cancellationToken).ConfigureAwait(false); + + if (options.DisposeClientWebSocket) + { + clientWebSocket.Dispose(); + } + + if (options.DisposeHttpInvoker) + { + options.HttpInvoker?.Dispose(); + } + } + + public static async Task GetConnectedClientAsync(Uri uri, Options options, CancellationToken cancellationToken) + { + var clientWebSocket = new ClientWebSocket(); + clientWebSocket.Options.HttpVersion = options.HttpVersion; + clientWebSocket.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + if (options.UseSsl && options.HttpInvoker is null) + { + clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; }; + } + + await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false); + + return clientWebSocket; + } + + public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker? HttpInvoker) + { + public bool DisposeServerWebSocket { get; set; } = true; + public bool DisposeClientWebSocket { get; set; } + public bool DisposeHttpInvoker { get; set; } + public bool ManualServerHandshakeResponse { get; set; } + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs new file mode 100644 index 0000000000000..f4d2f42f5edbb --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Net.Sockets; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Client.Tests +{ + public static class WebSocketHandshakeHelper + { + public static async Task ProcessHttp11RequestAsync(LoopbackServer.Connection connection, bool sendServerResponse = true, CancellationToken cancellationToken = default) + { + List headers = await connection.ReadRequestHeaderAsync().WaitAsync(cancellationToken).ConfigureAwait(false); + + var data = new WebSocketRequestData() + { + HttpVersion = HttpVersion.Version11, + Http11Connection = connection + }; + + foreach (string header in headers.Skip(1)) + { + string[] tokens = header.Split(new char[] { ':' }, StringSplitOptions.RemoveEmptyEntries); + if (tokens.Length is 1 or 2) + { + data.Headers.Add( + tokens[0].Trim(), + tokens.Length == 2 ? tokens[1].Trim() : null); + } + } + + var isValidOpeningHandshake = data.Headers.TryGetValue("Sec-WebSocket-Key", out var secWebSocketKey); + Assert.True(isValidOpeningHandshake); + + if (sendServerResponse) + { + await SendHttp11ServerResponseAsync(connection, secWebSocketKey, cancellationToken).ConfigureAwait(false); + } + + data.WebSocketStream = connection.Stream; + return data; + } + + private static async Task SendHttp11ServerResponseAsync(LoopbackServer.Connection connection, string secWebSocketKey, CancellationToken cancellationToken) + { + var serverResponse = LoopbackHelper.GetServerResponseString(secWebSocketKey); + await connection.WriteStringAsync(serverResponse).WaitAsync(cancellationToken).ConfigureAwait(false); + } + + public static async Task ProcessHttp2RequestAsync(Http2LoopbackServer server, bool sendServerResponse = true, CancellationToken cancellationToken = default) + { + var connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }) + .WaitAsync(cancellationToken).ConfigureAwait(false); + + (int streamId, var httpRequestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false) + .WaitAsync(cancellationToken).ConfigureAwait(false); + + var data = new WebSocketRequestData + { + HttpVersion = HttpVersion.Version20, + Http2Connection = connection, + Http2StreamId = streamId + }; + + foreach (var header in httpRequestData.Headers) + { + Assert.NotNull(header.Name); + data.Headers.Add(header.Name, header.Value); + } + + var isValidOpeningHandshake = httpRequestData.Method == HttpMethod.Connect.ToString() && data.Headers.ContainsKey(":protocol"); + Assert.True(isValidOpeningHandshake); + + if (sendServerResponse) + { + await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + data.WebSocketStream = new Http2LoopbackStream(connection, streamId); + return data; + } + + private static async Task SendHttp2ServerResponseAsync(Http2LoopbackConnection connection, int streamId, bool endStream = false, CancellationToken cancellationToken = default) + { + // send status 200 OK to establish websocket + // we don't need to send anything additional as Sec-WebSocket-Key is not used for HTTP/2 + // note: endStream=true is abnormal and used for testing premature EOS scenarios only + await connection.SendResponseHeadersAsync(streamId, endStream: endStream).WaitAsync(cancellationToken).ConfigureAwait(false); + } + + public static async Task SendHttp11ServerResponseAndEosAsync(WebSocketRequestData requestData, Func? requestDataCallback, CancellationToken cancellationToken) + { + Assert.Equal(HttpVersion.Version11, requestData.HttpVersion); + + // sending default handshake response + await SendHttp11ServerResponseAsync(requestData.Http11Connection!, requestData.Headers["Sec-WebSocket-Key"], cancellationToken).ConfigureAwait(false); + + if (requestDataCallback is not null) + { + await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false); + } + + // send server EOS (half-closing from server side) + requestData.Http11Connection!.Socket.Shutdown(SocketShutdown.Send); + } + + public static async Task SendHttp2ServerResponseAndEosAsync(WebSocketRequestData requestData, bool eosInHeadersFrame, Func? requestDataCallback, CancellationToken cancellationToken) + { + Assert.Equal(HttpVersion.Version20, requestData.HttpVersion); + + var connection = requestData.Http2Connection!; + var streamId = requestData.Http2StreamId!.Value; + + await SendHttp2ServerResponseAsync(connection, streamId, endStream: eosInHeadersFrame, cancellationToken).ConfigureAwait(false); + + if (requestDataCallback is not null) + { + await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false); + } + + if (!eosInHeadersFrame) + { + // send server EOS (half-closing from server side) + await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true).ConfigureAwait(false); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs new file mode 100644 index 0000000000000..799157a370f07 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Test.Common; + +namespace System.Net.WebSockets.Client.Tests +{ + public class WebSocketRequestData + { + public Dictionary Headers { get; set; } = new Dictionary(); + public Stream? WebSocketStream { get; set; } + + public Version HttpVersion { get; set; } + public LoopbackServer.Connection? Http11Connection { get; set; } + public Http2LoopbackConnection? Http2Connection { get; set; } + public int? Http2StreamId { get; set; } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 8f23e7925a451..a4f20d03002e1 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -55,6 +55,7 @@ + @@ -64,6 +65,10 @@ + + + +