Skip to content

Commit

Permalink
fix: read data from buffer if data is not present
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Aug 23, 2024
1 parent 413d754 commit 81c1a6d
Show file tree
Hide file tree
Showing 85 changed files with 2,835 additions and 1,629 deletions.
15 changes: 14 additions & 1 deletion cmake/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ iree::runtime::Device::~Device() {
}
}

iree::runtime::IREETensor::IREETensor(iree_hal_buffer_view_t *buffer_view, iree_hal_element_type_t type) : buffer_view(buffer_view), type(type) {
iree::runtime::IREETensor::IREETensor(iree_hal_buffer_view_t *buffer_view, iree_hal_element_type_t type, iree_hal_device_t *device) : buffer_view(buffer_view), type(type), device(device) {
size = iree_hal_buffer_view_byte_length(buffer_view);
// TODO: fill in dim metadata
}
Expand Down Expand Up @@ -81,6 +81,19 @@ std::vector<char> *iree::runtime::IREETensor::serialize() {
size_t size_size = sizeof(size);
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&size), reinterpret_cast<const char *>(&size) + size_size);

if (data == nullptr) {
data = std::malloc(size);

if (data == nullptr) {
return nullptr;
}

auto status = read_buffer(device, buffer_view, data, size);

if (!iree_status_is_ok(status)) {
return nullptr;
}
}
// Serialize 'data'
buffer->insert(buffer->end(), reinterpret_cast<const char *>(data), reinterpret_cast<const char *>(data) + size);

Expand Down
3 changes: 2 additions & 1 deletion cmake/src/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class IREETensor {
std::vector<iree_hal_dim_t> dims;
iree_hal_element_type_t type;
iree_hal_buffer_view_t* buffer_view;
iree_hal_device_t* device;

IREETensor(char* serialized_data);
IREETensor(iree_hal_buffer_view_t* buffer_view, iree_hal_element_type_t type);
IREETensor(iree_hal_buffer_view_t* buffer_view, iree_hal_element_type_t type, iree_hal_device_t* device);
IREETensor(void* data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type);

// Destructor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
0A6CE5F5167F6157FF061569 /* DisconnectedView.swift in Sources */ = {isa = PBXBuildFile; fileRef = D488732A3104CEE004CB7B42 /* DisconnectedView.swift */; };
125165EBA8D40490F0AABE88 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 51F3B59B97B70D891071B6C4 /* Assets.xcassets */; };
137E19B83B668B9C3130A3D7 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = D7955FA5760A1DB1E827049B /* ContentView.swift */; };
373237312C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 373237212C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237322C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 373237232C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237332C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 373237252C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237342C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 373237272C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237352C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 373237292C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237362C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 3732372B2C7861AC009D1652 /* libnx_iree_runtime.so */; };
373237372C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */ = {isa = PBXBuildFile; fileRef = 3732372D2C7861AC009D1652 /* libnx_iree_runtime.so */; };
373496882C7707DE00FD2E6F /* NxAddon.swift in Sources */ = {isa = PBXBuildFile; fileRef = 373496872C7707DE00FD2E6F /* NxAddon.swift */; };
3734968B2C77081100FD2E6F /* NxFunctionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3734968A2C77081100FD2E6F /* NxFunctionView.swift */; };
379E276B2C76D40900624AA7 /* nx_iree.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 379E276A2C76D40900624AA7 /* nx_iree.cpp */; };
Expand Down Expand Up @@ -395,6 +388,13 @@
373237292C7861AC009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
3732372B2C7861AC009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
3732372D2C7861AC009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
373237392C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
3732373B2C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
3732373D2C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
3732373F2C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
373237412C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
373237432C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
373237452C786921009D1652 /* libnx_iree_runtime.so */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; path = libnx_iree_runtime.so; sourceTree = "<group>"; };
373496872C7707DE00FD2E6F /* NxAddon.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NxAddon.swift; sourceTree = "<group>"; };
3734968A2C77081100FD2E6F /* NxFunctionView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NxFunctionView.swift; sourceTree = "<group>"; };
379E27672C76D25900624AA7 /* LiveNxIREE-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "LiveNxIREE-Bridging-Header.h"; sourceTree = "<group>"; };
Expand All @@ -415,13 +415,6 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
373237372C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237312C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237362C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237322C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237332C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237352C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
373237342C7861AC009D1652 /* libnx_iree_runtime.so in Frameworks */,
459F92DE90F9275614EDC9F7 /* LiveViewNative in Frameworks */,
E4C8EE2081B99D0A468A7913 /* LiveViewNativeLiveForm in Frameworks */,
);
Expand Down Expand Up @@ -1413,6 +1406,83 @@
path = nx_iree;
sourceTree = "<group>";
};
373237382C786921009D1652 /* Frameworks */ = {
isa = PBXGroup;
children = (
3732373C2C786921009D1652 /* host */,
373237422C786921009D1652 /* ios */,
3732373A2C786921009D1652 /* ios_simulator */,
373237462C786921009D1652 /* tvos */,
3732373E2C786921009D1652 /* tvos_simulator */,
373237442C786921009D1652 /* visionos */,
373237402C786921009D1652 /* visionos_simulator */,
);
name = Frameworks;
sourceTree = "<group>";
};
3732373A2C786921009D1652 /* ios_simulator */ = {
isa = PBXGroup;
children = (
373237392C786921009D1652 /* libnx_iree_runtime.so */,
);
name = ios_simulator;
path = LiveNxIREE/nx_iree/lib/ios_simulator;
sourceTree = "<group>";
};
3732373C2C786921009D1652 /* host */ = {
isa = PBXGroup;
children = (
3732373B2C786921009D1652 /* libnx_iree_runtime.so */,
);
name = host;
path = LiveNxIREE/nx_iree/lib/host;
sourceTree = "<group>";
};
3732373E2C786921009D1652 /* tvos_simulator */ = {
isa = PBXGroup;
children = (
3732373D2C786921009D1652 /* libnx_iree_runtime.so */,
);
name = tvos_simulator;
path = LiveNxIREE/nx_iree/lib/tvos_simulator;
sourceTree = "<group>";
};
373237402C786921009D1652 /* visionos_simulator */ = {
isa = PBXGroup;
children = (
3732373F2C786921009D1652 /* libnx_iree_runtime.so */,
);
name = visionos_simulator;
path = LiveNxIREE/nx_iree/lib/visionos_simulator;
sourceTree = "<group>";
};
373237422C786921009D1652 /* ios */ = {
isa = PBXGroup;
children = (
373237412C786921009D1652 /* libnx_iree_runtime.so */,
);
name = ios;
path = LiveNxIREE/nx_iree/lib/ios;
sourceTree = "<group>";
};
373237442C786921009D1652 /* visionos */ = {
isa = PBXGroup;
children = (
373237432C786921009D1652 /* libnx_iree_runtime.so */,
);
name = visionos;
path = LiveNxIREE/nx_iree/lib/visionos;
sourceTree = "<group>";
};
373237462C786921009D1652 /* tvos */ = {
isa = PBXGroup;
children = (
373237452C786921009D1652 /* libnx_iree_runtime.so */,
);
name = tvos;
path = LiveNxIREE/nx_iree/lib/tvos;
sourceTree = "<group>";
};
373496892C7707F700FD2E6F /* NxAddon */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -1443,6 +1513,7 @@
children = (
2F2E8465248E555C831C9B55 /* LiveNxIREE */,
6CF8FEA143D447C31948C8CB /* Products */,
373237382C786921009D1652 /* Frameworks */,
);
sourceTree = "<group>";
};
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <string>
#include <vector>
#include <iostream>

int nx_iree_initialize(iree_vm_instance_t* vm_instance, iree_hal_driver_registry_t* driver_registry, char* error_message) {
vm_instance = create_instance();
Expand Down Expand Up @@ -49,17 +50,22 @@ char** nx_iree_call(iree_vm_instance_t* vm_instance, iree_hal_device_t* device,
// driver name is hardcoded because there is only a check for CUDA
auto [status, optional_result] = call(vm_instance, device, "not_cuda", bytecode, static_cast<size_t>(bytecode_size), inputs);

if (!is_ok(status)) {
if (!is_ok(status) || !optional_result.has_value()) {
std::string msg = get_status_message(status);
strncpy(error_message, msg.c_str(), msg.length());
return nullptr;
}

auto serialized_outputs = new char*[num_outputs];

auto result = optional_result.value();
std::vector<iree::runtime::IREETensor *> result = optional_result.value();
for (size_t i = 0; i < num_outputs; i++) {
std::vector<char> *serialized = result[i]->serialize();
iree::runtime::IREETensor *tensor = result[i];
std::cout << "Tensor: " << tensor;
if (tensor == nullptr) {
std::cout << "Failed output allocation at index " << i << "\n";
}
std::vector<char> *serialized = tensor->serialize();
serialized_outputs[i] = new char[serialized->size()];
memcpy(serialized_outputs[i], serialized->data(), serialized->size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ void iree_thread_request_affinity(iree_thread_t* thread,
// This has no effect if the thread is not suspended.
void iree_thread_resume(iree_thread_t* thread);

// Blocks the current thread until |thread| has finished its execution.
void iree_thread_join(iree_thread_t* thread);

void iree_thread_yield(void);

#ifdef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,11 @@ TEST_F(LoopTest, WaitOneBlocking) {
// Spin up the thread to signal the event after a short delay.
// We need to do this before we issue the wait so that loops which perform the
// wait inline can still make forward progress even if they block.
//
// Note: there is a race here that can cause flakes. If this new thread
// starts running after the iree_loop_wait_* timeout below, the status check
// will fail. We make the timeout there sufficiently long to give the OS
// enough time to switch threads a few times.
std::thread thread([&]() {
IREE_TRACE_SCOPE();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
Expand All @@ -606,7 +611,7 @@ TEST_F(LoopTest, WaitOneBlocking) {
bool did_wait_callback = false;
} user_data;
IREE_ASSERT_OK(iree_loop_wait_one(
loop, wait_source, iree_make_timeout_ms(200),
loop, wait_source, iree_make_timeout_ms(2000),
+[](void* user_data_ptr, iree_loop_t loop, iree_status_t status) {
IREE_TRACE_SCOPE();
IREE_EXPECT_OK(status);
Expand Down Expand Up @@ -768,6 +773,11 @@ TEST_F(LoopTest, WaitAnyBlocking) {
// Spin up the thread to signal the event after a short delay.
// We need to do this before we issue the wait so that loops which perform the
// wait inline can still make forward progress even if they block.
//
// Note: there is a race here that can cause flakes. If this new thread
// starts running after the iree_loop_wait_* timeout below, the status check
// will fail. We make the timeout there sufficiently long to give the OS
// enough time to switch threads a few times.
std::thread thread([&]() {
IREE_TRACE_SCOPE();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
Expand All @@ -783,7 +793,7 @@ TEST_F(LoopTest, WaitAnyBlocking) {
} user_data;
IREE_ASSERT_OK(iree_loop_wait_any(
loop, IREE_ARRAYSIZE(wait_sources), wait_sources,
iree_make_timeout_ms(200),
iree_make_timeout_ms(2000),
+[](void* user_data_ptr, iree_loop_t loop, iree_status_t status) {
IREE_TRACE_SCOPE();
IREE_EXPECT_OK(status);
Expand Down Expand Up @@ -941,6 +951,11 @@ TEST_F(LoopTest, WaitAllBlocking) {
// Spin up the thread to signal the event after a short delay.
// We need to do this before we issue the wait so that loops which perform the
// wait inline can still make forward progress even if they block.
//
// Note: there is a race here that can cause flakes. If this new thread
// starts running after the iree_loop_wait_* timeout below, the status check
// will fail. We make the timeout there sufficiently long to give the OS
// enough time to switch threads a few times.
std::thread thread([&]() {
IREE_TRACE_SCOPE();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
Expand All @@ -956,7 +971,7 @@ TEST_F(LoopTest, WaitAllBlocking) {
} user_data;
IREE_ASSERT_OK(iree_loop_wait_all(
loop, IREE_ARRAYSIZE(wait_sources), wait_sources,
iree_make_timeout_ms(200),
iree_make_timeout_ms(2000),
+[](void* user_data_ptr, iree_loop_t loop, iree_status_t status) {
IREE_TRACE_SCOPE();
IREE_EXPECT_OK(status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
// query_tile_sizes
//===----------------------------------------------------------------------===//

// OPERAND_ROLE describes the role that a tensor plays in an
// OPERAND_INDEX describes the index that a tensor plays in an
// operation, e.g. "left-hand-size operand" (e.g. in a matmul).
#define IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_MASK 0xFF
#define IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_NONE 0x00
Expand Down
Loading

0 comments on commit 81c1a6d

Please sign in to comment.