diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
index 4c7e48ccda820..28cc177fa036b 100644
--- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
+++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
@@ -44,6 +44,7 @@ public partial class QuicException : System.Exception
{
public QuicException(string? message) { }
public QuicException(string? message, System.Exception? innerException) { }
+ public QuicException(string? message, System.Exception? innerException, int result) { }
}
public static partial class QuicImplementationProviders
{
diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx
index aeecb31ba2976..38a30c86c019e 100644
--- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx
+++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx
@@ -159,6 +159,9 @@
The remote certificate is invalid because of errors in the certificate chain: {0}
+
+ Connection is not connected.
+
The application protocol list is invalid.
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs
index 5fb0a24c9a905..680069bf7b815 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/QuicExceptionHelpers.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Net.Sockets;
+
namespace System.Net.Quic.Implementations.MsQuic.Internal
{
internal static class QuicExceptionHelpers
@@ -15,7 +17,22 @@ internal static void ThrowIfFailed(uint status, string? message = null, Exceptio
internal static Exception CreateExceptionForHResult(uint status, string? message = null, Exception? innerException = null)
{
- return new QuicException($"{message} Error Code: {MsQuicStatusCodes.GetError(status)}", innerException);
+ return new QuicException($"{message} Error Code: {MsQuicStatusCodes.GetError(status)}", innerException, MapMsQuicStatusToHResult(status));
+ }
+
+ internal static int MapMsQuicStatusToHResult(uint status)
+ {
+ switch (status)
+ {
+ case MsQuicStatusCodes.ConnectionRefused:
+ return (int)SocketError.ConnectionRefused; // 0x8007274D - WSAECONNREFUSED
+ case MsQuicStatusCodes.ConnectionTimeout:
+ return (int)SocketError.TimedOut; // 0x8007274C - WSAETIMEDOUT
+ case MsQuicStatusCodes.HostUnreachable:
+ return (int)SocketError.HostUnreachable;
+ default:
+ return 0;
+ }
}
}
}
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
index b3cbb93ceda58..a02417ff710c3 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
@@ -171,7 +171,12 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf
// constructor for outbound connections
public MsQuicConnection(QuicClientConnectionOptions options)
{
- _remoteEndPoint = options.RemoteEndPoint!;
+ if (options.RemoteEndPoint == null)
+ {
+ throw new ArgumentNullException(nameof(options.RemoteEndPoint));
+ }
+
+ _remoteEndPoint = options.RemoteEndPoint;
_configuration = SafeMsQuicConfigurationHandle.Create(options);
_state.RemoteCertificateRequired = true;
if (options.ClientAuthenticationOptions != null)
@@ -523,6 +528,10 @@ internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(Cancellati
internal override QuicStreamProvider OpenUnidirectionalStream()
{
ThrowIfDisposed();
+ if (!Connected)
+ {
+ throw new InvalidOperationException(SR.net_quic_not_connected);
+ }
return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
}
@@ -530,6 +539,10 @@ internal override QuicStreamProvider OpenUnidirectionalStream()
internal override QuicStreamProvider OpenBidirectionalStream()
{
ThrowIfDisposed();
+ if (!Connected)
+ {
+ throw new InvalidOperationException(SR.net_quic_not_connected);
+ }
return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.NONE);
}
@@ -552,7 +565,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
if (_configuration is null)
{
- throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener.");
+ throw new InvalidOperationException($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener.");
}
QUIC_ADDRESS_FAMILY af = _remoteEndPoint.AddressFamily switch
@@ -560,7 +573,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
AddressFamily.Unspecified => QUIC_ADDRESS_FAMILY.UNSPEC,
AddressFamily.InterNetwork => QUIC_ADDRESS_FAMILY.INET,
AddressFamily.InterNetworkV6 => QUIC_ADDRESS_FAMILY.INET6,
- _ => throw new Exception(SR.Format(SR.net_quic_unsupported_address_family, _remoteEndPoint.AddressFamily))
+ _ => throw new ArgumentException(SR.Format(SR.net_quic_unsupported_address_family, _remoteEndPoint.AddressFamily))
};
Debug.Assert(_state.StateGCHandle.IsAllocated);
@@ -592,7 +605,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
}
else
{
- throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.");
+ throw new ArgumentException($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.");
}
// We store TCS to local variable to avoid NRE if callbacks finish fast and set _state.ConnectTcs to null.
@@ -759,7 +772,7 @@ private void Dispose(bool disposing)
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Connection disposing {disposing}");
// If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown.
- if (_state.Handle != null)
+ if (_state.Handle != null && !_state.Handle.IsInvalid && !_state.Handle.IsClosed)
{
// Handle can be null if outbound constructor failed and we are called from finalizer.
Debug.Assert(!Monitor.IsEntered(_state));
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
index b5daea6e0d0ed..f163c8d161b0b 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
@@ -12,7 +12,6 @@
using System.Threading.Channels;
using System.Threading.Tasks;
using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods;
-using System.Security.Authentication;
namespace System.Net.Quic.Implementations.MsQuic
{
@@ -219,15 +218,14 @@ private static unsafe uint NativeCallbackHandler(
IntPtr context,
ref ListenerEvent evt)
{
- if (evt.Type != QUIC_LISTENER_EVENT.NEW_CONNECTION)
- {
- return MsQuicStatusCodes.InternalError;
- }
-
GCHandle gcHandle = GCHandle.FromIntPtr(context);
Debug.Assert(gcHandle.IsAllocated);
Debug.Assert(gcHandle.Target is not null);
var state = (State)gcHandle.Target;
+ if (evt.Type != QUIC_LISTENER_EVENT.NEW_CONNECTION)
+ {
+ return MsQuicStatusCodes.InternalError;
+ }
SafeMsQuicConnectionHandle? connectionHandle = null;
MsQuicConnection? msQuicConnection = null;
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs
index 7336c90831840..4cd9745d82a8f 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicException.cs
@@ -13,5 +13,11 @@ public QuicException(string? message, Exception? innerException)
: base(message, innerException)
{
}
+
+ public QuicException(string? message, Exception? innerException, int result)
+ : base(message, innerException)
+ {
+ HResult = result;
+ }
}
}
diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
index 4132bcef76cb7..9e6b7b504fb16 100644
--- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
+++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
@@ -28,38 +28,31 @@ public MsQuicTests(ITestOutputHelper output) : base(output) { }
[Fact]
public async Task UnidirectionalAndBidirectionalStreamCountsWork()
{
- using QuicListener listener = CreateQuicListener();
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
Assert.Equal(100, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, serverConnection.GetRemoteAvailableUnidirectionalStreamCount());
+ serverConnection.Dispose();
+ clientConnection.Dispose();
}
[Fact]
public async Task UnidirectionalAndBidirectionalChangeValues()
{
- using QuicListener listener = CreateQuicListener();
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
+ QuicClientConnectionOptions listenerOptions = new QuicClientConnectionOptions()
{
MaxBidirectionalStreams = 10,
MaxUnidirectionalStreams = 20,
- RemoteEndPoint = listener.ListenEndPoint,
ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
};
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
-
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listenerOptions);
Assert.Equal(100, clientConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, clientConnection.GetRemoteAvailableUnidirectionalStreamCount());
Assert.Equal(10, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(20, serverConnection.GetRemoteAvailableUnidirectionalStreamCount());
+ serverConnection.Dispose();
+ clientConnection.Dispose();
}
[Fact]
@@ -68,21 +61,14 @@ public async Task ConnectWithCertificateChain()
(X509Certificate2 certificate, X509Certificate2Collection chain) = System.Net.Security.Tests.TestHelper.GenerateCertificates("localhost", longChain: true);
X509Certificate2 rootCA = chain[chain.Count - 1];
- var quicOptions = new QuicListenerOptions();
- quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
- quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
- quicOptions.ServerAuthenticationOptions.ServerCertificateContext = SslStreamCertificateContext.Create(certificate, chain);
- quicOptions.ServerAuthenticationOptions.ServerCertificate = null;
-
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ServerAuthenticationOptions.ServerCertificateContext = SslStreamCertificateContext.Create(certificate, chain);
+ listenerOptions.ServerAuthenticationOptions.ServerCertificate = null;
- options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
Assert.Equal(certificate.Subject, cert.Subject);
Assert.Equal(certificate.Issuer, cert.Issuer);
@@ -108,12 +94,11 @@ public async Task ConnectWithCertificateChain()
return ret;
};
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
Assert.Equal(certificate, clientConnection.RemoteCertificate);
Assert.Null(serverConnection.RemoteCertificate);
+ serverConnection.Dispose();
+ clientConnection.Dispose();
}
[Fact]
@@ -122,33 +107,27 @@ public async Task CertificateCallbackThrowPropagates()
using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
X509Certificate? receivedCertificate = null;
- var quicOptions = new QuicListenerOptions();
- quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
- quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
-
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
- options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.RemoteEndPoint = listener.ListenEndPoint;
+ clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
receivedCertificate = cert;
throw new ArithmeticException("foobar");
};
- options.ClientAuthenticationOptions.TargetHost = "foobar1";
-
- QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+ clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
+ QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask());
QuicConnection serverConnection = await serverTask;
- Assert.Equal(quicOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
+ Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
clientConnection.Dispose();
serverConnection.Dispose();
@@ -166,11 +145,11 @@ public async Task ConnectWithCertificateCallback()
string? receivedHostName = null;
X509Certificate? receivedCertificate = null;
- var quicOptions = new QuicListenerOptions();
- quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
- quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
- quicOptions.ServerAuthenticationOptions.ServerCertificate = null;
- quicOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) =>
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ServerAuthenticationOptions.ServerCertificate = null;
+ listenerOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) =>
{
receivedHostName = hostName;
if (hostName == "foobar1")
@@ -185,52 +164,36 @@ public async Task ConnectWithCertificateCallback()
return null;
};
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
-
- options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
+ clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
receivedCertificate = cert;
return true;
};
- options.ClientAuthenticationOptions.TargetHost = "foobar1";
-
- QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
-
- Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
- await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
- QuicConnection serverConnection = serverTask.Result;
-
- Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);
+ Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName);
Assert.Equal(c1, receivedCertificate);
clientConnection.Dispose();
serverConnection.Dispose();
// This should fail when callback return null.
- options.ClientAuthenticationOptions.TargetHost = "foobar3";
- clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+ clientOptions.ClientAuthenticationOptions.TargetHost = "foobar3";
+ clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask();
await Assert.ThrowsAsync(() => clientTask);
- Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName);
clientConnection.Dispose();
// Do this last to make sure Listener is still functional.
- options.ClientAuthenticationOptions.TargetHost = "foobar2";
+ clientOptions.ClientAuthenticationOptions.TargetHost = "foobar2";
expectedCertificate = c2;
- clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
- serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
- await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
- serverConnection = serverTask.Result;
-
- Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ (clientConnection, serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);
+ Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName);
Assert.Equal(c2, receivedCertificate);
clientConnection.Dispose();
serverConnection.Dispose();
@@ -245,18 +208,13 @@ public async Task ConnectWithCertificateForDifferentName_Throws()
quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
quicOptions.ServerAuthenticationOptions.ServerCertificate = certificate;
-
using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
-
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.RemoteEndPoint = listener.ListenEndPoint;
// Use different target host on purpose to get RemoteCertificateNameMismatch ssl error.
- options.ClientAuthenticationOptions.TargetHost = "loopback";
- options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ clientOptions.ClientAuthenticationOptions.TargetHost = "loopback";
+ clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
Assert.Equal(certificate.Subject, cert.Subject);
Assert.Equal(certificate.Issuer, cert.Issuer);
@@ -264,7 +222,7 @@ public async Task ConnectWithCertificateForDifferentName_Throws()
return SslPolicyErrors.None == errors;
};
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+ using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
@@ -281,20 +239,13 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str
var ipAddress = IPAddress.Parse(ipString);
(X509Certificate2 certificate, _) = System.Net.Security.Tests.TestHelper.GenerateCertificates(expectsError ? "badhost" : "localhost");
- var quicOptions = new QuicListenerOptions();
- quicOptions.ListenEndPoint = new IPEndPoint(ipAddress, 0);
- quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
- quicOptions.ServerAuthenticationOptions.ServerCertificate = certificate;
-
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = new IPEndPoint(ipAddress, listener.ListenEndPoint.Port),
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(ipAddress, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ServerAuthenticationOptions.ServerCertificate = certificate;
- options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
Assert.Equal(certificate.Subject, cert.Subject);
Assert.Equal(certificate.Issuer, cert.Issuer);
@@ -302,11 +253,7 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str
return true;
};
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
- ValueTask clientTask = clientConnection.ConnectAsync();
-
- using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
- await clientTask;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
}
[Fact]
@@ -315,11 +262,11 @@ public async Task ConnectWithClientCertificate()
{
bool clientCertificateOK = false;
- var serverOptions = new QuicListenerOptions();
- serverOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
- serverOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
- serverOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
- serverOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
+ listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
@@ -328,19 +275,11 @@ public async Task ConnectWithClientCertificate()
clientCertificateOK = true;
return true;
};
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, serverOptions);
- QuicClientConnectionOptions clientOptions = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
-
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);
// Verify functionality of the connections.
await PingPong(clientConnection, serverConnection);
@@ -349,17 +288,16 @@ public async Task ConnectWithClientCertificate()
Assert.Equal(ClientCertificate, serverConnection.RemoteCertificate);
await serverConnection.CloseAsync(0);
+ clientConnection.Dispose();
+ serverConnection.Dispose();
}
[Fact]
public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
{
- using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
- listener.Dispose();
+ QuicListenerOptions listenerOptions = CreateQuicListenerOptions();
+ listenerOptions.MaxUnidirectionalStreams = 1;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions);
// No stream opened yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableUnidirectionalStreamsAsync().IsCompletedSuccessfully);
@@ -375,17 +313,16 @@ public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
newStream.Dispose();
await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
+ clientConnection.Dispose();
+ serverConnection.Dispose();
}
[Fact]
public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
{
- using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1);
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
-
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ QuicListenerOptions listenerOptions = CreateQuicListenerOptions();
+ listenerOptions.MaxBidirectionalStreams = 1;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions);
// No stream opened yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully);
@@ -401,31 +338,23 @@ public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
QuicStream newStream = await serverConnection.AcceptStreamAsync();
newStream.Dispose();
await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
+ clientConnection.Dispose();
+ serverConnection.Dispose();
}
[Fact]
[OuterLoop("May take several seconds")]
public async Task SetListenerTimeoutWorksWithSmallTimeout()
{
- var quicOptions = new QuicListenerOptions();
- quicOptions.IdleTimeout = TimeSpan.FromSeconds(1);
- quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
- quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
-
- using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
-
- QuicClientConnectionOptions options = new QuicClientConnectionOptions()
- {
- RemoteEndPoint = listener.ListenEndPoint,
- ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
- };
-
- using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.IdleTimeout = TimeSpan.FromSeconds(1);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions);
await Assert.ThrowsAsync(async () => await serverConnection.AcceptStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100)));
+ serverConnection.Dispose();
+ clientConnection.Dispose();
}
[Theory]
@@ -522,13 +451,7 @@ public enum WriteType
[Fact]
public async Task CallDifferentWriteMethodsWorks()
{
- using QuicListener listener = CreateQuicListener();
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
-
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
-
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
ReadOnlyMemory helloWorld = Encoding.ASCII.GetBytes("Hello world!");
ReadOnlySequence ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray());
@@ -548,6 +471,8 @@ public async Task CallDifferentWriteMethodsWorks()
res = await serverStream.ReadAsync(memory);
Assert.Equal(24, res);
+ clientConnection.Dispose();
+ serverConnection.Dispose();
}
[Fact]
@@ -704,17 +629,14 @@ public async Task ManagedAVE_MinimalFailingTest()
{
async Task GetStreamIdWithoutStartWorks()
{
- using QuicListener listener = CreateQuicListener();
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
-
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
// TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer
+ clientConnection.Dispose();
+ serverConnection.Dispose();
}
await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15));
@@ -727,12 +649,7 @@ public async Task DisposingConnection_OK()
{
async Task GetStreamIdWithoutStartWorks()
{
- using QuicListener listener = CreateQuicListener();
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
-
- Task serverTask = listener.AcceptConnectionAsync().AsTask();
- await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
- using QuicConnection serverConnection = serverTask.Result;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
@@ -740,7 +657,6 @@ async Task GetStreamIdWithoutStartWorks()
// Dispose all connections before the streams;
clientConnection.Dispose();
serverConnection.Dispose();
- listener.Dispose();
}
await GetStreamIdWithoutStartWorks();
@@ -755,13 +671,7 @@ public async Task Read_ConnectionAbortedByPeer_Throws()
await Task.Run(async () =>
{
- using QuicListener listener = CreateQuicListener();
- ValueTask serverConnectionTask = listener.AcceptConnectionAsync();
-
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
- await clientConnection.ConnectAsync();
-
- using QuicConnection serverConnection = await serverConnectionTask;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
await using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
await clientStream.WriteAsync(new byte[1]);
@@ -782,13 +692,7 @@ public async Task Read_ConnectionAbortedByUser_Throws()
{
await Task.Run(async () =>
{
- using QuicListener listener = CreateQuicListener();
- ValueTask serverConnectionTask = listener.AcceptConnectionAsync();
-
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
- await clientConnection.ConnectAsync();
-
- using QuicConnection serverConnection = await serverConnectionTask;
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
await using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
await clientStream.WriteAsync(new byte[1]);
diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
index 2a03bca118163..7bf38b0812980 100644
--- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
+++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
@@ -12,6 +12,7 @@
using Xunit;
using Xunit.Abstractions;
using System.Diagnostics.Tracing;
+using System.Net.Sockets;
namespace System.Net.Quic.Tests
{
@@ -66,11 +67,24 @@ public SslClientAuthenticationOptions GetSslClientAuthenticationOptions()
};
}
+ public QuicClientConnectionOptions CreateQuicClientOptions()
+ {
+ return new QuicClientConnectionOptions()
+ {
+ ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
+ };
+ }
+
internal QuicConnection CreateQuicConnection(IPEndPoint endpoint)
{
return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions());
}
+ internal QuicConnection CreateQuicConnection(QuicClientConnectionOptions clientOptions)
+ {
+ return new QuicConnection(ImplementationProvider, clientOptions);
+ }
+
internal QuicListenerOptions CreateQuicListenerOptions()
{
return new QuicListenerOptions()
@@ -99,14 +113,73 @@ internal QuicListener CreateQuicListener(IPEndPoint endpoint)
return CreateQuicListener(options);
}
- internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection()
+ private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options);
+
+ internal Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicListener listener) => CreateConnectedQuicConnection(null, listener);
+ internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions, QuicListenerOptions listenerOptions)
{
- using QuicListener listener = CreateQuicListener();
- QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
+ using (QuicListener listener = CreateQuicListener(listenerOptions))
+ {
+ clientOptions ??= new QuicClientConnectionOptions()
+ {
+ ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
+ };
+ clientOptions.RemoteEndPoint = listener.ListenEndPoint;
+ return await CreateConnectedQuicConnection(clientOptions, listener);
+ }
+ }
+
+ internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions = null, QuicListener? listener = null)
+ {
+ int retry = 3;
+ int delay = 25;
+ bool disposeListener = false;
+
+ if (listener == null)
+ {
+ listener = CreateQuicListener();
+ disposeListener = true;
+ }
+
+ clientOptions ??= CreateQuicClientOptions();
+ if (clientOptions.RemoteEndPoint == null)
+ {
+ clientOptions.RemoteEndPoint = listener.ListenEndPoint;
+ }
- ValueTask clientTask = clientConnection.ConnectAsync();
+ QuicConnection clientConnection = null;
ValueTask serverTask = listener.AcceptConnectionAsync();
- await new Task[] { clientTask.AsTask(), serverTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
+ while (retry > 0)
+ {
+ clientConnection = CreateQuicConnection(clientOptions);
+ retry--;
+ try
+ {
+ await clientConnection.ConnectAsync().ConfigureAwait(false);
+ break;
+ }
+ catch (QuicException ex) when (ex.HResult == (int)SocketError.ConnectionRefused)
+ {
+ _output.WriteLine($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
+ await Task.Delay(delay);
+ delay *= 2;
+
+ if (retry == 0)
+ {
+ Debug.Fail($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
+ }
+ }
+ }
+
+ QuicConnection serverConnection = await serverTask.ConfigureAwait(false);
+ if (disposeListener)
+ {
+ listener.Dispose();
+ }
+
+ Assert.True(serverConnection.Connected);
+ Assert.True(clientConnection.Connected);
+
return (clientConnection, serverTask.Result);
}
@@ -140,8 +213,6 @@ internal async Task PingPong(QuicConnection client, QuicConnection server)
await t;
}
- private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options);
-
internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds, QuicListenerOptions listenerOptions = null)
{
const long ClientCloseErrorCode = 11111;
@@ -154,37 +225,28 @@ internal async Task RunClientServer(Func clientFunction, F
for (int i = 0; i < iterations; ++i)
{
- await new[]
+ (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
+ using (clientConnection)
+ using (serverConnection)
{
- Task.Run(async () =>
- {
- using QuicConnection serverConnection = await listener.AcceptConnectionAsync().AsTask().WaitAsync(millisecondsTimeout);
- await serverFunction(serverConnection);
-
- serverFinished.Release();
- await clientFinished.WaitAsync();
- await serverConnection.CloseAsync(ServerCloseErrorCode);
- }),
- Task.Run(async () =>
+ await new[]
{
- using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
- try
+ Task.Run(async () =>
{
- await clientConnection.ConnectAsync();
- }
- catch (Exception ex)
+ await serverFunction(serverConnection);
+ serverFinished.Release();
+ await clientFinished.WaitAsync();
+ }),
+ Task.Run(async () =>
{
- _output?.WriteLine("Failed to connect {0} with {1}", listener.ListenEndPoint, ex.Message);
- throw;
- }
-
- await clientFunction(clientConnection);
-
- clientFinished.Release();
- await serverFinished.WaitAsync();
- await clientConnection.CloseAsync(ClientCloseErrorCode);
- })
- }.WhenAllOrAnyFailed(millisecondsTimeout);
+ await clientFunction(clientConnection);
+ clientFinished.Release();
+ await serverFinished.WaitAsync();
+ })
+ }.WhenAllOrAnyFailed(millisecondsTimeout);
+ await serverConnection.CloseAsync(ServerCloseErrorCode);
+ await clientConnection.CloseAsync(ClientCloseErrorCode);
+ }
}
}