Skip to content

Commit

Permalink
Merge branch 'main' into steffen/virtual_mem_adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongreig committed Dec 18, 2023
2 parents 0563259 + 67e4d1b commit 1678894
Show file tree
Hide file tree
Showing 21 changed files with 83 additions and 96 deletions.
4 changes: 2 additions & 2 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
ur_context_handle_t hContext, ur_device_handle_t hDevice)
: Context(hContext),
Device(hDevice), CudaGraph{nullptr}, CudaGraphExec{nullptr}, RefCount{1} {
: Context(hContext), Device(hDevice), CudaGraph{nullptr},
CudaGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} {
urContextRetain(hContext);
urDeviceRetain(hDevice);
}
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct ur_exp_command_buffer_handle_t_ {

void RegisterSyncPoint(ur_exp_command_buffer_sync_point_t SyncPoint,
std::shared_ptr<CUgraphNode> CuNode) {
SyncPoints[SyncPoint] = CuNode;
SyncPoints[SyncPoint] = std::move(CuNode);
NextSyncPoint++;
}

Expand All @@ -193,12 +193,12 @@ struct ur_exp_command_buffer_handle_t_ {
}

// Helper to register next sync point
// @param CuNode Node to register as next sycn point
// @param CuNode Node to register as next sync point
// @return Pointer to the sync that registers the Node
ur_exp_command_buffer_sync_point_t
AddSyncPoint(std::shared_ptr<CUgraphNode> CuNode) {
ur_exp_command_buffer_sync_point_t SyncPoint = NextSyncPoint;
RegisterSyncPoint(SyncPoint, CuNode);
RegisterSyncPoint(SyncPoint, std::move(CuNode));
return SyncPoint;
}

Expand Down
15 changes: 8 additions & 7 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,17 +1144,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
if (Result != UR_RESULT_SUCCESS)
return Result;

ur_platform_handle_t *Plat = static_cast<ur_platform_handle_t *>(
malloc(NumPlatforms * sizeof(ur_platform_handle_t)));
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, Plat, nullptr);
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);

Result =
urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

// Iterate through platforms to find device that matches nativeHandle
for (uint32_t j = 0; j < NumPlatforms; ++j) {
auto SearchRes =
std::find_if(begin(Plat[j]->Devices), end(Plat[j]->Devices), IsDevice);
if (SearchRes != end(Plat[j]->Devices)) {
for (const auto Platform : Platforms) {
auto SearchRes = std::find_if(std::begin(Platform->Devices),
std::end(Platform->Devices), IsDevice);
if (SearchRes != end(Platform->Devices)) {
*phDevice = static_cast<ur_device_handle_t>((*SearchRes).get());
return UR_RESULT_SUCCESS;
}
Expand Down
29 changes: 9 additions & 20 deletions source/adapters/cuda/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
//===----------------------------------------------------------------------===//

#include "event.hpp"
#include "common.hpp"
#include "context.hpp"
#include "device.hpp"
#include "queue.hpp"
Expand All @@ -19,35 +18,25 @@

ur_event_handle_t_::ur_event_handle_t_(ur_command_t Type,
ur_context_handle_t Context,
ur_queue_handle_t Queue, CUstream Stream,
ur_queue_handle_t Queue,
native_type EvEnd, native_type EvQueued,
native_type EvStart, CUstream Stream,
uint32_t StreamToken)
: CommandType{Type}, RefCount{1}, HasOwnership{true},
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
StreamToken{StreamToken}, EvEnd{nullptr}, EvStart{nullptr},
EvQueued{nullptr}, Queue{Queue}, Stream{Stream}, Context{Context} {

bool ProfilingEnabled = Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE;

UR_CHECK_ERROR(cuEventCreate(
&EvEnd, ProfilingEnabled ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING));

if (ProfilingEnabled) {
UR_CHECK_ERROR(cuEventCreate(&EvQueued, CU_EVENT_DEFAULT));
UR_CHECK_ERROR(cuEventCreate(&EvStart, CU_EVENT_DEFAULT));
}

if (Queue != nullptr) {
urQueueRetain(Queue);
}
StreamToken{StreamToken}, EventID{0}, EvEnd{EvEnd}, EvStart{EvStart},
EvQueued{EvQueued}, Queue{Queue}, Stream{Stream}, Context{Context} {
urQueueRetain(Queue);
urContextRetain(Context);
}

ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context,
CUevent EventNative)
: CommandType{UR_COMMAND_EVENTS_WAIT}, RefCount{1}, HasOwnership{false},
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
StreamToken{std::numeric_limits<uint32_t>::max()}, EvEnd{EventNative},
EvStart{nullptr}, EvQueued{nullptr}, Queue{nullptr}, Context{Context} {
StreamToken{std::numeric_limits<uint32_t>::max()}, EventID{0},
EvEnd{EventNative}, EvStart{nullptr}, EvQueued{nullptr}, Queue{nullptr},
Stream{nullptr}, Context{Context} {
urContextRetain(Context);
}

Expand Down
18 changes: 15 additions & 3 deletions source/adapters/cuda/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cuda.h>
#include <ur/ur.hpp>

#include "common.hpp"
#include "queue.hpp"

/// UR Event mapping to CUevent
Expand Down Expand Up @@ -82,8 +83,18 @@ struct ur_event_handle_t_ {
static ur_event_handle_t
makeNative(ur_command_t Type, ur_queue_handle_t Queue, CUstream Stream,
uint32_t StreamToken = std::numeric_limits<uint32_t>::max()) {
return new ur_event_handle_t_(Type, Queue->getContext(), Queue, Stream,
StreamToken);
const bool ProfilingEnabled =
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE;
native_type EvEnd = nullptr, EvQueued = nullptr, EvStart = nullptr;
UR_CHECK_ERROR(cuEventCreate(
&EvEnd, ProfilingEnabled ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING));

if (ProfilingEnabled) {
UR_CHECK_ERROR(cuEventCreate(&EvQueued, CU_EVENT_DEFAULT));
UR_CHECK_ERROR(cuEventCreate(&EvStart, CU_EVENT_DEFAULT));
}
return new ur_event_handle_t_(Type, Queue->getContext(), Queue, EvEnd,
EvQueued, EvStart, Stream, StreamToken);
}

static ur_event_handle_t makeWithNative(ur_context_handle_t context,
Expand All @@ -99,7 +110,8 @@ struct ur_event_handle_t_ {
// This constructor is private to force programmers to use the makeNative /
// make_user static members in order to create a pi_event for CUDA.
ur_event_handle_t_(ur_command_t Type, ur_context_handle_t Context,
ur_queue_handle_t Queue, CUstream Stream,
ur_queue_handle_t Queue, native_type EvEnd,
native_type EvQueued, native_type EvStart, CUstream Stream,
uint32_t StreamToken);

// This constructor is private to force programmers to use the
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/cuda/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ cudaToUrImageChannelFormat(CUarray_format cuda_format,

ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
const ur_image_desc_t *pImageDesc,
CUDA_RESOURCE_DESC ResourceDesc,
const CUDA_RESOURCE_DESC &ResourceDesc,
ur_exp_image_handle_t *phRetImage) {

try {
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/cuda/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ struct ur_mem_handle_t_ {
/// Constructs the UR allocation for an unsampled image object
ur_mem_handle_t_(ur_context_handle_t Context, CUarray Array,
CUsurfObject Surf, ur_mem_type_t ImageType)
: Context{Context}, RefCount{1}, MemType{Type::Surface},
: Context{Context}, RefCount{1}, MemType{Type::Surface}, MemFlags{0},
Mem{ImageMem{Array, (void *)Surf, ImageType, nullptr}} {
urContextRetain(Context);
}

/// Constructs the UR allocation for a sampled image object
ur_mem_handle_t_(ur_context_handle_t Context, CUarray Array, CUtexObject Tex,
ur_sampler_handle_t Sampler, ur_mem_type_t ImageType)
: Context{Context}, RefCount{1}, MemType{Type::Texture},
: Context{Context}, RefCount{1}, MemType{Type::Texture}, MemFlags{0},
Mem{ImageMem{Array, (void *)Tex, ImageType, Sampler}} {
urContextRetain(Context);
}
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ ur_result_t ur_program_handle_t_::buildProgram(const char *BuildOptions) {

if (!this->BuildOptions.empty()) {
unsigned int MaxRegs;
bool Valid = getMaxRegistersJitOptionValue(BuildOptions, MaxRegs);
const bool Valid =
getMaxRegistersJitOptionValue(this->BuildOptions, MaxRegs);
if (Valid) {
Options.push_back(CU_JIT_MAX_REGISTERS);
OptionVals.push_back(reinterpret_cast<void *>(MaxRegs));
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/cuda/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ urSamplerCreate(ur_context_handle_t hContext, const ur_sampler_desc_t *pDesc,
std::unique_ptr<ur_sampler_handle_t_> Sampler{
new ur_sampler_handle_t_(hContext)};

if (pDesc && pDesc->stype == UR_STRUCTURE_TYPE_SAMPLER_DESC) {
if (pDesc->stype == UR_STRUCTURE_TYPE_SAMPLER_DESC) {
Sampler->Props |= pDesc->normalizedCoords;
Sampler->Props |= pDesc->filterMode << 1;
Sampler->Props |= pDesc->addressingMode << 2;
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct ur_device_handle_t_ {
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context), DeviceIndex(DeviceIndex) {}

~ur_device_handle_t_() {
~ur_device_handle_t_() noexcept(false) {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
}

Expand Down
31 changes: 8 additions & 23 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);

ur_result_t Result = UR_RESULT_SUCCESS;

ur_lock MemoryMigrationLock{hImage->MemoryMigrationMutex};
auto Device = hQueue->getDevice();
hipStream_t HIPStream = hQueue->getNextTransferStream();
Expand Down Expand Up @@ -1039,13 +1037,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
UR_CHECK_ERROR(RetImplEvent->start());
}

Result = commonEnqueueMemImageNDCopy(HIPStream, ImgType, AdjustedRegion,
Array, hipMemoryTypeArray, SrcOffset,
pDst, hipMemoryTypeHost, nullptr);

if (Result != UR_RESULT_SUCCESS) {
return Result;
}
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
HIPStream, ImgType, AdjustedRegion, Array, hipMemoryTypeArray,
SrcOffset, pDst, hipMemoryTypeHost, nullptr));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
Expand All @@ -1061,7 +1055,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
return UR_RESULT_ERROR_UNKNOWN;
}
return UR_RESULT_SUCCESS;
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
Expand All @@ -1071,15 +1064,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);

ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
}

hipArray *Array =
Expand Down Expand Up @@ -1107,13 +1098,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
UR_CHECK_ERROR(RetImplEvent->start());
}

Result = commonEnqueueMemImageNDCopy(HIPStream, ImgType, AdjustedRegion,
pSrc, hipMemoryTypeHost, nullptr,
Array, hipMemoryTypeArray, DstOffset);

if (Result != UR_RESULT_SUCCESS) {
return Result;
}
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
HIPStream, ImgType, AdjustedRegion, pSrc, hipMemoryTypeHost, nullptr,
Array, hipMemoryTypeArray, DstOffset));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
Expand All @@ -1126,8 +1113,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
}

return UR_RESULT_SUCCESS;

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
Expand Down
22 changes: 13 additions & 9 deletions source/adapters/hip/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ur_event_handle_t_::ur_event_handle_t_(ur_command_t Type,
hipStream_t Stream, uint32_t StreamToken)
: CommandType{Type}, RefCount{1}, HasOwnership{true},
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
StreamToken{StreamToken}, EvEnd{nullptr}, EvStart{nullptr},
StreamToken{StreamToken}, EventId{0}, EvEnd{nullptr}, EvStart{nullptr},
EvQueued{nullptr}, Queue{Queue}, Stream{Stream}, Context{Context} {

bool ProfilingEnabled = Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE;
Expand All @@ -32,18 +32,17 @@ ur_event_handle_t_::ur_event_handle_t_(ur_command_t Type,
UR_CHECK_ERROR(hipEventCreateWithFlags(&EvStart, hipEventDefault));
}

if (Queue != nullptr) {
urQueueRetain(Queue);
}
urQueueRetain(Queue);
urContextRetain(Context);
}

ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context,
hipEvent_t EventNative)
: CommandType{UR_COMMAND_EVENTS_WAIT}, RefCount{1}, HasOwnership{false},
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
StreamToken{std::numeric_limits<uint32_t>::max()}, EvEnd{EventNative},
EvStart{nullptr}, EvQueued{nullptr}, Queue{nullptr}, Context{Context} {
StreamToken{std::numeric_limits<uint32_t>::max()}, EventId{0},
EvEnd{EventNative}, EvStart{nullptr}, EvQueued{nullptr}, Queue{nullptr},
Stream{nullptr}, Context{Context} {
urContextRetain(Context);
}

Expand Down Expand Up @@ -72,7 +71,7 @@ ur_result_t ur_event_handle_t_::start() {
return Result;
}

bool ur_event_handle_t_::isCompleted() const noexcept {
bool ur_event_handle_t_::isCompleted() const {
if (!IsRecorded) {
return false;
}
Expand Down Expand Up @@ -225,8 +224,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
return ReturnValue(hEvent->getCommandType());
case UR_EVENT_INFO_REFERENCE_COUNT:
return ReturnValue(hEvent->getReferenceCount());
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS:
return ReturnValue(hEvent->getExecutionStatus());
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
try {
return ReturnValue(hEvent->getExecutionStatus());
} catch (ur_result_t Error) {
return Error;
}
}
case UR_EVENT_INFO_CONTEXT:
return ReturnValue(hEvent->getContext());
default:
Expand Down
5 changes: 2 additions & 3 deletions source/adapters/hip/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ struct ur_event_handle_t_ {

bool isStarted() const noexcept { return IsStarted; }

bool isCompleted() const noexcept;

uint32_t getExecutionStatus() const noexcept {
bool isCompleted() const;

uint32_t getExecutionStatus() const {
if (!isRecorded()) {
return UR_EVENT_STATUS_SUBMITTED;
}
Expand Down
14 changes: 5 additions & 9 deletions source/adapters/hip/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,16 @@ struct SurfaceMem {
void *HostPtr)
: Arrays(Context->Devices.size(), nullptr),
SurfObjs(Context->Devices.size(), nullptr),
OuterMemStruct{OuterMemStruct},
ImageFormat{ImageFormat}, ImageDesc{ImageDesc}, HostPtr{HostPtr} {
OuterMemStruct{OuterMemStruct}, ImageFormat{ImageFormat},
ImageDesc{ImageDesc}, ArrayDesc{}, HostPtr{HostPtr} {
// We have to use hipArray3DCreate, which has some caveats. The height and
// depth parameters must be set to 0 produce 1D or 2D arrays. image_desc
// gives a minimum value of 1, so we need to convert the answer.
ArrayDesc.NumChannels = 4; // Only support 4 channel image
ArrayDesc.Flags = 0; // No flags required
ArrayDesc.Width = ImageDesc.width;
if (ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
ArrayDesc.Height = 0;
ArrayDesc.Depth = 0;
} else if (ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
if (ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
ArrayDesc.Height = ImageDesc.height;
ArrayDesc.Depth = 0;
} else if (ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
ArrayDesc.Height = ImageDesc.height;
ArrayDesc.Depth = ImageDesc.depth;
Expand Down Expand Up @@ -456,7 +452,7 @@ struct ur_mem_handle_t_ {
urContextRetain(Context);
}

~ur_mem_handle_t_() {
~ur_mem_handle_t_() noexcept(false) {
if (isBuffer() && isSubBuffer()) {
urMemRelease(std::get<BufferMem>(Mem).Parent);
return;
Expand All @@ -468,7 +464,7 @@ struct ur_mem_handle_t_ {
return std::holds_alternative<BufferMem>(Mem);
}

bool isSubBuffer() const noexcept {
bool isSubBuffer() const {
return (isBuffer() && (std::get<BufferMem>(Mem).Parent != nullptr));
}

Expand Down
Loading

0 comments on commit 1678894

Please sign in to comment.