Skip to content

Commit

Permalink
Review Field GPU API
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Sep 12, 2023
1 parent a922f4b commit 1b5b17c
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 47 deletions.
14 changes: 10 additions & 4 deletions src/atlas/array/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,27 @@ class Array : public util::Object {

virtual const void* storage() const { return data_store_->voidDataStore(); }

bool valid() const { return data_store_->valid(); }

void updateDevice() const { data_store_->updateDevice(); }

void updateHost() const { data_store_->updateHost(); }

bool valid() const { return data_store_->valid(); }

void syncHostDevice() const { data_store_->syncHostDevice(); }

bool hostNeedsUpdate() const { return data_store_->hostNeedsUpdate(); }

bool deviceNeedsUpdate() const { return data_store_->deviceNeedsUpdate(); }

void reactivateDeviceWriteViews() const { data_store_->reactivateDeviceWriteViews(); }
void setHostNeedsUpdate(bool v) const { return data_store_->setHostNeedsUpdate(v); }

void setDeviceNeedsUpdate(bool v) const { return data_store_->setDeviceNeedsUpdate(v); }

bool deviceAllocated() { return data_store_->deviceAllocated(); }

void allocateDevice() { data_store_->allocateDevice(); }

void reactivateHostWriteViews() const { data_store_->reactivateHostWriteViews(); }
void deallocateDevice() { data_store_->deallocateDevice(); }

const ArraySpec& spec() const { return spec_; }

Expand Down
25 changes: 14 additions & 11 deletions src/atlas/array/ArrayDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,20 @@ struct add_const<T const> {
class ArrayDataStore {
public:
virtual ~ArrayDataStore() {}
virtual void updateDevice() const = 0;
virtual void updateHost() const = 0;
virtual bool valid() const = 0;
virtual void syncHostDevice() const = 0;
virtual bool hostNeedsUpdate() const = 0;
virtual bool deviceNeedsUpdate() const = 0;
virtual void reactivateDeviceWriteViews() const = 0;
virtual void reactivateHostWriteViews() const = 0;
virtual void* voidDataStore() = 0;
virtual void* voidHostData() = 0;
virtual void* voidDeviceData() = 0;
virtual void updateDevice() const = 0;
virtual void updateHost() const = 0;
virtual bool valid() const = 0;
virtual void syncHostDevice() const = 0;
virtual void allocateDevice() const = 0;
virtual void deallocateDevice() const = 0;
virtual bool deviceAllocated() const = 0;
virtual bool hostNeedsUpdate() const = 0;
virtual bool deviceNeedsUpdate() const = 0;
virtual void setHostNeedsUpdate(bool) const = 0;
virtual void setDeviceNeedsUpdate(bool) const = 0;
virtual void* voidDataStore() = 0;
virtual void* voidHostData() = 0;
virtual void* voidDeviceData() = 0;
template <typename Value>
Value* hostData() {
return static_cast<Value*>(voidHostData());
Expand Down
2 changes: 1 addition & 1 deletion src/atlas/array/ArraySpec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ArraySpec::ArraySpec(DataType datatype, const ArrayShape& shape): ArraySpec(shap

ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayAlignment& alignment): datatype_(DataType::KIND_REAL64) {
ArrayShape aligned_shape = shape;
aligned_shape.back() = compute_aligned_size(aligned_shape.back(), alignment);
aligned_shape.back() = compute_aligned_size(aligned_shape.back(), size_t(alignment));

rank_ = static_cast<int>(shape.size());
size_ = 1;
Expand Down
38 changes: 27 additions & 11 deletions src/atlas/array/gridtools/GridToolsDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,48 @@ struct GridToolsDataStore : ArrayDataStore {
delete data_store_;
}

void updateDevice() const {
void updateDevice() const override {
assert(data_store_);
data_store_->clone_to_device();
}

void updateHost() const { data_store_->clone_from_device(); }
void updateHost() const override { data_store_->clone_from_device(); }

bool valid() const { return data_store_->valid(); }
bool valid() const override { return data_store_->valid(); }

void syncHostDevice() const { data_store_->sync(); }
void syncHostDevice() const override { data_store_->sync(); }

bool hostNeedsUpdate() const { return data_store_->host_needs_update(); }
void allocateDevice() const override {}

bool deviceNeedsUpdate() const { return data_store_->device_needs_update(); }
void deallocateDevice() const override {}

void reactivateDeviceWriteViews() const { data_store_->reactivate_target_write_views(); }
bool deviceAllocated() const override { return ATLAS_GRIDTOOLS_STORAGE_BACKEND_CUDA; }

void reactivateHostWriteViews() const { data_store_->reactivate_host_write_views(); }
bool hostNeedsUpdate() const override { return data_store_->host_needs_update(); }

void* voidDataStore() { return static_cast<void*>(const_cast<gt_DataStore*>(data_store_)); }
bool deviceNeedsUpdate() const override { return data_store_->device_needs_update(); }

void* voidHostData() {
void setHostNeedsUpdate(bool v) const override {
auto state_machine = data_store_->get_storage_ptr()->get_state_machine_ptr_impl();
if (state_machine) {
state_machine->m_hnu = v;
}
}

void setDeviceNeedsUpdate(bool v) const override {
auto state_machine = data_store_->get_storage_ptr()->get_state_machine_ptr_impl();
if (state_machine) {
state_machine->m_dnu = v;
}
}

void* voidDataStore() override { return static_cast<void*>(const_cast<gt_DataStore*>(data_store_)); }

void* voidHostData() override {
return ::gridtools::make_host_view<::gridtools::access_mode::read_only>(*data_store_).data();
}

void* voidDeviceData() {
void* voidDeviceData() override {
#if ATLAS_GRIDTOOLS_STORAGE_BACKEND_CUDA
return ::gridtools::make_device_view<::gridtools::access_mode::read_only>(*data_store_).data();
#else
Expand Down
20 changes: 16 additions & 4 deletions src/atlas/array/native/NativeDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,19 @@ class DataStore : public ArrayDataStore {

virtual void syncHostDevice() const override {}

virtual bool deviceAllocated() const override { return false; }

virtual void allocateDevice() const override {}

virtual void deallocateDevice() const override {}

virtual bool hostNeedsUpdate() const override { return false; }

virtual bool deviceNeedsUpdate() const override { return false; }

virtual void reactivateDeviceWriteViews() const override {}
virtual void setHostNeedsUpdate(bool) const override {}

virtual void reactivateHostWriteViews() const override {}
virtual void setDeviceNeedsUpdate(bool) const override {}

virtual void* voidDataStore() override { return static_cast<void*>(data_store_); }

Expand Down Expand Up @@ -172,13 +178,19 @@ class WrappedDataStore : public ArrayDataStore {

virtual void syncHostDevice() const override {}

virtual bool deviceAllocated() const override { return false; }

virtual void allocateDevice() const override {}

virtual void deallocateDevice() const override {}

virtual bool hostNeedsUpdate() const override { return true; }

virtual bool deviceNeedsUpdate() const override { return false; }

virtual void reactivateDeviceWriteViews() const override {}
virtual void setHostNeedsUpdate(bool) const override {}

virtual void reactivateHostWriteViews() const override {}
virtual void setDeviceNeedsUpdate(bool) const override {}

virtual void* voidDataStore() override { return static_cast<void*>(data_store_); }

Expand Down
8 changes: 4 additions & 4 deletions src/atlas/field/Field.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ bool Field::hostNeedsUpdate() const {
bool Field::deviceNeedsUpdate() const {
return get()->deviceNeedsUpdate();
}
void Field::reactivateDeviceWriteViews() const {
get()->reactivateDeviceWriteViews();
void Field::setHostNeedsUpdate(bool v) const {
return get()->setHostNeedsUpdate(v);
}
void Field::reactivateHostWriteViews() const {
get()->reactivateHostWriteViews();
void Field::setDeviceNeedsUpdate(bool v) const {
return get()->setDeviceNeedsUpdate(v);
}

// ------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/atlas/field/Field.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ class Field : DOXYGEN_HIDE(public util::ObjectHandle<field::FieldImpl>) {
void syncHostDevice() const;
bool hostNeedsUpdate() const;
bool deviceNeedsUpdate() const;
void reactivateDeviceWriteViews() const;
void reactivateHostWriteViews() const;
void setHostNeedsUpdate(bool) const;
void setDeviceNeedsUpdate(bool) const;
};

extern template Field::Field(const std::string&, float*, const array::ArraySpec&);
Expand Down
7 changes: 5 additions & 2 deletions src/atlas/field/detail/FieldImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,13 @@ class FieldImpl : public util::Object {
void updateHost() const { array_->updateHost(); }
void updateDevice() const { array_->updateDevice(); }
void syncHostDevice() const { array_->syncHostDevice(); }
bool deviceAllocated() const { return array_->deviceAllocated(); }
void allocateDevice() const { array_->allocateDevice(); }
void deallocateDevice() const { array_->deallocateDevice(); }
bool hostNeedsUpdate() const { return array_->hostNeedsUpdate(); }
bool deviceNeedsUpdate() const { return array_->deviceNeedsUpdate(); }
void reactivateDeviceWriteViews() const { array_->reactivateDeviceWriteViews(); }
void reactivateHostWriteViews() const { array_->reactivateHostWriteViews(); }
void setHostNeedsUpdate(bool v) const { return array_->setHostNeedsUpdate(v); }
void setDeviceNeedsUpdate(bool v) const { return array_->setDeviceNeedsUpdate(v); }

void haloExchange(bool on_device = false) const;
void adjointHaloExchange(bool on_device = false) const;
Expand Down
24 changes: 24 additions & 0 deletions src/atlas/field/detail/FieldInterface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ int atlas__Field__device_needs_update(const FieldImpl* This) {
return This->deviceNeedsUpdate();
}

int atlas__Field__device_allocated(const FieldImpl* This) {
return This->deviceAllocated();
}

void atlas__Field__rename(FieldImpl* This, const char* name) {
ATLAS_ASSERT(This, "Cannot rename uninitialised atlas_Field");
This->rename(std::string(name));
Expand All @@ -219,6 +223,16 @@ void atlas__Field__set_functionspace(FieldImpl* This, const functionspace::Funct
#endif
}

void atlas__Field__set_host_needs_update(const FieldImpl* This, int value) {
ATLAS_ASSERT(This != nullptr, "Cannot set value for uninitialised atlas_Field");
This->setHostNeedsUpdate(value);
}

void atlas__Field__set_device_needs_update(const FieldImpl* This, int value) {
ATLAS_ASSERT(This != nullptr, "Cannot set value for uninitialised atlas_Field");
This->setDeviceNeedsUpdate(value);
}

void atlas__Field__update_device(FieldImpl* This) {
ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field");
This->updateDevice();
Expand All @@ -234,6 +248,16 @@ void atlas__Field__sync_host_device(FieldImpl* This) {
This->syncHostDevice();
}

void atlas__Field__allocate_device(FieldImpl* This) {
ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field");
This->allocateDevice();
}

void atlas__Field__deallocate_device(FieldImpl* This) {
ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field");
This->deallocateDevice();
}

void atlas__Field__set_dirty(FieldImpl* This, int value) {
ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field");
This->set_dirty(value);
Expand Down
5 changes: 5 additions & 0 deletions src/atlas/field/detail/FieldInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ void atlas__Field__set_levels(FieldImpl* This, int levels);
void atlas__Field__set_functionspace(FieldImpl* This, const functionspace::FunctionSpaceImpl* functionspace);
int atlas__Field__host_needs_update(const FieldImpl* This);
int atlas__Field__device_needs_update(const FieldImpl* This);
int atlas__Field__device_allocated(const FieldImpl* This);
void atlas__Field__set_host_needs_update(const FieldImpl* This, int value);
void atlas__Field__set_device_needs_update(const FieldImpl* This, int value);
void atlas__Field__update_device(FieldImpl* This);
void atlas__Field__update_host(FieldImpl* This);
void atlas__Field__sync_host_device(FieldImpl* This);
void atlas__Field__allocate_device(FieldImpl* This);
void atlas__Field__deallocate_device(FieldImpl* This);
void atlas__Field__set_dirty(FieldImpl* This, int value);
void atlas__Field__halo_exchange(FieldImpl* This, int on_device);
void atlas__Field__adjoint_halo_exchange(FieldImpl* This, int on_device);
Expand Down
66 changes: 66 additions & 0 deletions src/atlas_f/field/atlas_Field_module.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ contains
#:endfor
& dummy

procedure, public :: device_allocated
procedure, public :: allocate_device
procedure, public :: deallocate_device
procedure, public :: set_host_needs_update
procedure, public :: set_device_needs_update
procedure, public :: host_needs_update
procedure, public :: device_needs_update
procedure, public :: update_device
Expand Down Expand Up @@ -633,6 +638,38 @@ end subroutine

!-------------------------------------------------------------------------------

subroutine set_host_needs_update(this,value)
use atlas_field_c_binding
class(atlas_Field), intent(in) :: this
logical, optional, intent(in) :: value
integer :: value_int
value_int = 1
if (present(value)) then
if (.not. value) then
value_int = 0
endif
endif
call atlas__field__set_host_needs_update(this%CPTR_PGIBUG_A,value_int)
end subroutine

!-------------------------------------------------------------------------------

subroutine set_device_needs_update(this,value)
use atlas_field_c_binding
class(atlas_Field), intent(in) :: this
logical, optional, intent(in) :: value
integer :: value_int
value_int = 1
if (present(value)) then
if (.not. value) then
value_int = 0
endif
endif
call atlas__field__set_device_needs_update(this%CPTR_PGIBUG_A,value_int)
end subroutine

!-------------------------------------------------------------------------------

function host_needs_update(this)
use atlas_field_c_binding
logical :: host_needs_update
Expand Down Expand Up @@ -683,6 +720,35 @@ end subroutine

!-------------------------------------------------------------------------------

subroutine allocate_device(this)
use atlas_field_c_binding
class(atlas_Field), intent(inout) :: this
call atlas__Field__allocate_device(this%CPTR_PGIBUG_A)
end subroutine

!-------------------------------------------------------------------------------

subroutine deallocate_device(this)
use atlas_field_c_binding
class(atlas_Field), intent(inout) :: this
call atlas__Field__deallocate_device(this%CPTR_PGIBUG_A)
end subroutine

!-------------------------------------------------------------------------------

function device_allocated(this)
use atlas_field_c_binding
logical :: device_allocated
class(atlas_Field), intent(in) :: this
if( atlas__Field__device_allocated(this%CPTR_PGIBUG_A) == 1 ) then
device_allocated = .true.
else
device_allocated = .false.
endif
end function

!-------------------------------------------------------------------------------

subroutine halo_exchange(this,on_device)
use, intrinsic :: iso_c_binding, only : c_int
use atlas_field_c_binding
Expand Down
Loading

0 comments on commit 1b5b17c

Please sign in to comment.