Skip to content

Commit

Permalink
make atlas_test_array_kernel run correctly with the Native-CUDA backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrdar committed Mar 6, 2024
1 parent 3541973 commit bb8c508
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
16 changes: 12 additions & 4 deletions src/atlas/array/native/NativeDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,27 @@ class DataStore : public ArrayDataStore {
DataStore(size_t size): size_(size) {
alloc_aligned(data_store_, size_);
initialise(data_store_, size_);
device_allocated_ = false;
setDeviceNeedsUpdate(true);
}

~DataStore() override {
free_aligned(data_store_);
deallocateDevice();
}

void updateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaMemcpy(data_store_dev_, data_store_, sizeof(data_store_), cudaMemcpyHostToDevice);
if (not device_allocated_) {
allocateDevice();
}
cudaMemcpy(data_store_dev_, data_store_, size_*sizeof(Value), cudaMemcpyHostToDevice);
#endif
}

void updateHost() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaMemcpy(data_store_, data_store_dev_, sizeof(data_store_dev_), cudaMemcpyDeviceToHost);
cudaMemcpy(data_store_, data_store_dev_, size_*sizeof(Value), cudaMemcpyDeviceToHost);
#endif
}

Expand All @@ -124,17 +129,19 @@ class DataStore : public ArrayDataStore {
if (device_updated_) updateHost();
}

bool deviceAllocated() const override { return false; }
bool deviceAllocated() const override { return device_allocated_; }

void allocateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaMalloc((void**)&data_store_dev_, sizeof(Value)*size_);
device_allocated_ = true;
#endif
}

void deallocateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaFree(data_store_dev_);
device_allocated_ = false;
#endif
}

Expand Down Expand Up @@ -187,10 +194,11 @@ class DataStore : public ArrayDataStore {
size_t footprint() const { return sizeof(Value) * size_; }

Value* data_store_;
Value* data_store_dev_;
size_t size_;
Value* data_store_dev_;
mutable bool host_updated_;
mutable bool device_updated_;
mutable bool device_allocated_;
};

//------------------------------------------------------------------------------
Expand Down
22 changes: 20 additions & 2 deletions src/atlas/array/native/NativeMakeView.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,40 @@ inline static void check_metadata(const Array& array) {

template <typename Value, int Rank>
ArrayView<Value, Rank> make_host_view(Array& array) {
return ArrayView<Value, Rank>((Value*)(array.storage()), array.shape(), array.strides());
return ArrayView<Value, Rank>(array.host_data<Value>(), array.shape(), array.strides());
}

template <typename Value, int Rank>
ArrayView<const Value, Rank> make_host_view(const Array& array) {
return ArrayView<const Value, Rank>((const Value*)(array.storage()), array.shape(), array.strides());
return ArrayView<const Value, Rank>(array.host_data<const Value>(), array.shape(), array.strides());
}

template <typename Value, int Rank>
ArrayView<Value, Rank> make_device_view(Array& array) {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
if (not array.deviceAllocated()) {
std::ostringstream ss;
ss << "make_device_view: Array not allocated on device" << std::endl;
throw_Exception(ss.str(), Here());
}
return ArrayView<Value, Rank>((array.device_data<Value>()), array.shape(), array.strides());
#else
return make_host_view<Value, Rank>(array);
#endif
}

template <typename Value, int Rank>
ArrayView<const Value, Rank> make_device_view(const Array& array) {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
if (not array.deviceAllocated()) {
std::ostringstream ss;
ss << "make_device_view: Array not allocated on device" << std::endl;
throw_Exception(ss.str(), Here());
}
return ArrayView<const Value, Rank>(array.device_data<const Value>(), array.shape(), array.strides());
#else
return make_host_view<Value, Rank>(array);
#endif
}

template <typename Value, int Rank>
Expand Down

0 comments on commit bb8c508

Please sign in to comment.