Skip to content

Commit

Permalink
Merge pull request #939 from steffenlarsen/steffen/virtual_mem_adapters
Browse files Browse the repository at this point in the history
[UR][CUDA][L0][HIP] Add virtual memory adapter implementations
  • Loading branch information
aarongreig committed Dec 18, 2023
2 parents 67e4d1b + 1678894 commit 8d1486a
Show file tree
Hide file tree
Showing 39 changed files with 859 additions and 279 deletions.
3 changes: 3 additions & 0 deletions source/adapters/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ add_ur_adapter(${TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.hpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/physical_mem.hpp
${CMAKE_CURRENT_SOURCE_DIR}/physical_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/platform.hpp
${CMAKE_CURRENT_SOURCE_DIR}/platform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/program.hpp
Expand All @@ -38,6 +40,7 @@ add_ur_adapter(${TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/tracing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
${CMAKE_CURRENT_SOURCE_DIR}/virtual_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.hpp
)
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
// TODO: Investigate if this information is available on CUDA.
case UR_DEVICE_INFO_HOST_PIPE_READ_WRITE_SUPPORTED:
return ReturnValue(false);
case UR_DEVICE_INFO_VIRTUAL_MEMORY_SUPPORT:
return ReturnValue(true);
case UR_DEVICE_INFO_ESIMD_SUPPORT:
return ReturnValue(false);
case UR_DEVICE_INFO_MAX_READ_WRITE_IMAGE_ARGS:
Expand All @@ -1026,7 +1028,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_GPU_SUBSLICES_PER_SLICE:
case UR_DEVICE_INFO_GPU_EU_COUNT_PER_SUBSLICE:
case UR_DEVICE_INFO_GPU_HW_THREADS_PER_EU:
case UR_DEVICE_INFO_VIRTUAL_MEMORY_SUPPORT:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;

default:
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include <ur/ur.hpp>

#include "common.hpp"

struct ur_device_handle_t_ {
private:
using native_type = CUdevice;
Expand Down
62 changes: 62 additions & 0 deletions source/adapters/cuda/physical_mem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===--------- physical_mem.cpp - CUDA Adapter ----------------------------===//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "physical_mem.hpp"
#include "common.hpp"
#include "context.hpp"
#include "event.hpp"

#include <cassert>
#include <cuda.h>

UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemCreate(
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
[[maybe_unused]] const ur_physical_mem_properties_t *pProperties,
ur_physical_mem_handle_t *phPhysicalMem) {
CUmemAllocationProp AllocProps = {};
AllocProps.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
AllocProps.type = CU_MEM_ALLOCATION_TYPE_PINNED;
UR_CHECK_ERROR(GetDeviceOrdinal(hDevice, AllocProps.location.id));

CUmemGenericAllocationHandle ResHandle;
switch (auto Result = cuMemCreate(&ResHandle, size, &AllocProps, 0)) {
case CUDA_ERROR_INVALID_VALUE:
return UR_RESULT_ERROR_INVALID_SIZE;
default:
UR_CHECK_ERROR(Result);
}
*phPhysicalMem = new ur_physical_mem_handle_t_(ResHandle, hContext);

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) {
hPhysicalMem->incrementReferenceCount();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) {
if (hPhysicalMem->decrementReferenceCount() > 0)
return UR_RESULT_SUCCESS;

try {
std::unique_ptr<ur_physical_mem_handle_t_> PhysicalMemGuard(hPhysicalMem);

ScopedContext Active(hPhysicalMem->getContext());
UR_CHECK_ERROR(cuMemRelease(hPhysicalMem->get()));
return UR_RESULT_SUCCESS;
} catch (ur_result_t err) {
return err;
} catch (...) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
}
}
66 changes: 66 additions & 0 deletions source/adapters/cuda/physical_mem.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===---------- physical_mem.hpp - CUDA Adapter ---------------------------===//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include <ur/ur.hpp>

#include <cuda.h>

#include "adapter.hpp"
#include "device.hpp"
#include "platform.hpp"

/// UR queue mapping on physical memory allocations used in virtual memory
/// management.
///
struct ur_physical_mem_handle_t_ {
using native_type = CUmemGenericAllocationHandle;

std::atomic_uint32_t RefCount;
native_type PhysicalMem;
ur_context_handle_t_ *Context;

ur_physical_mem_handle_t_(native_type PhysMem, ur_context_handle_t_ *Ctx)
: RefCount(1), PhysicalMem(PhysMem), Context(Ctx) {
urContextRetain(Context);
}

~ur_physical_mem_handle_t_() { urContextRelease(Context); }

native_type get() const noexcept { return PhysicalMem; }

ur_context_handle_t_ *getContext() const noexcept { return Context; }

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }
};

// Find a device ordinal of a device.
inline ur_result_t GetDeviceOrdinal(ur_device_handle_t Device, int &Ordinal) {
ur_adapter_handle_t AdapterHandle = &adapter;
// Get list of platforms
uint32_t NumPlatforms;
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms));
UR_ASSERT(NumPlatforms, UR_RESULT_ERROR_UNKNOWN);

std::vector<ur_platform_handle_t> Platforms{NumPlatforms};
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, NumPlatforms,
Platforms.data(), nullptr));

// Ordinal corresponds to the platform ID as each device has its own platform.
CUdevice NativeDevice = Device->get();
for (Ordinal = 0; size_t(Ordinal) < Platforms.size(); ++Ordinal)
if (Platforms[Ordinal]->Devices[0]->get() == NativeDevice)
return UR_RESULT_SUCCESS;
return UR_RESULT_ERROR_INVALID_DEVICE;
}
20 changes: 10 additions & 10 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,13 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetVirtualMemProcAddrTable(
return retVal;
}

pDdiTable->pfnFree = nullptr;
pDdiTable->pfnGetInfo = nullptr;
pDdiTable->pfnGranularityGetInfo = nullptr;
pDdiTable->pfnMap = nullptr;
pDdiTable->pfnReserve = nullptr;
pDdiTable->pfnSetAccess = nullptr;
pDdiTable->pfnUnmap = nullptr;
pDdiTable->pfnFree = urVirtualMemFree;
pDdiTable->pfnGetInfo = urVirtualMemGetInfo;
pDdiTable->pfnGranularityGetInfo = urVirtualMemGranularityGetInfo;
pDdiTable->pfnMap = urVirtualMemMap;
pDdiTable->pfnReserve = urVirtualMemReserve;
pDdiTable->pfnSetAccess = urVirtualMemSetAccess;
pDdiTable->pfnUnmap = urVirtualMemUnmap;

return retVal;
}
Expand All @@ -381,9 +381,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetPhysicalMemProcAddrTable(
return retVal;
}

pDdiTable->pfnCreate = nullptr;
pDdiTable->pfnRelease = nullptr;
pDdiTable->pfnRetain = nullptr;
pDdiTable->pfnCreate = urPhysicalMemCreate;
pDdiTable->pfnRelease = urPhysicalMemRelease;
pDdiTable->pfnRetain = urPhysicalMemRetain;

return retVal;
}
Expand Down
135 changes: 135 additions & 0 deletions source/adapters/cuda/virtual_mem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
//===--------- virtual_mem.cpp - CUDA Adapter -----------------------------===//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "common.hpp"
#include "context.hpp"
#include "event.hpp"
#include "physical_mem.hpp"

#include <cassert>
#include <cuda.h>

UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemGranularityGetInfo(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
ur_virtual_mem_granularity_info_t propName, size_t propSize,
void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

ScopedContext Active(hContext);
switch (propName) {
case UR_VIRTUAL_MEM_GRANULARITY_INFO_MINIMUM:
case UR_VIRTUAL_MEM_GRANULARITY_INFO_RECOMMENDED: {
CUmemAllocationGranularity_flags Flags =
propName == UR_VIRTUAL_MEM_GRANULARITY_INFO_MINIMUM
? CU_MEM_ALLOC_GRANULARITY_MINIMUM
: CU_MEM_ALLOC_GRANULARITY_RECOMMENDED;
CUmemAllocationProp AllocProps = {};
AllocProps.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
AllocProps.type = CU_MEM_ALLOCATION_TYPE_PINNED;
UR_CHECK_ERROR(GetDeviceOrdinal(hDevice, AllocProps.location.id));

size_t Granularity;
UR_CHECK_ERROR(
cuMemGetAllocationGranularity(&Granularity, &AllocProps, Flags));
return ReturnValue(Granularity);
}
default:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urVirtualMemReserve(ur_context_handle_t hContext, const void *pStart,
size_t size, void **ppStart) {
ScopedContext Active(hContext);
UR_CHECK_ERROR(cuMemAddressReserve((CUdeviceptr *)ppStart, size, 0,
(CUdeviceptr)pStart, 0));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemFree(
ur_context_handle_t hContext, const void *pStart, size_t size) {
ScopedContext Active(hContext);
UR_CHECK_ERROR(cuMemAddressFree((CUdeviceptr)pStart, size));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urVirtualMemSetAccess(ur_context_handle_t hContext, const void *pStart,
size_t size, ur_virtual_mem_access_flags_t flags) {
CUmemAccessDesc AccessDesc = {};
if (flags & UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE)
AccessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
else if (flags & UR_VIRTUAL_MEM_ACCESS_FLAG_READ_ONLY)
AccessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READ;
else
AccessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_NONE;
AccessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// TODO: When contexts support multiple devices, we should create a descriptor
// for each. We may also introduce a variant of this function with a
// specific device.
UR_CHECK_ERROR(
GetDeviceOrdinal(hContext->getDevice(), AccessDesc.location.id));

ScopedContext Active(hContext);
UR_CHECK_ERROR(cuMemSetAccess((CUdeviceptr)pStart, size, &AccessDesc, 1));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urVirtualMemMap(ur_context_handle_t hContext, const void *pStart, size_t size,
ur_physical_mem_handle_t hPhysicalMem, size_t offset,
ur_virtual_mem_access_flags_t flags) {
ScopedContext Active(hContext);
UR_CHECK_ERROR(
cuMemMap((CUdeviceptr)pStart, size, offset, hPhysicalMem->get(), 0));
if (flags)
UR_CHECK_ERROR(urVirtualMemSetAccess(hContext, pStart, size, flags));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemUnmap(
ur_context_handle_t hContext, const void *pStart, size_t size) {
ScopedContext Active(hContext);
UR_CHECK_ERROR(cuMemUnmap((CUdeviceptr)pStart, size));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemGetInfo(
ur_context_handle_t hContext, const void *pStart,
[[maybe_unused]] size_t size, ur_virtual_mem_info_t propName,
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

ScopedContext Active(hContext);
switch (propName) {
case UR_VIRTUAL_MEM_INFO_ACCESS_MODE: {
CUmemLocation MemLocation = {};
MemLocation.type = CU_MEM_LOCATION_TYPE_DEVICE;
UR_CHECK_ERROR(GetDeviceOrdinal(hContext->getDevice(), MemLocation.id));

unsigned long long CuAccessFlags;
UR_CHECK_ERROR(
cuMemGetAccess(&CuAccessFlags, &MemLocation, (CUdeviceptr)pStart));

ur_virtual_mem_access_flags_t UrAccessFlags = 0;
if (CuAccessFlags == CU_MEM_ACCESS_FLAGS_PROT_READWRITE)
UrAccessFlags = UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE;
else if (CuAccessFlags == CU_MEM_ACCESS_FLAGS_PROT_READ)
UrAccessFlags = UR_VIRTUAL_MEM_ACCESS_FLAG_READ_ONLY;
return ReturnValue(UrAccessFlags);
}
default:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
}
return UR_RESULT_SUCCESS;
}
3 changes: 3 additions & 0 deletions source/adapters/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ add_ur_adapter(${TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.hpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/physical_mem.hpp
${CMAKE_CURRENT_SOURCE_DIR}/physical_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/platform.hpp
${CMAKE_CURRENT_SOURCE_DIR}/platform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/program.hpp
Expand All @@ -71,6 +73,7 @@ add_ur_adapter(${TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/sampler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
${CMAKE_CURRENT_SOURCE_DIR}/virtual_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.hpp
)
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
}
case UR_DEVICE_INFO_HOST_PIPE_READ_WRITE_SUPPORTED:
return ReturnValue(false);
case UR_DEVICE_INFO_VIRTUAL_MEMORY_SUPPORT:
return ReturnValue(false);
case UR_DEVICE_INFO_ESIMD_SUPPORT:
return ReturnValue(false);

Expand All @@ -833,7 +835,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_BFLOAT16:
case UR_DEVICE_INFO_IL_VERSION:
case UR_DEVICE_INFO_ASYNC_BARRIER:
case UR_DEVICE_INFO_VIRTUAL_MEMORY_SUPPORT:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;

default:
Expand Down
Loading

0 comments on commit 8d1486a

Please sign in to comment.