Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fortran access to device data through Field and FieldSet #232

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/atlas/array/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@ class Array : public util::Object {

const ArrayStrides& strides() const { return spec_.strides(); }

const ArrayStrides& device_strides() const { return spec_.device_strides(); }

const ArrayShape& shape() const { return spec_.shape(); }

const std::vector<int>& shapef() const { return spec_.shapef(); }

const std::vector<int>& stridesf() const { return spec_.stridesf(); }

const std::vector<int>& device_stridesf() const { return spec_.device_stridesf(); }

bool contiguous() const { return spec_.contiguous(); }

bool hasDefaultLayout() const { return spec_.hasDefaultLayout(); }
Expand Down
26 changes: 25 additions & 1 deletion src/atlas/array/ArraySpec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayAlignment& alignment):
shape_.resize(rank_);
strides_.resize(rank_);
layout_.resize(rank_);
device_strides_.resize(rank_);
device_strides_[rank_ - 1] = 1;
for (int j = rank_ - 1; j >= 0; --j) {
shape_[j] = shape[j];
strides_[j] = allocated_size_;
layout_[j] = j;
size_ *= size_t(shape_[j]);
allocated_size_ *= size_t(aligned_shape[j]);
if( j < rank_ - 1) {
// Assume contiguous device data!
device_strides_[j] = strides_[j+1] * shape[j+1];
}
}
ATLAS_ASSERT(allocated_size_ == compute_aligned_size(size_t(shape_[0]) * size_t(strides_[0]), size_t(alignment)));
contiguous_ = (size_ == allocated_size_);
Expand Down Expand Up @@ -81,11 +87,17 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
shape_.resize(rank_);
strides_.resize(rank_);
layout_.resize(rank_);
device_strides_.resize(rank_);
device_strides_[rank_ - 1] = strides[rank_ - 1];
for (int j = rank_ - 1; j >= 0; --j) {
shape_[j] = shape[j];
strides_[j] = strides[j];
layout_[j] = j;
size_ *= size_t(shape_[j]);
if( j < rank_ - 1) {
// Assume contiguous device data!
device_strides_[j] = device_strides_[j+1] * shape[j+1];
}
}
allocated_size_ = compute_aligned_size(size_t(shape_[0]) * size_t(strides_[0]), size_t(alignment));
contiguous_ = (size_ == allocated_size_);
Expand Down Expand Up @@ -121,6 +133,8 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
shape_.resize(rank_);
strides_.resize(rank_);
layout_.resize(rank_);
device_strides_.resize(rank_);
device_strides_[rank_ - 1] = strides[rank_ - 1];
default_layout_ = true;
for (int j = rank_ - 1; j >= 0; --j) {
shape_[j] = shape[j];
Expand All @@ -130,6 +144,10 @@ ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayStrides& strides, const
if (layout_[j] != idx_t(j)) {
default_layout_ = false;
}
if( j < rank_ - 1) {
// Assume contiguous device data!
device_strides_[j] = device_strides_[j+1] * shape[j+1];
}
}
allocated_size_ = compute_aligned_size(size_t(shape_[layout_[0]]) * size_t(strides_[layout_[0]]), size_t(alignment));
contiguous_ = (size_ == allocated_size_);
Expand All @@ -152,12 +170,18 @@ const std::vector<int>& ArraySpec::stridesf() const {
return stridesf_;
}

const std::vector<int>& ArraySpec::device_stridesf() const {
return device_stridesf_;
}

void ArraySpec::allocate_fortran_specs() {
shapef_.resize(rank_);
stridesf_.resize(rank_);
device_stridesf_.resize(rank_);
for (idx_t j = 0; j < rank_; ++j) {
shapef_[j] = shape_[rank_ - 1 - layout_[j]];
stridesf_[j] = strides_[rank_ -1 - layout_[j]];
stridesf_[j] = strides_[rank_ - 1 - layout_[j]];
device_stridesf_[j] = device_strides_[rank_ - 1 - j];
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/atlas/array/ArraySpec.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ class ArraySpec {
DataType datatype_;
ArrayShape shape_;
ArrayStrides strides_;
ArrayStrides device_strides_;
ArrayLayout layout_;
ArrayAlignment alignment_;
std::vector<int> shapef_;
std::vector<int> stridesf_;
std::vector<int> device_stridesf_;
bool contiguous_;
bool default_layout_;

Expand All @@ -61,9 +63,11 @@ class ArraySpec {
const ArrayShape& shape() const { return shape_; }
const ArrayAlignment& alignment() const { return alignment_; }
const ArrayStrides& strides() const { return strides_; }
const ArrayStrides& device_strides() const { return device_strides_; }
const ArrayLayout& layout() const { return layout_; }
const std::vector<int>& shapef() const;
const std::vector<int>& stridesf() const;
const std::vector<int>& device_stridesf() const;
bool contiguous() const { return contiguous_; }
bool hasDefaultLayout() const { return default_layout_; }

Expand Down
4 changes: 2 additions & 2 deletions src/atlas/array/native/NativeMakeView.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ template <typename Value, int Rank>
ArrayView<Value, Rank> make_device_view(Array& array) {
#if ATLAS_HAVE_GPU
ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device");
return ArrayView<Value, Rank>((array.device_data<Value>()), array.shape(), array.strides());
return ArrayView<Value, Rank>((array.device_data<Value>()), array.shape(), array.device_strides());
#else
return make_host_view<Value, Rank>(array);
#endif
Expand All @@ -61,7 +61,7 @@ template <typename Value, int Rank>
ArrayView<const Value, Rank> make_device_view(const Array& array) {
#if ATLAS_HAVE_GPU
ATLAS_ASSERT(array.deviceAllocated(),"make_device_view: Array not allocated on device");
return ArrayView<const Value, Rank>(array.device_data<const Value>(), array.shape(), array.strides());
return ArrayView<const Value, Rank>(array.device_data<const Value>(), array.shape(), array.device_strides());
#else
return make_host_view<const Value, Rank>(array);
#endif
Expand Down
3 changes: 3 additions & 0 deletions src/atlas/field/detail/FieldImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class FieldImpl : public util::Object {
/// @brief Strides of this field in Fortran style (reverse order of C style)
const std::vector<int>& stridesf() const { return array_->stridesf(); }

/// @brief Strides of this field on the device in Fortran style (reverse order of C style)
const std::vector<int>& device_stridesf() const { return array_->device_stridesf(); }

/// @brief Shape of this field (reverse order of Fortran style)
const array::ArrayShape& shape() const { return array_->shape(); }

Expand Down
28 changes: 28 additions & 0 deletions src/atlas/field/detail/FieldInterface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ void atlas__Field__data_specf(FieldImpl* This, Value*& data, int& rank, int*& sh
rank = This->shapef().size();
}

template <typename Value>
void atlas__Field__device_data_specf(FieldImpl* This, Value*& data, int& rank, int*& shapef, int*& stridesf) {
ATLAS_ASSERT(This != nullptr, "Cannot access data of uninitialised atlas_Field");
if (This->datatype() != array::make_datatype<Value>()) {
throw_Exception("Datatype mismatch for accessing field data");
}
data = This->array().device_data<Value>();
shapef = const_cast<int*>(This->shapef().data());
stridesf = const_cast<int*>(This->device_stridesf().data());
rank = This->shapef().size();
}

template <typename Value>
FieldImpl* atlas__Field__wrap_specf(const char* name, Value data[], int rank, int shapef[], int stridesf[]) {
array::ArrayShape shape;
Expand Down Expand Up @@ -189,6 +201,22 @@ void atlas__Field__data_double_specf(FieldImpl* This, double*& data, int& rank,
atlas__Field__data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_int_specf(FieldImpl* This, int*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_long_specf(FieldImpl* This, long*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_float_specf(FieldImpl* This, float*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_double_specf(FieldImpl* This, double*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

int atlas__Field__host_needs_update(const FieldImpl* This) {
return This->hostNeedsUpdate();
}
Expand Down
8 changes: 8 additions & 0 deletions src/atlas/field/detail/FieldInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ void atlas__Field__data_float_specf(FieldImpl* This, float*& field_data, int& ra
int*& field_stridesf);
void atlas__Field__data_double_specf(FieldImpl* This, double*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_int_specf(FieldImpl* This, int*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_long_specf(FieldImpl* This, long*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_float_specf(FieldImpl* This, float*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_double_specf(FieldImpl* This, double*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
util::Metadata* atlas__Field__metadata(FieldImpl* This);
const functionspace::FunctionSpaceImpl* atlas__Field__functionspace(FieldImpl* This);
void atlas__Field__rename(FieldImpl* This, const char* name);
Expand Down
Loading
Loading