Skip to content

Commit

Permalink
add supports_device() to ops, use provider name
Browse files Browse the repository at this point in the history
  • Loading branch information
bratpiorka committed Jul 19, 2023
1 parent 72ec0fe commit 3a4d29f
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ typedef struct umf_memory_provider_t *umf_memory_provider_handle_t;
/// \return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
///
enum umf_result_t
umfMemoryProviderCreate(struct umf_memory_provider_ops_t *ops, void *params,
umf_memory_provider_handle_t *hProvider);
umfMemoryProviderCreate(const struct umf_memory_provider_ops_t *ops,
void *params, umf_memory_provider_handle_t *hProvider);

///
/// \brief Destroys memory provider.
Expand All @@ -39,11 +39,11 @@ void umfMemoryProviderDestroy(umf_memory_provider_handle_t hProvider);

// TODO comment
enum umf_result_t
umfMemoryProviderRegister(struct umf_memory_provider_ops_t *ops, char *name);
umfMemoryProviderRegister(struct umf_memory_provider_ops_t *ops);
enum umf_result_t umfMemoryProvidersRegisterGetNames(char *providers,
size_t *numProviders);
umf_memory_provider_type_t umfMemoryProvidersRegisterGetType(char *name);
struct umf_memory_provider_ops_t umfMemoryProvidersRegisterGetOps(char *name);
const struct umf_memory_provider_ops_t *
umfMemoryProvidersRegisterGetOps(char *name);

///
/// \brief Allocates size bytes of uninitialized storage from memory provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef UMF_MEMORY_PROVIDER_OPS_H
#define UMF_MEMORY_PROVIDER_OPS_H 1

#include <stdbool.h>
#include <umf/base.h>

#ifdef __cplusplus
Expand Down Expand Up @@ -63,6 +64,7 @@ struct umf_memory_provider_ops_t {
enum umf_result_t (*purge_lazy)(void *provider, void *ptr, size_t size);
enum umf_result_t (*purge_force)(void *provider, void *ptr, size_t size);
const char *(*get_name)(void *provider);
bool (*supports_device)(const char *name);
};

#ifdef __cplusplus
Expand Down
56 changes: 32 additions & 24 deletions source/common/unified_malloc_framework/src/memory_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
#include <assert.h>
#include <stdlib.h>

#include <algorithm>
#include <cstring>
#include <map>
#include <list>
#include <string>

#include <umf/memory_provider.h>
#include "umf/memory_provider.h"

#include "memory_provider_internal.h"
#include "os_memory_provider.h"
Expand All @@ -25,11 +26,11 @@ struct umf_memory_provider_t {
void *provider_priv;
};

std::map<std::string, umf_memory_provider_ops_t> globalProviders;
std::list<struct umf_memory_provider_ops_t> globalProviders;

enum umf_result_t
umfMemoryProviderCreate(struct umf_memory_provider_ops_t *ops, void *params,
umf_memory_provider_handle_t *hProvider) {
umfMemoryProviderCreate(const struct umf_memory_provider_ops_t *ops,
void *params, umf_memory_provider_handle_t *hProvider) {
umf_memory_provider_handle_t provider =
(umf_memory_provider_t *)malloc(sizeof(struct umf_memory_provider_t));
if (!provider) {
Expand All @@ -54,27 +55,27 @@ umfMemoryProviderCreate(struct umf_memory_provider_ops_t *ops, void *params,
return UMF_RESULT_SUCCESS;
}

enum umf_result_t umfMemoryProviderRegister(umf_memory_provider_ops_t *ops,
char *name) {
enum umf_result_t umfMemoryProviderRegister(umf_memory_provider_ops_t *ops) {

// TODO improve - use the ops->get_name()
globalProviders[name] = *ops;
// TODO check if this provider isn't already registered
globalProviders.push_back(*ops);

return UMF_RESULT_SUCCESS;
}

enum umf_result_t umfMemoryProvidersRegisterGetNames(char *providers,
size_t *numProviders) {
// TODO improve
if (globalProviders.count("OS") == 0) {

umfMemoryProviderRegister(&OS_MEMORY_PROVIDER_OPS,
std::string("OS").data());

// TODO IMPORTANT
// as the NUMA (OS) memory provider is the default provider here in the UMF,
// it should be available (predefined) somehow so a user could use it
// without any umalloc or UR libs etc
auto it = std::find_if(
std::begin(globalProviders), std::end(globalProviders),
[&](auto &ops) { return std::strcmp(ops.get_name(NULL), "OS") == 0; });

// TODO IMPORTANT
// as the NUMA (OS) memory provider is the default provider here in the UMF,
// it should be available (predefined) somehow so a user could use it
// without any umalloc or UR libs etc
if (it == globalProviders.end()) {
umfMemoryProviderRegister(&OS_MEMORY_PROVIDER_OPS);
}

if (providers == NULL) {
Expand All @@ -86,21 +87,28 @@ enum umf_result_t umfMemoryProvidersRegisterGetNames(char *providers,
// : *numProviders;

for (auto p : globalProviders) {
std::strcat(providers, p.first.c_str());
std::strcat(providers, p.get_name(NULL));
std::strcat(providers, ";");
}
// remove last ';'
providers[std::strlen(providers) - 1] = '\0';
}

return UMF_RESULT_SUCCESS;
}

// TODO rename ;)
umf_memory_provider_ops_t umfMemoryProvidersRegisterGetOps(char *name) {
return globalProviders[name];
}
const umf_memory_provider_ops_t *umfMemoryProvidersRegisterGetOps(char *name) {
auto it = std::find_if(
std::begin(globalProviders), std::end(globalProviders),
[&](auto &ops) { return std::strcmp(ops.get_name(NULL), name) == 0; });

if (it != globalProviders.end()) {
return &(*it);
}

umf_memory_provider_type_t umfMemoryProvidersRegisterGetType(char *name) {
return globalProviders[name].type;
// else
return NULL;
}

void umfMemoryProviderDestroy(umf_memory_provider_handle_t hProvider) {
Expand Down
12 changes: 11 additions & 1 deletion source/common/unified_malloc_framework/src/os_memory_provider.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,21 @@ static enum umf_result_t os_free(void *provider, void *ptr, size_t bytes) {
void os_get_last_native_error(void *provider, const char **ppMessage,
int32_t *pError) {}

const char *os_get_name(void *provider) { return "OS"; }

bool os_supports_device(const char *device) {
// TODO
return (bool)(strcmp(device, "NUMA") == 0);
}

struct umf_memory_provider_ops_t OS_MEMORY_PROVIDER_OPS = {
.version = UMF_VERSION_CURRENT,
.type = UMF_MEMORY_PROVIDER_TYPE_NUMA,
.initialize = os_initialize,
.finalize = os_finalize,
.alloc = os_alloc,
.free = os_free,
.get_last_native_error = os_get_last_native_error};
.get_last_native_error = os_get_last_native_error,
.get_name = os_get_name,
.supports_device = os_supports_device,
};
3 changes: 1 addition & 2 deletions source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ ur_result_t UR_APICALL urInit(

std::cout << "platform_name: " << platform_name.data() << "\n";

umfMemoryProviderRegister(&UR_MEMORY_PROVIDER_OPS,
platform_name.data());
umfMemoryProviderRegister(&UR_MEMORY_PROVIDER_OPS);
}
});

Expand Down
66 changes: 62 additions & 4 deletions source/loader/ur_memory_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <string.h>
#include <unistd.h>

#include <vector>

#include <ur_api.h>

#include <umf.h>
Expand All @@ -37,7 +39,8 @@
#include <ur_memory_provider.hpp>

enum umf_result_t ur_initialize(void *params, void **pool) {

urInit(0);

if (pool == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}
Expand All @@ -49,16 +52,62 @@ enum umf_result_t ur_initialize(void *params, void **pool) {
struct ur_provider_config_t *config = (struct ur_provider_config_t *)malloc(
sizeof(struct ur_provider_config_t));

if (config) {
if (config == NULL) {
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

if (((ur_provider_config_t *)params)->context) {
// TODO get list of devices from the context
config->device = ((ur_provider_config_t *)params)->device;
config->context = ((ur_provider_config_t *)params)->context;
*pool = config;

return UMF_RESULT_SUCCESS;
} else if (((ur_provider_config_t *)params)->name) {

// find a device that matches this name

// NOTE: we browse all platforms here because right now
// there is a single UR provider for all platforms

uint32_t platformCount = 0;
std::vector<ur_platform_handle_t> platforms;
urPlatformGet(1, nullptr, &platformCount);
platforms.resize(platformCount);
urPlatformGet(platformCount, platforms.data(), nullptr);

for (auto p : platforms) {
uint32_t deviceCount = 0;
urDeviceGet(p, UR_DEVICE_TYPE_GPU, 0, nullptr, &deviceCount);
std::vector<ur_device_handle_t> devices(deviceCount);
urDeviceGet(p, UR_DEVICE_TYPE_GPU, deviceCount, devices.data(),
nullptr);

for (auto d : devices) {
ur_device_type_t device_type = UR_DEVICE_TYPE_ALL;
urDeviceGetInfo(d, UR_DEVICE_INFO_TYPE,
sizeof(ur_device_type_t),
static_cast<void *>(&device_type), nullptr);
static const size_t DEVICE_NAME_MAX_LEN = 1024;
char device_name[DEVICE_NAME_MAX_LEN] = {0};
urDeviceGetInfo(d, UR_DEVICE_INFO_NAME, DEVICE_NAME_MAX_LEN - 1,
static_cast<void *>(&device_name), nullptr);
if (strcmp(((ur_provider_config_t *)params)->name,
device_name) == 0) {

ur_context_handle_t ctx = NULL;
urContextCreate(1, &d, NULL, &ctx);
config->device = d;
config->context = ctx;
*pool = config;

return UMF_RESULT_SUCCESS;
}
}
}
}

// else
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
}

void ur_finalize(void *provider) { free(provider); }
Expand Down Expand Up @@ -100,6 +149,13 @@ enum umf_result_t ur_get_min_page_size(void *provider, void *ptr,
return UMF_RESULT_SUCCESS;
}

const char *ur_get_name(void *provider) { return "USM"; }

bool ur_supports_device(const char *device) {
// TODO
return (strcmp(device, "Intel(R) Arc(TM) A750 Graphics") == 0);
}

struct umf_memory_provider_ops_t UR_MEMORY_PROVIDER_OPS = {
.version = UMF_VERSION_CURRENT,
.type = UMF_MEMORY_PROVIDER_TYPE_USM,
Expand All @@ -109,4 +165,6 @@ struct umf_memory_provider_ops_t UR_MEMORY_PROVIDER_OPS = {
.free = ur_free,
.get_last_native_error = ur_get_last_native_error,
.get_min_page_size = ur_get_min_page_size,
.get_name = ur_get_name,
.supports_device = ur_supports_device,
};
1 change: 1 addition & 0 deletions source/loader/ur_memory_provider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern "C" {
#include "umf/memory_provider.h"

struct ur_provider_config_t {
char *name;
ur_device_handle_t device;
ur_context_handle_t context;
};
Expand Down

0 comments on commit 3a4d29f

Please sign in to comment.