Skip to content

Commit

Permalink
umf providers
Browse files Browse the repository at this point in the history
  • Loading branch information
bratpiorka committed Aug 17, 2023
1 parent 33f3eef commit 6536c12
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ umfMemoryProvidersRegistryGet(struct umf_memory_provider_ops_t *providers,
const struct umf_memory_provider_ops_t *
umfMemoryProvidersRegistryGetOps(char *name);

bool umfMemoryProviderSupportsDevice(
const struct umf_memory_provider_ops_t *ops, const void *descr, size_t len);

///
/// \brief Allocates size bytes of uninitialized storage from memory provider
/// with specified alignment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ struct umf_memory_provider_ops_t {
/// Should be initialized using UMF_VERSION_CURRENT
uint32_t version;

void *priv;

///
/// \brief Initializes memory provider.
/// \param params provider-specific params
Expand All @@ -51,7 +53,7 @@ struct umf_memory_provider_ops_t {
enum umf_result_t (*purge_lazy)(void *provider, void *ptr, size_t size);
enum umf_result_t (*purge_force)(void *provider, void *ptr, size_t size);
const char *(*get_name)(void *provider);
bool (*supports_device)(const char *name);
bool (*supports_device)(const void* descr, size_t len);
};

#ifdef __cplusplus
Expand Down
25 changes: 19 additions & 6 deletions source/common/unified_malloc_framework/src/memory_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <vector>

Expand All @@ -25,7 +26,9 @@ struct umf_memory_provider_t {
void *provider_priv;
};

std::vector<struct umf_memory_provider_ops_t> globalProviders;
// TODO here I use the ptr to vector because the system calls
// globalProviders dectructor twice - why?
std::vector<struct umf_memory_provider_ops_t> *globalProviders;

enum umf_result_t
umfMemoryProviderCreate(const struct umf_memory_provider_ops_t *ops,
Expand Down Expand Up @@ -56,8 +59,13 @@ umfMemoryProviderCreate(const struct umf_memory_provider_ops_t *ops,

enum umf_result_t umfMemoryProviderRegister(umf_memory_provider_ops_t *ops) {

if (globalProviders == NULL) {
// TODO this is never freed
globalProviders = new std::vector<struct umf_memory_provider_ops_t>;
}

// TODO check if this provider isn't already registered
globalProviders.push_back(*ops);
globalProviders->push_back(*ops);

return UMF_RESULT_SUCCESS;
}
Expand All @@ -67,9 +75,9 @@ umfMemoryProvidersRegistryGet(umf_memory_provider_ops_t *providers,
size_t *numProviders) {

if (providers == NULL) {
*numProviders = globalProviders.size();
*numProviders = globalProviders->size();
} else {
memcpy(providers, globalProviders.data(),
memcpy(providers, globalProviders->data(),
sizeof(umf_memory_provider_ops_t) * *numProviders);
}

Expand All @@ -79,17 +87,22 @@ umfMemoryProvidersRegistryGet(umf_memory_provider_ops_t *providers,
// TODO rename ;)
const umf_memory_provider_ops_t *umfMemoryProvidersRegistryGetOps(char *name) {
auto it = std::find_if(
std::begin(globalProviders), std::end(globalProviders),
std::begin(*globalProviders), std::end(*globalProviders),
[&](auto &ops) { return std::strcmp(ops.get_name(NULL), name) == 0; });

if (it != globalProviders.end()) {
if (it != globalProviders->end()) {
return &(*it);
}

// else
return NULL;
}

bool umfMemoryProviderSupportsDevice(const umf_memory_provider_ops_t *ops,
const void *descr, size_t len) {
return ops->supports_device(descr, len);
}

void umfMemoryProviderDestroy(umf_memory_provider_handle_t hProvider) {
hProvider->ops.finalize(hProvider->provider_priv);
free(hProvider);
Expand Down
239 changes: 239 additions & 0 deletions source/common/ur_memory_provider.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
Copyright (c) 2023 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include <assert.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <umf.h>
#include <umf/memory_provider.h>

#include "ur_api.h"

struct ur_provider_config_t {
// NOTE: when the context is not NULL, create/get a provider instance
// based on it
ur_context_handle_t context;

// when the context is NULL, create/get a provider instance based
// on model, dev and PCI
char* model_name;
char* pci;

// type of USM allocations
ur_usm_type_t usm_type;
};

typedef struct ur_provider_priv_t
{
// TODO there will be always single dev per provider instance?
ur_context_handle_t context;
ur_device_handle_t device;
ur_usm_type_t usm_type;
} ur_provider_priv_t;

enum umf_result_t
ur_initialize(void *params, void **priv_ptr)
{
urInit(0, 0);

if (priv_ptr == NULL)
{
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

if (params == NULL)
{
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

*(struct ur_provider_priv_t **)priv_ptr = (struct ur_provider_priv_t *)malloc(
sizeof(struct ur_provider_priv_t));

ur_provider_priv_t *priv = (struct ur_provider_priv_t *)*priv_ptr;
if (priv == NULL)
{
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

priv->usm_type = ((struct ur_provider_config_t *)params)->usm_type;

if (((struct ur_provider_config_t *)params)->context)
{
priv->context = ((struct ur_provider_config_t *)params)->context;
// get the devices list from the context
size_t num_devices = 0;
urContextGetInfo(priv->context, UR_CONTEXT_INFO_NUM_DEVICES,
sizeof(size_t), &num_devices, NULL);
// assume there will be always single device per provider instance
assert(num_devices == 1); // TODO report error
urContextGetInfo(priv->context, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t), &priv->device, NULL);
assert(priv->device != NULL); // TODO report error

return UMF_RESULT_SUCCESS;
}
else if (((struct ur_provider_config_t *)params)->pci)
{
// find a device that matches this name

// NOTE: we browse all platforms here because right now
// there is a single UR provider for all platforms

uint32_t adapterCount = 0;
urAdapterGet(0, NULL, &adapterCount);
ur_adapter_handle_t *adapters = (ur_adapter_handle_t *)malloc(sizeof(ur_adapter_handle_t) * adapterCount);
urAdapterGet(adapterCount, adapters, NULL);

uint32_t platformCount = 0;
urPlatformGet(adapters, adapterCount, 1, NULL, &platformCount);
ur_platform_handle_t *platforms = (ur_platform_handle_t *)malloc(sizeof(ur_platform_handle_t) * platformCount);
urPlatformGet(adapters, adapterCount, platformCount, platforms, NULL);

for (size_t pid = 0; pid < platformCount; pid++)
{
uint32_t deviceCount = 0;
urDeviceGet(platforms[pid], UR_DEVICE_TYPE_GPU, 0, NULL, &deviceCount);
ur_device_handle_t *devices = (ur_device_handle_t *)malloc(sizeof(ur_device_handle_t) * deviceCount);
urDeviceGet(platforms[pid], UR_DEVICE_TYPE_GPU, deviceCount, devices,
NULL);

for (size_t did = 0; did < deviceCount; did++)
{
static const size_t DEVICE_INFO_MAX_LEN = 1024;

char *device_pci = (char *)malloc(DEVICE_INFO_MAX_LEN);
memset(device_pci, 0, DEVICE_INFO_MAX_LEN);
urDeviceGetInfo(devices[did], UR_DEVICE_INFO_PCI_ADDRESS, DEVICE_INFO_MAX_LEN - 1,
device_pci, NULL);

if (strcmp(((struct ur_provider_config_t *)params)->pci, device_pci) == 0)
{
ur_context_handle_t ctx = NULL;
urContextCreate(1, &devices[did], NULL, &ctx);
priv->context = ctx;
priv->device = devices[did];

free(device_pci);
return UMF_RESULT_SUCCESS;
}
}
}
}

// else
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
}

void ur_finalize(void *provider) { free(provider); }

static enum umf_result_t ur_alloc(void *provider, size_t size, size_t alignment,
void **resultPtr)
{
struct ur_provider_priv_t *config = (struct ur_provider_priv_t *)provider;
enum umf_result_t result = UMF_RESULT_SUCCESS;

// TODO check errors
assert(config->context);

switch (config->usm_type)
{
case UR_USM_TYPE_HOST:
// TODO ur_usm_desc_t
urUSMHostAlloc(config->context, NULL, NULL, size, resultPtr);
break;
case UR_USM_TYPE_DEVICE:
urUSMDeviceAlloc(config->context, config->device, NULL, NULL, size,
resultPtr);
break;
case UR_USM_TYPE_SHARED:
urUSMSharedAlloc(config->context, config->device, NULL, NULL, size,
resultPtr);
break;
default:
assert(0);
}

return result;
}

static enum umf_result_t ur_free(void *provider, void *ptr, size_t size)
{
struct ur_provider_priv_t *config = (struct ur_provider_priv_t *)provider;

// TODO check errors
urUSMFree(config->context, ptr);

// TODO - size?

return UMF_RESULT_SUCCESS;
}

void ur_get_last_native_error(void *provider, const char **ppMessage,
int32_t *pError)
{
// TODO
}

enum umf_result_t ur_get_min_page_size(void *provider, void *ptr,
size_t *pageSize)
{
*pageSize = 1024; // TODO call urVirtualMemGranularityGetInfo here
return UMF_RESULT_SUCCESS;
}

const char *ur_get_name(void *provider) { return "USM"; }

bool ur_supports_device(const void* descr, size_t len)
{
urInit(0, 0);

uint32_t adapterCount = 0;
urAdapterGet(0, NULL, &adapterCount);
ur_adapter_handle_t *adapters = (ur_adapter_handle_t *)malloc(sizeof(ur_adapter_handle_t) * adapterCount);
urAdapterGet(adapterCount, adapters, NULL);

uint32_t platformCount = 0;
urPlatformGet(adapters, adapterCount, 1, NULL, &platformCount);
ur_platform_handle_t *platforms = (ur_platform_handle_t *)malloc(sizeof(ur_platform_handle_t) * platformCount);
urPlatformGet(adapters, adapterCount, platformCount, platforms, NULL);

for (uint32_t pid = 0; pid < platformCount; pid++)
{
uint32_t deviceCount = 0;
urDeviceGet(platforms[pid], UR_DEVICE_TYPE_GPU, 0, NULL, &deviceCount);
ur_device_handle_t *devices = (ur_device_handle_t *)malloc(sizeof(ur_device_handle_t) * deviceCount);
urDeviceGet(platforms[pid], UR_DEVICE_TYPE_GPU, deviceCount, devices,
NULL);

for (uint32_t did = 0; did < deviceCount; did++)
{
static const size_t DEVICE_INFO_MAX_LEN = 1024;

char *device_name = (char *)malloc(DEVICE_INFO_MAX_LEN);
memset(device_name, 0, DEVICE_INFO_MAX_LEN);
urDeviceGetInfo(devices[did], UR_DEVICE_INFO_NAME, DEVICE_INFO_MAX_LEN - 1,
device_name, NULL);

if (strncmp((const char*)descr, device_name, len) == 0)
return true;
}
}

// no match
return false;
}
16 changes: 16 additions & 0 deletions source/loader/ur_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "logger/ur_logger.hpp"
#include "ur_loader.hpp"

#include "ur_memory_provider.hpp"

#include <cstring>

namespace ur_lib {
Expand Down Expand Up @@ -75,6 +77,20 @@ context_t::Init(ur_device_init_flags_t device_flags,
result = urInit();
}

umf_memory_provider_ops_t ur_memory_provider_ops = {
.version = UMF_VERSION_CURRENT,
.initialize = ur_initialize,
.finalize = ur_finalize,
.alloc = ur_alloc,
.free = ur_free,
.get_last_native_error = ur_get_last_native_error,
.get_min_page_size = ur_get_min_page_size,
.get_name = ur_get_name,
.supports_device = ur_supports_device,
};

umfMemoryProviderRegister(&ur_memory_provider_ops);

if (hLoaderConfig) {
enabledLayerNames.merge(hLoaderConfig->getEnabledLayerNames());
}
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2120,8 +2120,8 @@ ur_result_t UR_APICALL urUSMHostAlloc(
size_t
size, ///< [in] size in bytes of the USM memory object to be allocated
void **ppMem ///< [out] pointer to USM host memory object
) try {
auto pfnHostAlloc = ur_lib::context->urDdiTable.USM.pfnHostAlloc;
) try {
auto pfnHostAlloc = ur_lib::context->urDdiTable.USM.pfnHostAlloc;
if (nullptr == pfnHostAlloc) {
return UR_RESULT_ERROR_UNINITIALIZED;
}
Expand Down

0 comments on commit 6536c12

Please sign in to comment.