Skip to content

Commit

Permalink
Improve reliability of some SafeHandle creations (#72341)
Browse files Browse the repository at this point in the history
* Improve reliability of some SafeHandle creations

Use the new Marshal.InitHandle to support creating the SafeHandle instance before the native call and then storing the handle after, as is done implicitly as part of interop calls that return SafeHandles.  Also replace some existing such SetHandle methods with Marshal.InitHandle.

* Update WaitSubsystem.Unix.cs

Add InteropServices namespace

Co-authored-by: Aaron Robinson <arobins@microsoft.com>
  • Loading branch information
stephentoub and AaronRobinsonMSFT authored Jul 18, 2022
1 parent 1edd890 commit aafa910
Show file tree
Hide file tree
Showing 18 changed files with 95 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ internal static SafeDsaHandle DuplicateHandle(IntPtr handle)
{
Debug.Assert(handle != IntPtr.Zero);

return new SafeDsaHandle(Interop.JObjectLifetime.NewGlobalReference(handle));
var duplicate = new SafeDsaHandle();
duplicate.SetHandle(Interop.JObjectLifetime.NewGlobalReference(handle));
return duplicate;
}

internal override SafeDsaHandle DuplicateHandle() => DuplicateHandle(DangerousGetHandle());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,26 @@ internal static SafeX509Handle[] X509DecodeCollection(ReadOnlySpan<byte> data)
throw new CryptographicException();

IntPtr[] ptrs = new IntPtr[size];
SafeX509Handle[] handles = new SafeX509Handle[ptrs.Length];
for (var i = 0; i < handles.Length; i++)
{
handles[i] = new SafeX509Handle();
}

ret = X509DecodeCollection(ref buf, data.Length, ptrs, ref size);
if (ret != SUCCESS)
{
foreach (SafeX509Handle handle in handles)
{
handle.Dispose();
}

throw new CryptographicException();
}

SafeX509Handle[] handles = new SafeX509Handle[ptrs.Length];
for (var i = 0; i < handles.Length; i++)
{
handles[i] = new SafeX509Handle(ptrs[i]);
Marshal.InitHandle(handles[i], ptrs[i]);
}

return handles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,25 @@ internal static partial class Crypto
internal static partial void BigNumDestroy(IntPtr a);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BigNumFromBinary")]
private static unsafe partial IntPtr BigNumFromBinary(byte* s, int len);
private static unsafe partial SafeBignumHandle BigNumFromBinary(ReadOnlySpan<byte> bigEndianValue, int len);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BigNumToBinary")]
private static unsafe partial int BigNumToBinary(SafeBignumHandle a, byte* to);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_GetBigNumBytes")]
private static partial int GetBigNumBytes(SafeBignumHandle a);

private static unsafe IntPtr CreateBignumPtr(ReadOnlySpan<byte> bigEndianValue)
internal static SafeBignumHandle CreateBignum(ReadOnlySpan<byte> bigEndianValue)
{
fixed (byte* pBigEndianValue = bigEndianValue)
SafeBignumHandle ret = BigNumFromBinary(bigEndianValue, bigEndianValue.Length);
if (ret.IsInvalid)
{
IntPtr ret = BigNumFromBinary(pBigEndianValue, bigEndianValue.Length);

if (ret == IntPtr.Zero)
{
throw CreateOpenSslCryptographicException();
}

return ret;
Exception e = CreateOpenSslCryptographicException();
ret.Dispose();
throw e;
}
}

internal static SafeBignumHandle CreateBignum(ReadOnlySpan<byte> bigEndianValue)
{
IntPtr handle = CreateBignumPtr(bigEndianValue);
return new SafeBignumHandle(handle, true);
return ret;
}

internal static byte[]? ExtractBignum(IntPtr bignum, int targetSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal static partial class Advapi32
[LibraryImport(Libraries.Advapi32, EntryPoint = "RegConnectRegistryW", StringMarshalling = StringMarshalling.Utf16)]
internal static partial int RegConnectRegistry(
string machineName,
SafeRegistryHandle key,
IntPtr key,
out SafeRegistryHandle result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public SafeCertContextHandle(SafeCertContextHandle parent)
SetHandle(_parent.handle);
}

internal new void SetHandle(IntPtr handle) => base.SetHandle(handle);

protected override bool ReleaseHandle()
{
if (_parent != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ private static RegistryKey OpenRemoteBaseKeyCore(RegistryHive hKey, string machi
}

// connect to the specified remote registry
int ret = Interop.Advapi32.RegConnectRegistry(machineName, new SafeRegistryHandle(new IntPtr((int)hKey), false), out SafeRegistryHandle foreignHKey);

int ret = Interop.Advapi32.RegConnectRegistry(machineName, new IntPtr((int)hKey), out SafeRegistryHandle foreignHKey);
if (ret == 0 && !foreignHKey.IsInvalid)
{
RegistryKey key = new RegistryKey(foreignHKey, true, false, true, ((IntPtr)hKey) == HKEY_PERFORMANCE_DATA, view);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ private PerformanceDataRegistryKey(SafeRegistryHandle hkey)
public static PerformanceDataRegistryKey OpenRemoteBaseKey(string machineName)
{
// connect to the specified remote registry
int ret = Interop.Advapi32.RegConnectRegistry(machineName, new SafeRegistryHandle(new IntPtr(PerformanceData), ownsHandle: false), out SafeRegistryHandle foreignHKey);

int ret = Interop.Advapi32.RegConnectRegistry(machineName, new IntPtr(PerformanceData), out SafeRegistryHandle foreignHKey);
if (ret == 0 && !foreignHKey.IsInvalid)
{
return new PerformanceDataRegistryKey(foreignHKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,20 @@ private static unsafe void CreateSocket()
{
Debug.Assert(Monitor.IsEntered(s_gate));
Debug.Assert(Socket == null, "Socket is not null, must close existing socket before opening another.");

var sh = new SafeSocketHandle();

IntPtr newSocket;
Interop.Error result = Interop.Sys.CreateNetworkChangeListenerSocket(&newSocket);
if (result != Interop.Error.SUCCESS)
{
string message = Interop.Sys.GetLastErrorInfo().GetErrorMessage();
sh.Dispose();
throw new NetworkInformationException(message);
}

Socket = new Socket(new SafeSocketHandle(newSocket, ownsHandle: true));
Marshal.InitHandle(sh, newSocket);
Socket = new Socket(sh);

// Don't capture ExecutionContext.
ThreadPool.UnsafeQueueUserWorkItem(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ private static unsafe IPPacketInformation GetIPPacketInformation(Interop.Sys.Mes

public static unsafe SocketError CreateSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, out SafeSocketHandle socket)
{
socket = new SafeSocketHandle();

IntPtr fd;
SocketError errorCode;
Interop.Error error = Interop.Sys.Socket(addressFamily, socketType, protocolType, &fd);
Expand Down Expand Up @@ -86,7 +88,8 @@ public static unsafe SocketError CreateSocket(AddressFamily addressFamily, Socke
errorCode = GetSocketErrorForErrorCode(error);
}

socket = new SafeSocketHandle(fd, ownsHandle: true);
Marshal.InitHandle(socket, fd);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, socket);
if (socket.IsInvalid)
{
Expand Down Expand Up @@ -1093,6 +1096,8 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog)

public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressLen, out SafeSocketHandle socket)
{
socket = new SafeSocketHandle();

IntPtr acceptedFd;
SocketError errorCode;
if (!listenSocket.IsNonBlocking)
Expand All @@ -1101,14 +1106,14 @@ public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAdd
}
else
{
bool completed = TryCompleteAccept(listenSocket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode);
if (!completed)
if (!TryCompleteAccept(listenSocket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode))
{
errorCode = SocketError.WouldBlock;
}
}

socket = new SafeSocketHandle(acceptedFd, ownsHandle: true);
Marshal.InitHandle(socket, acceptedFd);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, socket);

return errorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ public static SocketError CreateSocket(AddressFamily addressFamily, SocketType s
{
Interop.Winsock.EnsureInitialized();

IntPtr handle = Interop.Winsock.WSASocketW(addressFamily, socketType, protocolType, IntPtr.Zero, 0, Interop.Winsock.SocketConstructorFlags.WSA_FLAG_OVERLAPPED |
Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);
socket = new SafeSocketHandle();
Marshal.InitHandle(socket, Interop.Winsock.WSASocketW(addressFamily, socketType, protocolType, IntPtr.Zero, 0, Interop.Winsock.SocketConstructorFlags.WSA_FLAG_OVERLAPPED |
Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT));

socket = new SafeSocketHandle(handle, ownsHandle: true);
if (socket.IsInvalid)
{
SocketError error = GetLastSocketError();
Expand Down Expand Up @@ -72,22 +72,22 @@ public static unsafe SocketError CreateSocket(

fixed (byte* protocolInfoBytes = socketInformation.ProtocolInformation)
{
socket = new SafeSocketHandle();

// Sockets are non-inheritable in .NET Core.
// Handle properties like HANDLE_FLAG_INHERIT are not cloned with socket duplication, therefore
// we need to disable handle inheritance when constructing the new socket handle from Protocol Info.
// Additionally, it looks like WSA_FLAG_NO_HANDLE_INHERIT has no effect when being used with the Protocol Info
// variant of WSASocketW, so it is being passed to that call only for consistency.
// Inheritance is being disabled with SetHandleInformation(...) after the WSASocketW call.
IntPtr handle = Interop.Winsock.WSASocketW(
Marshal.InitHandle(socket, Interop.Winsock.WSASocketW(
(AddressFamily)(-1),
(SocketType)(-1),
(ProtocolType)(-1),
(IntPtr)protocolInfoBytes,
0,
Interop.Winsock.SocketConstructorFlags.WSA_FLAG_OVERLAPPED |
Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);

socket = new SafeSocketHandle(handle, ownsHandle: true);
Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT));

if (socket.IsInvalid)
{
Expand Down Expand Up @@ -178,9 +178,9 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog)

public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressSize, out SafeSocketHandle socket)
{
IntPtr handle = Interop.Winsock.accept(listenSocket, socketAddress, ref socketAddressSize);
socket = new SafeSocketHandle();
Marshal.InitHandle(socket, Interop.Winsock.accept(listenSocket, socketAddress, ref socketAddressSize));

socket = new SafeSocketHandle(handle, ownsHandle: true);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, socket);

return socket.IsInvalid ? GetLastSocketError() : SocketError.Success;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Microsoft.Win32.SafeHandles;

namespace System.Threading
Expand Down Expand Up @@ -145,6 +146,8 @@ public void Dispose()

private static SafeWaitHandle NewHandle(WaitableObject waitableObject)
{
var safeWaitHandle = new SafeWaitHandle();

IntPtr handle = IntPtr.Zero;
try
{
Expand All @@ -158,19 +161,8 @@ private static SafeWaitHandle NewHandle(WaitableObject waitableObject)
}
}

SafeWaitHandle? safeWaitHandle = null;
try
{
safeWaitHandle = new SafeWaitHandle(handle, ownsHandle: true);
return safeWaitHandle;
}
finally
{
if (safeWaitHandle == null)
{
HandleManager.DeleteHandle(handle);
}
}
Marshal.InitHandle(safeWaitHandle, handle);
return safeWaitHandle;
}

public static SafeWaitHandle NewEvent(bool initiallySignaled, EventResetMode resetMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using Microsoft.Win32.SafeHandles;
using Xunit;

Expand All @@ -12,11 +13,14 @@ public static void SafeHandleMinusOneIsInvalidTest()
{
var sh = new TestSafeHandleMinusOneIsInvalid();
Assert.True(sh.IsInvalid);
sh.SetHandle(new IntPtr(-2));

Marshal.InitHandle(sh, -2);
Assert.False(sh.IsInvalid);
sh.SetHandle(new IntPtr(-1));

Marshal.InitHandle(sh, -1);
Assert.True(sh.IsInvalid);
sh.SetHandle(IntPtr.Zero);

Marshal.InitHandle(sh, 0);
Assert.False(sh.IsInvalid);
}

Expand All @@ -25,13 +29,17 @@ public static void SafeHandleZeroOrMinusOneIsInvalidTest()
{
var sh = new TestSafeHandleZeroOrMinusOneIsInvalid();
Assert.True(sh.IsInvalid);
sh.SetHandle(new IntPtr(-2));

Marshal.InitHandle(sh, -2);
Assert.False(sh.IsInvalid);
sh.SetHandle(new IntPtr(-1));

Marshal.InitHandle(sh, -1);
Assert.True(sh.IsInvalid);
sh.SetHandle(IntPtr.Zero);

Marshal.InitHandle(sh, 0);
Assert.True(sh.IsInvalid);
sh.SetHandle(new IntPtr(1));

Marshal.InitHandle(sh, 1);
Assert.False(sh.IsInvalid);
}

Expand All @@ -42,7 +50,6 @@ public TestSafeHandleMinusOneIsInvalid() : base(true)
}

protected override bool ReleaseHandle() => true;
public new void SetHandle(IntPtr handle) => base.SetHandle(handle);
}

private class TestSafeHandleZeroOrMinusOneIsInvalid : SafeHandleZeroOrMinusOneIsInvalid
Expand All @@ -52,6 +59,5 @@ public TestSafeHandleZeroOrMinusOneIsInvalid() : base(true)
}

protected override bool ReleaseHandle() => true;
public new void SetHandle(IntPtr handle) => base.SetHandle(handle);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -383,15 +383,6 @@ internal SafeNCryptProviderHandle Duplicate()
return Duplicate<SafeNCryptProviderHandle>();
}

internal void SetHandleValue(IntPtr newHandleValue)
{
Debug.Assert(newHandleValue != IntPtr.Zero);
Debug.Assert(!IsClosed);
Debug.Assert(handle == IntPtr.Zero);

SetHandle(newHandleValue);
}

protected override bool ReleaseNativeHandle()
{
return ReleaseNativeWithNCryptFreeObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using Microsoft.Win32.SafeHandles;

Expand Down Expand Up @@ -34,7 +35,7 @@ internal static CngKey OpenNoDuplicate(SafeNCryptKeyHandle keyHandle, CngKeyHand
// Get a handle to the key's provider.
providerHandle = new SafeNCryptProviderHandle();
IntPtr rawProviderHandle = keyHandle.GetPropertyAsIntPtr(KeyPropertyName.ProviderHandle, CngPropertyOptions.None);
providerHandle.SetHandleValue(rawProviderHandle);
Marshal.InitHandle(providerHandle, rawProviderHandle);

// If we're wrapping a handle to an ephemeral key, we need to make sure that IsEphemeral is
// set up to return true. In the case that the handle is for an ephemeral key that was created
Expand Down
Loading

0 comments on commit aafa910

Please sign in to comment.