diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 4c7e48ccda820..28cc177fa036b 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -44,6 +44,7 @@ public partial class QuicException : System.Exception { public QuicException(string? message) { } public QuicException(string? message, System.Exception? innerException) { } + public QuicException(string? message, System.Exception? innerException, int result) { } } public static partial class QuicImplementationProviders { diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx index aeecb31ba2976..38a30c86c019e 100644 --- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx @@ -159,6 +159,9 @@ The remote certificate is invalid because of errors in the certificate chain: {0} + + Connection is not connected. + The application protocol list is invalid. diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs index 5fb0a24c9a905..680069bf7b815 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Net.Sockets; + namespace System.Net.Quic.Implementations.MsQuic.Internal { internal static class QuicExceptionHelpers @@ -15,7 +17,22 @@ internal static void ThrowIfFailed(uint status, string? message = null, Exceptio internal static Exception CreateExceptionForHResult(uint status, string? message = null, Exception? innerException = null) { - return new QuicException($"{message} Error Code: {MsQuicStatusCodes.GetError(status)}", innerException); + return new QuicException($"{message} Error Code: {MsQuicStatusCodes.GetError(status)}", innerException, MapMsQuicStatusToHResult(status)); + } + + internal static int MapMsQuicStatusToHResult(uint status) + { + switch (status) + { + case MsQuicStatusCodes.ConnectionRefused: + return (int)SocketError.ConnectionRefused; // 0x8007274D - WSAECONNREFUSED + case MsQuicStatusCodes.ConnectionTimeout: + return (int)SocketError.TimedOut; // 0x8007274C - WSAETIMEDOUT + case MsQuicStatusCodes.HostUnreachable: + return (int)SocketError.HostUnreachable; + default: + return 0; + } } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index b3cbb93ceda58..a02417ff710c3 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -171,7 +171,12 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf // constructor for outbound connections public MsQuicConnection(QuicClientConnectionOptions options) { - _remoteEndPoint = options.RemoteEndPoint!; + if (options.RemoteEndPoint == null) + { + throw new ArgumentNullException(nameof(options.RemoteEndPoint)); + } + + _remoteEndPoint = options.RemoteEndPoint; _configuration = SafeMsQuicConfigurationHandle.Create(options); _state.RemoteCertificateRequired = true; if (options.ClientAuthenticationOptions != null) @@ -523,6 +528,10 @@ internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(Cancellati internal override QuicStreamProvider OpenUnidirectionalStream() { ThrowIfDisposed(); + if (!Connected) + { + throw new InvalidOperationException(SR.net_quic_not_connected); + } return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL); } @@ -530,6 +539,10 @@ internal override QuicStreamProvider OpenUnidirectionalStream() internal override QuicStreamProvider OpenBidirectionalStream() { ThrowIfDisposed(); + if (!Connected) + { + throw new InvalidOperationException(SR.net_quic_not_connected); + } return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.NONE); } @@ -552,7 +565,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d if (_configuration is null) { - throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener."); + throw new InvalidOperationException($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener."); } QUIC_ADDRESS_FAMILY af = _remoteEndPoint.AddressFamily switch @@ -560,7 +573,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d AddressFamily.Unspecified => QUIC_ADDRESS_FAMILY.UNSPEC, AddressFamily.InterNetwork => QUIC_ADDRESS_FAMILY.INET, AddressFamily.InterNetworkV6 => QUIC_ADDRESS_FAMILY.INET6, - _ => throw new Exception(SR.Format(SR.net_quic_unsupported_address_family, _remoteEndPoint.AddressFamily)) + _ => throw new ArgumentException(SR.Format(SR.net_quic_unsupported_address_family, _remoteEndPoint.AddressFamily)) }; Debug.Assert(_state.StateGCHandle.IsAllocated); @@ -592,7 +605,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d } else { - throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'."); + throw new ArgumentException($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'."); } // We store TCS to local variable to avoid NRE if callbacks finish fast and set _state.ConnectTcs to null. @@ -759,7 +772,7 @@ private void Dispose(bool disposing) if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Connection disposing {disposing}"); // If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown. - if (_state.Handle != null) + if (_state.Handle != null && !_state.Handle.IsInvalid && !_state.Handle.IsClosed) { // Handle can be null if outbound constructor failed and we are called from finalizer. Debug.Assert(!Monitor.IsEntered(_state)); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs index b5daea6e0d0ed..f163c8d161b0b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs @@ -12,7 +12,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; -using System.Security.Authentication; namespace System.Net.Quic.Implementations.MsQuic { @@ -219,15 +218,14 @@ private static unsafe uint NativeCallbackHandler( IntPtr context, ref ListenerEvent evt) { - if (evt.Type != QUIC_LISTENER_EVENT.NEW_CONNECTION) - { - return MsQuicStatusCodes.InternalError; - } - GCHandle gcHandle = GCHandle.FromIntPtr(context); Debug.Assert(gcHandle.IsAllocated); Debug.Assert(gcHandle.Target is not null); var state = (State)gcHandle.Target; + if (evt.Type != QUIC_LISTENER_EVENT.NEW_CONNECTION) + { + return MsQuicStatusCodes.InternalError; + } SafeMsQuicConnectionHandle? connectionHandle = null; MsQuicConnection? msQuicConnection = null; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs index 7336c90831840..4cd9745d82a8f 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs @@ -13,5 +13,11 @@ public QuicException(string? message, Exception? innerException) : base(message, innerException) { } + + public QuicException(string? message, Exception? innerException, int result) + : base(message, innerException) + { + HResult = result; + } } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 4132bcef76cb7..9e6b7b504fb16 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -28,38 +28,31 @@ public MsQuicTests(ITestOutputHelper output) : base(output) { } [Fact] public async Task UnidirectionalAndBidirectionalStreamCountsWork() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); Assert.Equal(100, serverConnection.GetRemoteAvailableBidirectionalStreamCount()); Assert.Equal(100, serverConnection.GetRemoteAvailableUnidirectionalStreamCount()); + serverConnection.Dispose(); + clientConnection.Dispose(); } [Fact] public async Task UnidirectionalAndBidirectionalChangeValues() { - using QuicListener listener = CreateQuicListener(); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() + QuicClientConnectionOptions listenerOptions = new QuicClientConnectionOptions() { MaxBidirectionalStreams = 10, MaxUnidirectionalStreams = 20, - RemoteEndPoint = listener.ListenEndPoint, ClientAuthenticationOptions = GetSslClientAuthenticationOptions() }; - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; - + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listenerOptions); Assert.Equal(100, clientConnection.GetRemoteAvailableBidirectionalStreamCount()); Assert.Equal(100, clientConnection.GetRemoteAvailableUnidirectionalStreamCount()); Assert.Equal(10, serverConnection.GetRemoteAvailableBidirectionalStreamCount()); Assert.Equal(20, serverConnection.GetRemoteAvailableUnidirectionalStreamCount()); + serverConnection.Dispose(); + clientConnection.Dispose(); } [Fact] @@ -68,21 +61,14 @@ public async Task ConnectWithCertificateChain() (X509Certificate2 certificate, X509Certificate2Collection chain) = System.Net.Security.Tests.TestHelper.GenerateCertificates("localhost", longChain: true); X509Certificate2 rootCA = chain[chain.Count - 1]; - var quicOptions = new QuicListenerOptions(); - quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); - quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - quicOptions.ServerAuthenticationOptions.ServerCertificateContext = SslStreamCertificateContext.Create(certificate, chain); - quicOptions.ServerAuthenticationOptions.ServerCertificate = null; - - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ServerAuthenticationOptions.ServerCertificateContext = SslStreamCertificateContext.Create(certificate, chain); + listenerOptions.ServerAuthenticationOptions.ServerCertificate = null; - options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { Assert.Equal(certificate.Subject, cert.Subject); Assert.Equal(certificate.Issuer, cert.Issuer); @@ -108,12 +94,11 @@ public async Task ConnectWithCertificateChain() return ret; }; - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions); Assert.Equal(certificate, clientConnection.RemoteCertificate); Assert.Null(serverConnection.RemoteCertificate); + serverConnection.Dispose(); + clientConnection.Dispose(); } [Fact] @@ -122,33 +107,27 @@ public async Task CertificateCallbackThrowPropagates() using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout); X509Certificate? receivedCertificate = null; - var quicOptions = new QuicListenerOptions(); - quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); - quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions); - options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.RemoteEndPoint = listener.ListenEndPoint; + clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { receivedCertificate = cert; throw new ArithmeticException("foobar"); }; - options.ClientAuthenticationOptions.TargetHost = "foobar1"; - - QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1"; + QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask()); QuicConnection serverConnection = await serverTask; - Assert.Equal(quicOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate); + Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate); clientConnection.Dispose(); serverConnection.Dispose(); @@ -166,11 +145,11 @@ public async Task ConnectWithCertificateCallback() string? receivedHostName = null; X509Certificate? receivedCertificate = null; - var quicOptions = new QuicListenerOptions(); - quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); - quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - quicOptions.ServerAuthenticationOptions.ServerCertificate = null; - quicOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) => + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ServerAuthenticationOptions.ServerCertificate = null; + listenerOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) => { receivedHostName = hostName; if (hostName == "foobar1") @@ -185,52 +164,36 @@ public async Task ConnectWithCertificateCallback() return null; }; - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; - - options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions); + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1"; + clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { receivedCertificate = cert; return true; }; - options.ClientAuthenticationOptions.TargetHost = "foobar1"; - - QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - - Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); - await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); - QuicConnection serverConnection = serverTask.Result; - - Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); + Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); Assert.Equal(c1, receivedCertificate); clientConnection.Dispose(); serverConnection.Dispose(); // This should fail when callback return null. - options.ClientAuthenticationOptions.TargetHost = "foobar3"; - clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + clientOptions.ClientAuthenticationOptions.TargetHost = "foobar3"; + clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask(); await Assert.ThrowsAsync(() => clientTask); - Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); clientConnection.Dispose(); // Do this last to make sure Listener is still functional. - options.ClientAuthenticationOptions.TargetHost = "foobar2"; + clientOptions.ClientAuthenticationOptions.TargetHost = "foobar2"; expectedCertificate = c2; - clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); - await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); - serverConnection = serverTask.Result; - - Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + (clientConnection, serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); + Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); Assert.Equal(c2, receivedCertificate); clientConnection.Dispose(); serverConnection.Dispose(); @@ -245,18 +208,13 @@ public async Task ConnectWithCertificateForDifferentName_Throws() quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); quicOptions.ServerAuthenticationOptions.ServerCertificate = certificate; - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; - + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.RemoteEndPoint = listener.ListenEndPoint; // Use different target host on purpose to get RemoteCertificateNameMismatch ssl error. - options.ClientAuthenticationOptions.TargetHost = "loopback"; - options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + clientOptions.ClientAuthenticationOptions.TargetHost = "loopback"; + clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { Assert.Equal(certificate.Subject, cert.Subject); Assert.Equal(certificate.Issuer, cert.Issuer); @@ -264,7 +222,7 @@ public async Task ConnectWithCertificateForDifferentName_Throws() return SslPolicyErrors.None == errors; }; - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); ValueTask clientTask = clientConnection.ConnectAsync(); using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); @@ -281,20 +239,13 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str var ipAddress = IPAddress.Parse(ipString); (X509Certificate2 certificate, _) = System.Net.Security.Tests.TestHelper.GenerateCertificates(expectsError ? "badhost" : "localhost"); - var quicOptions = new QuicListenerOptions(); - quicOptions.ListenEndPoint = new IPEndPoint(ipAddress, 0); - quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - quicOptions.ServerAuthenticationOptions.ServerCertificate = certificate; - - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = new IPEndPoint(ipAddress, listener.ListenEndPoint.Port), - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(ipAddress, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ServerAuthenticationOptions.ServerCertificate = certificate; - options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { Assert.Equal(certificate.Subject, cert.Subject); Assert.Equal(certificate.Issuer, cert.Issuer); @@ -302,11 +253,7 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str return true; }; - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - ValueTask clientTask = clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientTask; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions); } [Fact] @@ -315,11 +262,11 @@ public async Task ConnectWithClientCertificate() { bool clientCertificateOK = false; - var serverOptions = new QuicListenerOptions(); - serverOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); - serverOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - serverOptions.ServerAuthenticationOptions.ClientCertificateRequired = true; - serverOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true; + listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { _output.WriteLine("client certificate {0}", cert); Assert.NotNull(cert); @@ -328,19 +275,11 @@ public async Task ConnectWithClientCertificate() clientCertificateOK = true; return true; }; - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, serverOptions); - QuicClientConnectionOptions clientOptions = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions); + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate }; - - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); // Verify functionality of the connections. await PingPong(clientConnection, serverConnection); @@ -349,17 +288,16 @@ public async Task ConnectWithClientCertificate() Assert.Equal(ClientCertificate, serverConnection.RemoteCertificate); await serverConnection.CloseAsync(0); + clientConnection.Dispose(); + serverConnection.Dispose(); } [Fact] public async Task WaitForAvailableUnidirectionStreamsAsyncWorks() { - using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; - listener.Dispose(); + QuicListenerOptions listenerOptions = CreateQuicListenerOptions(); + listenerOptions.MaxUnidirectionalStreams = 1; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions); // No stream opened yet, should return immediately. Assert.True(clientConnection.WaitForAvailableUnidirectionalStreamsAsync().IsCompletedSuccessfully); @@ -375,17 +313,16 @@ public async Task WaitForAvailableUnidirectionStreamsAsyncWorks() newStream.Dispose(); await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + clientConnection.Dispose(); + serverConnection.Dispose(); } [Fact] public async Task WaitForAvailableBidirectionStreamsAsyncWorks() { - using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + QuicListenerOptions listenerOptions = CreateQuicListenerOptions(); + listenerOptions.MaxBidirectionalStreams = 1; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions); // No stream opened yet, should return immediately. Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully); @@ -401,31 +338,23 @@ public async Task WaitForAvailableBidirectionStreamsAsyncWorks() QuicStream newStream = await serverConnection.AcceptStreamAsync(); newStream.Dispose(); await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + clientConnection.Dispose(); + serverConnection.Dispose(); } [Fact] [OuterLoop("May take several seconds")] public async Task SetListenerTimeoutWorksWithSmallTimeout() { - var quicOptions = new QuicListenerOptions(); - quicOptions.IdleTimeout = TimeSpan.FromSeconds(1); - quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); - quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); - - using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); - - QuicClientConnectionOptions options = new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.ListenEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), - }; - - using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + var listenerOptions = new QuicListenerOptions(); + listenerOptions.IdleTimeout = TimeSpan.FromSeconds(1); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions); await Assert.ThrowsAsync(async () => await serverConnection.AcceptStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100))); + serverConnection.Dispose(); + clientConnection.Dispose(); } [Theory] @@ -522,13 +451,7 @@ public enum WriteType [Fact] public async Task CallDifferentWriteMethodsWorks() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; - + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); ReadOnlyMemory helloWorld = Encoding.ASCII.GetBytes("Hello world!"); ReadOnlySequence ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray()); @@ -548,6 +471,8 @@ public async Task CallDifferentWriteMethodsWorks() res = await serverStream.ReadAsync(memory); Assert.Equal(24, res); + clientConnection.Dispose(); + serverConnection.Dispose(); } [Fact] @@ -704,17 +629,14 @@ public async Task ManagedAVE_MinimalFailingTest() { async Task GetStreamIdWithoutStartWorks() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); Assert.Equal(0, clientStream.StreamId); // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer + clientConnection.Dispose(); + serverConnection.Dispose(); } await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15)); @@ -727,12 +649,7 @@ public async Task DisposingConnection_OK() { async Task GetStreamIdWithoutStartWorks() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - Task serverTask = listener.AcceptConnectionAsync().AsTask(); - await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds); - using QuicConnection serverConnection = serverTask.Result; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); Assert.Equal(0, clientStream.StreamId); @@ -740,7 +657,6 @@ async Task GetStreamIdWithoutStartWorks() // Dispose all connections before the streams; clientConnection.Dispose(); serverConnection.Dispose(); - listener.Dispose(); } await GetStreamIdWithoutStartWorks(); @@ -755,13 +671,7 @@ public async Task Read_ConnectionAbortedByPeer_Throws() await Task.Run(async () => { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); - - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await serverConnectionTask; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); await clientStream.WriteAsync(new byte[1]); @@ -782,13 +692,7 @@ public async Task Read_ConnectionAbortedByUser_Throws() { await Task.Run(async () => { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); - - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await serverConnectionTask; + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); await clientStream.WriteAsync(new byte[1]); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 2a03bca118163..7bf38b0812980 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -12,6 +12,7 @@ using Xunit; using Xunit.Abstractions; using System.Diagnostics.Tracing; +using System.Net.Sockets; namespace System.Net.Quic.Tests { @@ -66,11 +67,24 @@ public SslClientAuthenticationOptions GetSslClientAuthenticationOptions() }; } + public QuicClientConnectionOptions CreateQuicClientOptions() + { + return new QuicClientConnectionOptions() + { + ClientAuthenticationOptions = GetSslClientAuthenticationOptions() + }; + } + internal QuicConnection CreateQuicConnection(IPEndPoint endpoint) { return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions()); } + internal QuicConnection CreateQuicConnection(QuicClientConnectionOptions clientOptions) + { + return new QuicConnection(ImplementationProvider, clientOptions); + } + internal QuicListenerOptions CreateQuicListenerOptions() { return new QuicListenerOptions() @@ -99,14 +113,73 @@ internal QuicListener CreateQuicListener(IPEndPoint endpoint) return CreateQuicListener(options); } - internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection() + private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options); + + internal Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicListener listener) => CreateConnectedQuicConnection(null, listener); + internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions, QuicListenerOptions listenerOptions) { - using QuicListener listener = CreateQuicListener(); - QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); + using (QuicListener listener = CreateQuicListener(listenerOptions)) + { + clientOptions ??= new QuicClientConnectionOptions() + { + ClientAuthenticationOptions = GetSslClientAuthenticationOptions() + }; + clientOptions.RemoteEndPoint = listener.ListenEndPoint; + return await CreateConnectedQuicConnection(clientOptions, listener); + } + } + + internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions = null, QuicListener? listener = null) + { + int retry = 3; + int delay = 25; + bool disposeListener = false; + + if (listener == null) + { + listener = CreateQuicListener(); + disposeListener = true; + } + + clientOptions ??= CreateQuicClientOptions(); + if (clientOptions.RemoteEndPoint == null) + { + clientOptions.RemoteEndPoint = listener.ListenEndPoint; + } - ValueTask clientTask = clientConnection.ConnectAsync(); + QuicConnection clientConnection = null; ValueTask serverTask = listener.AcceptConnectionAsync(); - await new Task[] { clientTask.AsTask(), serverTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); + while (retry > 0) + { + clientConnection = CreateQuicConnection(clientOptions); + retry--; + try + { + await clientConnection.ConnectAsync().ConfigureAwait(false); + break; + } + catch (QuicException ex) when (ex.HResult == (int)SocketError.ConnectionRefused) + { + _output.WriteLine($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}"); + await Task.Delay(delay); + delay *= 2; + + if (retry == 0) + { + Debug.Fail($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}"); + } + } + } + + QuicConnection serverConnection = await serverTask.ConfigureAwait(false); + if (disposeListener) + { + listener.Dispose(); + } + + Assert.True(serverConnection.Connected); + Assert.True(clientConnection.Connected); + return (clientConnection, serverTask.Result); } @@ -140,8 +213,6 @@ internal async Task PingPong(QuicConnection client, QuicConnection server) await t; } - private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options); - internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds, QuicListenerOptions listenerOptions = null) { const long ClientCloseErrorCode = 11111; @@ -154,37 +225,28 @@ internal async Task RunClientServer(Func clientFunction, F for (int i = 0; i < iterations; ++i) { - await new[] + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener); + using (clientConnection) + using (serverConnection) { - Task.Run(async () => - { - using QuicConnection serverConnection = await listener.AcceptConnectionAsync().AsTask().WaitAsync(millisecondsTimeout); - await serverFunction(serverConnection); - - serverFinished.Release(); - await clientFinished.WaitAsync(); - await serverConnection.CloseAsync(ServerCloseErrorCode); - }), - Task.Run(async () => + await new[] { - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - try + Task.Run(async () => { - await clientConnection.ConnectAsync(); - } - catch (Exception ex) + await serverFunction(serverConnection); + serverFinished.Release(); + await clientFinished.WaitAsync(); + }), + Task.Run(async () => { - _output?.WriteLine("Failed to connect {0} with {1}", listener.ListenEndPoint, ex.Message); - throw; - } - - await clientFunction(clientConnection); - - clientFinished.Release(); - await serverFinished.WaitAsync(); - await clientConnection.CloseAsync(ClientCloseErrorCode); - }) - }.WhenAllOrAnyFailed(millisecondsTimeout); + await clientFunction(clientConnection); + clientFinished.Release(); + await serverFinished.WaitAsync(); + }) + }.WhenAllOrAnyFailed(millisecondsTimeout); + await serverConnection.CloseAsync(ServerCloseErrorCode); + await clientConnection.CloseAsync(ClientCloseErrorCode); + } } }