Skip to content

Commit

Permalink
Fix device_strides implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Oct 14, 2024
1 parent c82b4d3 commit c02e4a9
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/atlas/array/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ class Array : public util::Object {

const ArrayStrides& strides() const { return spec_.strides(); }

const ArrayStrides& device_strides() const { return spec_.device_strides(); }

const ArrayShape& shape() const { return spec_.shape(); }

const std::vector<int>& shapef() const { return spec_.shapef(); }

const std::vector<int>& stridesf() const { return spec_.stridesf(); }

const ArrayStrides& device_stridesf() const { return spec_.strides(); }
const std::vector<int>& device_stridesf() const { return spec_.device_stridesf(); }

bool contiguous() const { return spec_.contiguous(); }

Expand Down
17 changes: 14 additions & 3 deletions src/atlas/array/ArraySpec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
shape_.resize(rank_);
strides_.resize(rank_);
layout_.resize(rank_);
device_strides_.resize(rank_);
device_strides_[rank_ - 1] = strides[rank_ - 1];
for (int j = rank_ - 1; j >= 0; --j) {
shape_[j] = shape[j];
strides_[j] = strides[j];
layout_[j] = j;
size_ *= size_t(shape_[j]);
if( j < rank_ - 1) {
// Assume contiguous device data!
device_strides_[j] = device_strides_[j+1] * shape[j];
}
}
allocated_size_ = compute_aligned_size(size_t(shape_[0]) * size_t(strides_[0]), size_t(alignment));
contiguous_ = (size_ == allocated_size_);
Expand Down Expand Up @@ -121,6 +127,8 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
shape_.resize(rank_);
strides_.resize(rank_);
layout_.resize(rank_);
device_strides_.resize(rank_);
device_strides_[rank_ - 1] = strides[rank_ - 1];
default_layout_ = true;
for (int j = rank_ - 1; j >= 0; --j) {
shape_[j] = shape[j];
Expand All @@ -130,6 +138,10 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
if (layout_[j] != idx_t(j)) {
default_layout_ = false;
}
if( j < rank_ - 1) {
// Assume contiguous device data!
device_strides_[j] = device_strides_[j+1] * shape[j];
}
}
allocated_size_ = compute_aligned_size(size_t(shape_[layout_[0]]) * size_t(strides_[layout_[0]]), size_t(alignment));
contiguous_ = (size_ == allocated_size_);
Expand Down Expand Up @@ -160,11 +172,10 @@ void ArraySpec::allocate_fortran_specs() {
shapef_.resize(rank_);
stridesf_.resize(rank_);
device_stridesf_.resize(rank_);
device_stridesf_[rank_ - 1] = stridesf_[rank_ - 1];
for (idx_t j = 0; j < rank_; ++j) {
shapef_[j] = shape_[rank_ - 1 - layout_[j]];
stridesf_[j] = strides_[rank_ -1 - layout_[j]];
device_stridesf_[rank_ - j - 1] = device_stridesf_[rank_ - j] * shapef_[rank_ - j - 1];
stridesf_[j] = strides_[rank_ - 1 - layout_[j]];
device_stridesf_[j] = device_strides_[rank_ - 1 - j];
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/atlas/array/ArraySpec.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ArraySpec {
DataType datatype_;
ArrayShape shape_;
ArrayStrides strides_;
ArrayStrides device_strides_;
ArrayLayout layout_;
ArrayAlignment alignment_;
std::vector<int> shapef_;
Expand Down Expand Up @@ -62,6 +63,7 @@ class ArraySpec {
const ArrayShape& shape() const { return shape_; }
const ArrayAlignment& alignment() const { return alignment_; }
const ArrayStrides& strides() const { return strides_; }
const ArrayStrides& device_strides() const { return device_strides_; }
const ArrayLayout& layout() const { return layout_; }
const std::vector<int>& shapef() const;
const std::vector<int>& stridesf() const;
Expand Down
4 changes: 2 additions & 2 deletions src/atlas/array/native/NativeMakeView.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ template <typename Value, int Rank>
ArrayView<Value, Rank> make_device_view(Array& array) {
#if ATLAS_HAVE_GPU
ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device");
return ArrayView<Value, Rank>((array.device_data<Value>()), array.shape(), array.strides());
return ArrayView<Value, Rank>((array.device_data<Value>()), array.shape(), array.device_strides());
#else
return make_host_view<Value, Rank>(array);
#endif
Expand All @@ -61,7 +61,7 @@ template <typename Value, int Rank>
ArrayView<const Value, Rank> make_device_view(const Array& array) {
#if ATLAS_HAVE_GPU
ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device");
return ArrayView<const Value, Rank>(array.device_data<const Value>(), array.shape(), array.strides());
return ArrayView<const Value, Rank>(array.device_data<const Value>(), array.shape(), array.device_strides());
#else
return make_host_view<const Value, Rank>(array);
#endif
Expand Down

0 comments on commit c02e4a9

Please sign in to comment.