From c02e4a916ca091f3e5604f2b731b2c06f9a405a2 Mon Sep 17 00:00:00 2001 From: Willem Deconinck Date: Mon, 14 Oct 2024 09:02:22 +0200 Subject: [PATCH] Fix device_strides implementation --- src/atlas/array/Array.h | 4 +++- src/atlas/array/ArraySpec.cc | 17 ++++++++++++++--- src/atlas/array/ArraySpec.h | 2 ++ src/atlas/array/native/NativeMakeView.cc | 4 ++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/atlas/array/Array.h b/src/atlas/array/Array.h index 85d43d50c..99bff9fac 100644 --- a/src/atlas/array/Array.h +++ b/src/atlas/array/Array.h @@ -85,13 +85,15 @@ class Array : public util::Object { const ArrayStrides& strides() const { return spec_.strides(); } + const ArrayStrides& device_strides() const { return spec_.device_strides(); } + const ArrayShape& shape() const { return spec_.shape(); } const std::vector& shapef() const { return spec_.shapef(); } const std::vector& stridesf() const { return spec_.stridesf(); } - const ArrayStrides& device_stridesf() const { return spec_.strides(); } + const std::vector& device_stridesf() const { return spec_.device_stridesf(); } bool contiguous() const { return spec_.contiguous(); } diff --git a/src/atlas/array/ArraySpec.cc b/src/atlas/array/ArraySpec.cc index 43af4c99f..243f04495 100644 --- a/src/atlas/array/ArraySpec.cc +++ b/src/atlas/array/ArraySpec.cc @@ -81,11 +81,17 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const shape_.resize(rank_); strides_.resize(rank_); layout_.resize(rank_); + device_strides_.resize(rank_); + device_strides_[rank_ - 1] = strides[rank_ - 1]; for (int j = rank_ - 1; j >= 0; --j) { shape_[j] = shape[j]; strides_[j] = strides[j]; layout_[j] = j; size_ *= size_t(shape_[j]); + if( j < rank_ - 1) { + // Assume contiguous device data! + device_strides_[j] = device_strides_[j+1] * shape[j]; + } } allocated_size_ = compute_aligned_size(size_t(shape_[0]) * size_t(strides_[0]), size_t(alignment)); contiguous_ = (size_ == allocated_size_); @@ -121,6 +127,8 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const shape_.resize(rank_); strides_.resize(rank_); layout_.resize(rank_); + device_strides_.resize(rank_); + device_strides_[rank_ - 1] = strides[rank_ - 1]; default_layout_ = true; for (int j = rank_ - 1; j >= 0; --j) { shape_[j] = shape[j]; @@ -130,6 +138,10 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const if (layout_[j] != idx_t(j)) { default_layout_ = false; } + if( j < rank_ - 1) { + // Assume contiguous device data! + device_strides_[j] = device_strides_[j+1] * shape[j]; + } } allocated_size_ = compute_aligned_size(size_t(shape_[layout_[0]]) * size_t(strides_[layout_[0]]), size_t(alignment)); contiguous_ = (size_ == allocated_size_); @@ -160,11 +172,10 @@ void ArraySpec::allocate_fortran_specs() { shapef_.resize(rank_); stridesf_.resize(rank_); device_stridesf_.resize(rank_); - device_stridesf_[rank_ - 1] = stridesf_[rank_ - 1]; for (idx_t j = 0; j < rank_; ++j) { shapef_[j] = shape_[rank_ - 1 - layout_[j]]; - stridesf_[j] = strides_[rank_ -1 - layout_[j]]; - device_stridesf_[rank_ - j - 1] = device_stridesf_[rank_ - j] * shapef_[rank_ - j - 1]; + stridesf_[j] = strides_[rank_ - 1 - layout_[j]]; + device_stridesf_[j] = device_strides_[rank_ - 1 - j]; } } diff --git a/src/atlas/array/ArraySpec.h b/src/atlas/array/ArraySpec.h index 43d7a261f..74b423689 100644 --- a/src/atlas/array/ArraySpec.h +++ b/src/atlas/array/ArraySpec.h @@ -33,6 +33,7 @@ class ArraySpec { DataType datatype_; ArrayShape shape_; ArrayStrides strides_; + ArrayStrides device_strides_; ArrayLayout layout_; ArrayAlignment alignment_; std::vector shapef_; @@ -62,6 +63,7 @@ class ArraySpec { const ArrayShape& shape() const { return shape_; } const ArrayAlignment& alignment() const { return alignment_; } const ArrayStrides& strides() const { return strides_; } + const ArrayStrides& device_strides() const { return device_strides_; } const ArrayLayout& layout() const { return layout_; } const std::vector& shapef() const; const std::vector& stridesf() const; diff --git a/src/atlas/array/native/NativeMakeView.cc b/src/atlas/array/native/NativeMakeView.cc index 5c24c0037..cdbcd27dc 100644 --- a/src/atlas/array/native/NativeMakeView.cc +++ b/src/atlas/array/native/NativeMakeView.cc @@ -51,7 +51,7 @@ template ArrayView make_device_view(Array& array) { #if ATLAS_HAVE_GPU ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device"); - return ArrayView((array.device_data()), array.shape(), array.strides()); + return ArrayView((array.device_data()), array.shape(), array.device_strides()); #else return make_host_view(array); #endif @@ -61,7 +61,7 @@ template ArrayView make_device_view(const Array& array) { #if ATLAS_HAVE_GPU ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device"); - return ArrayView(array.device_data(), array.shape(), array.strides()); + return ArrayView(array.device_data(), array.shape(), array.device_strides()); #else return make_host_view(array); #endif