diff --git a/src/atlas/array/Array.h b/src/atlas/array/Array.h index 9a7c22668..b55ec8297 100644 --- a/src/atlas/array/Array.h +++ b/src/atlas/array/Array.h @@ -115,21 +115,27 @@ class Array : public util::Object { virtual const void* storage() const { return data_store_->voidDataStore(); } + bool valid() const { return data_store_->valid(); } + void updateDevice() const { data_store_->updateDevice(); } void updateHost() const { data_store_->updateHost(); } - bool valid() const { return data_store_->valid(); } - void syncHostDevice() const { data_store_->syncHostDevice(); } bool hostNeedsUpdate() const { return data_store_->hostNeedsUpdate(); } bool deviceNeedsUpdate() const { return data_store_->deviceNeedsUpdate(); } - void reactivateDeviceWriteViews() const { data_store_->reactivateDeviceWriteViews(); } + void setHostNeedsUpdate(bool v) const { return data_store_->setHostNeedsUpdate(v); } + + void setDeviceNeedsUpdate(bool v) const { return data_store_->setDeviceNeedsUpdate(v); } + + bool deviceAllocated() { return data_store_->deviceAllocated(); } + + void allocateDevice() { data_store_->allocateDevice(); } - void reactivateHostWriteViews() const { data_store_->reactivateHostWriteViews(); } + void deallocateDevice() { data_store_->deallocateDevice(); } const ArraySpec& spec() const { return spec_; } diff --git a/src/atlas/array/ArrayDataStore.h b/src/atlas/array/ArrayDataStore.h index 3ef2bf58f..d60df80a9 100644 --- a/src/atlas/array/ArrayDataStore.h +++ b/src/atlas/array/ArrayDataStore.h @@ -46,17 +46,20 @@ struct add_const { class ArrayDataStore { public: virtual ~ArrayDataStore() {} - virtual void updateDevice() const = 0; - virtual void updateHost() const = 0; - virtual bool valid() const = 0; - virtual void syncHostDevice() const = 0; - virtual bool hostNeedsUpdate() const = 0; - virtual bool deviceNeedsUpdate() const = 0; - virtual void reactivateDeviceWriteViews() const = 0; - virtual void reactivateHostWriteViews() const = 0; - virtual void* voidDataStore() = 0; - virtual void* voidHostData() = 0; - virtual void* voidDeviceData() = 0; + virtual void updateDevice() const = 0; + virtual void updateHost() const = 0; + virtual bool valid() const = 0; + virtual void syncHostDevice() const = 0; + virtual void allocateDevice() const = 0; + virtual void deallocateDevice() const = 0; + virtual bool deviceAllocated() const = 0; + virtual bool hostNeedsUpdate() const = 0; + virtual bool deviceNeedsUpdate() const = 0; + virtual void setHostNeedsUpdate(bool) const = 0; + virtual void setDeviceNeedsUpdate(bool) const = 0; + virtual void* voidDataStore() = 0; + virtual void* voidHostData() = 0; + virtual void* voidDeviceData() = 0; template Value* hostData() { return static_cast(voidHostData()); diff --git a/src/atlas/array/ArraySpec.cc b/src/atlas/array/ArraySpec.cc index 074c23074..96ee8ba8e 100644 --- a/src/atlas/array/ArraySpec.cc +++ b/src/atlas/array/ArraySpec.cc @@ -39,7 +39,7 @@ ArraySpec::ArraySpec(DataType datatype, const ArrayShape& shape): ArraySpec(shap ArraySpec::ArraySpec(const ArrayShape& shape, const ArrayAlignment& alignment): datatype_(DataType::KIND_REAL64) { ArrayShape aligned_shape = shape; - aligned_shape.back() = compute_aligned_size(aligned_shape.back(), alignment); + aligned_shape.back() = compute_aligned_size(aligned_shape.back(), size_t(alignment)); rank_ = static_cast(shape.size()); size_ = 1; diff --git a/src/atlas/array/gridtools/GridToolsDataStore.h b/src/atlas/array/gridtools/GridToolsDataStore.h index 6758c74ba..4225453c4 100644 --- a/src/atlas/array/gridtools/GridToolsDataStore.h +++ b/src/atlas/array/gridtools/GridToolsDataStore.h @@ -28,32 +28,48 @@ struct GridToolsDataStore : ArrayDataStore { delete data_store_; } - void updateDevice() const { + void updateDevice() const override { assert(data_store_); data_store_->clone_to_device(); } - void updateHost() const { data_store_->clone_from_device(); } + void updateHost() const override { data_store_->clone_from_device(); } - bool valid() const { return data_store_->valid(); } + bool valid() const override { return data_store_->valid(); } - void syncHostDevice() const { data_store_->sync(); } + void syncHostDevice() const override { data_store_->sync(); } - bool hostNeedsUpdate() const { return data_store_->host_needs_update(); } + void allocateDevice() const override {} - bool deviceNeedsUpdate() const { return data_store_->device_needs_update(); } + void deallocateDevice() const override {} - void reactivateDeviceWriteViews() const { data_store_->reactivate_target_write_views(); } + bool deviceAllocated() const override { return ATLAS_GRIDTOOLS_STORAGE_BACKEND_CUDA; } - void reactivateHostWriteViews() const { data_store_->reactivate_host_write_views(); } + bool hostNeedsUpdate() const override { return data_store_->host_needs_update(); } - void* voidDataStore() { return static_cast(const_cast(data_store_)); } + bool deviceNeedsUpdate() const override { return data_store_->device_needs_update(); } - void* voidHostData() { + void setHostNeedsUpdate(bool v) const override { + auto state_machine = data_store_->get_storage_ptr()->get_state_machine_ptr_impl(); + if (state_machine) { + state_machine->m_hnu = v; + } + } + + void setDeviceNeedsUpdate(bool v) const override { + auto state_machine = data_store_->get_storage_ptr()->get_state_machine_ptr_impl(); + if (state_machine) { + state_machine->m_dnu = v; + } + } + + void* voidDataStore() override { return static_cast(const_cast(data_store_)); } + + void* voidHostData() override { return ::gridtools::make_host_view<::gridtools::access_mode::read_only>(*data_store_).data(); } - void* voidDeviceData() { + void* voidDeviceData() override { #if ATLAS_GRIDTOOLS_STORAGE_BACKEND_CUDA return ::gridtools::make_device_view<::gridtools::access_mode::read_only>(*data_store_).data(); #else diff --git a/src/atlas/array/native/NativeDataStore.h b/src/atlas/array/native/NativeDataStore.h index 3e02cb19c..f4d83573b 100644 --- a/src/atlas/array/native/NativeDataStore.h +++ b/src/atlas/array/native/NativeDataStore.h @@ -106,13 +106,19 @@ class DataStore : public ArrayDataStore { virtual void syncHostDevice() const override {} + virtual bool deviceAllocated() const override { return false; } + + virtual void allocateDevice() const override {} + + virtual void deallocateDevice() const override {} + virtual bool hostNeedsUpdate() const override { return false; } virtual bool deviceNeedsUpdate() const override { return false; } - virtual void reactivateDeviceWriteViews() const override {} + virtual void setHostNeedsUpdate(bool) const override {} - virtual void reactivateHostWriteViews() const override {} + virtual void setDeviceNeedsUpdate(bool) const override {} virtual void* voidDataStore() override { return static_cast(data_store_); } @@ -172,13 +178,19 @@ class WrappedDataStore : public ArrayDataStore { virtual void syncHostDevice() const override {} + virtual bool deviceAllocated() const override { return false; } + + virtual void allocateDevice() const override {} + + virtual void deallocateDevice() const override {} + virtual bool hostNeedsUpdate() const override { return true; } virtual bool deviceNeedsUpdate() const override { return false; } - virtual void reactivateDeviceWriteViews() const override {} + virtual void setHostNeedsUpdate(bool) const override {} - virtual void reactivateHostWriteViews() const override {} + virtual void setDeviceNeedsUpdate(bool) const override {} virtual void* voidDataStore() override { return static_cast(data_store_); } diff --git a/src/atlas/field/Field.cc b/src/atlas/field/Field.cc index befefbd0b..22996ac74 100644 --- a/src/atlas/field/Field.cc +++ b/src/atlas/field/Field.cc @@ -244,11 +244,11 @@ bool Field::hostNeedsUpdate() const { bool Field::deviceNeedsUpdate() const { return get()->deviceNeedsUpdate(); } -void Field::reactivateDeviceWriteViews() const { - get()->reactivateDeviceWriteViews(); +void Field::setHostNeedsUpdate(bool v) const { + return get()->setHostNeedsUpdate(v); } -void Field::reactivateHostWriteViews() const { - get()->reactivateHostWriteViews(); +void Field::setDeviceNeedsUpdate(bool v) const { + return get()->setDeviceNeedsUpdate(v); } // ------------------------------------------------------------------ diff --git a/src/atlas/field/Field.h b/src/atlas/field/Field.h index cde616ecb..c5e0f4ad4 100644 --- a/src/atlas/field/Field.h +++ b/src/atlas/field/Field.h @@ -190,8 +190,8 @@ class Field : DOXYGEN_HIDE(public util::ObjectHandle) { void syncHostDevice() const; bool hostNeedsUpdate() const; bool deviceNeedsUpdate() const; - void reactivateDeviceWriteViews() const; - void reactivateHostWriteViews() const; + void setHostNeedsUpdate(bool) const; + void setDeviceNeedsUpdate(bool) const; }; extern template Field::Field(const std::string&, float*, const array::ArraySpec&); diff --git a/src/atlas/field/detail/FieldImpl.h b/src/atlas/field/detail/FieldImpl.h index 91de0e17a..8d8b05ad6 100644 --- a/src/atlas/field/detail/FieldImpl.h +++ b/src/atlas/field/detail/FieldImpl.h @@ -201,10 +201,13 @@ class FieldImpl : public util::Object { void updateHost() const { array_->updateHost(); } void updateDevice() const { array_->updateDevice(); } void syncHostDevice() const { array_->syncHostDevice(); } + bool deviceAllocated() const { return array_->deviceAllocated(); } + void allocateDevice() const { array_->allocateDevice(); } + void deallocateDevice() const { array_->deallocateDevice(); } bool hostNeedsUpdate() const { return array_->hostNeedsUpdate(); } bool deviceNeedsUpdate() const { return array_->deviceNeedsUpdate(); } - void reactivateDeviceWriteViews() const { array_->reactivateDeviceWriteViews(); } - void reactivateHostWriteViews() const { array_->reactivateHostWriteViews(); } + void setHostNeedsUpdate(bool v) const { return array_->setHostNeedsUpdate(v); } + void setDeviceNeedsUpdate(bool v) const { return array_->setDeviceNeedsUpdate(v); } void haloExchange(bool on_device = false) const; void adjointHaloExchange(bool on_device = false) const; diff --git a/src/atlas/field/detail/FieldInterface.cc b/src/atlas/field/detail/FieldInterface.cc index 33435747d..12eeee691 100644 --- a/src/atlas/field/detail/FieldInterface.cc +++ b/src/atlas/field/detail/FieldInterface.cc @@ -198,6 +198,10 @@ int atlas__Field__device_needs_update(const FieldImpl* This) { return This->deviceNeedsUpdate(); } +int atlas__Field__device_allocated(const FieldImpl* This) { + return This->deviceAllocated(); +} + void atlas__Field__rename(FieldImpl* This, const char* name) { ATLAS_ASSERT(This, "Cannot rename uninitialised atlas_Field"); This->rename(std::string(name)); @@ -219,6 +223,16 @@ void atlas__Field__set_functionspace(FieldImpl* This, const functionspace::Funct #endif } +void atlas__Field__set_host_needs_update(const FieldImpl* This, int value) { + ATLAS_ASSERT(This != nullptr, "Cannot set value for uninitialised atlas_Field"); + This->setHostNeedsUpdate(value); +} + +void atlas__Field__set_device_needs_update(const FieldImpl* This, int value) { + ATLAS_ASSERT(This != nullptr, "Cannot set value for uninitialised atlas_Field"); + This->setDeviceNeedsUpdate(value); +} + void atlas__Field__update_device(FieldImpl* This) { ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field"); This->updateDevice(); @@ -234,6 +248,16 @@ void atlas__Field__sync_host_device(FieldImpl* This) { This->syncHostDevice(); } +void atlas__Field__allocate_device(FieldImpl* This) { + ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field"); + This->allocateDevice(); +} + +void atlas__Field__deallocate_device(FieldImpl* This) { + ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field"); + This->deallocateDevice(); +} + void atlas__Field__set_dirty(FieldImpl* This, int value) { ATLAS_ASSERT(This != nullptr, "Cannot access uninitialised atlas_Field"); This->set_dirty(value); diff --git a/src/atlas/field/detail/FieldInterface.h b/src/atlas/field/detail/FieldInterface.h index c0b0ebcc2..a47ffe5f8 100644 --- a/src/atlas/field/detail/FieldInterface.h +++ b/src/atlas/field/detail/FieldInterface.h @@ -59,9 +59,14 @@ void atlas__Field__set_levels(FieldImpl* This, int levels); void atlas__Field__set_functionspace(FieldImpl* This, const functionspace::FunctionSpaceImpl* functionspace); int atlas__Field__host_needs_update(const FieldImpl* This); int atlas__Field__device_needs_update(const FieldImpl* This); +int atlas__Field__device_allocated(const FieldImpl* This); +void atlas__Field__set_host_needs_update(const FieldImpl* This, int value); +void atlas__Field__set_device_needs_update(const FieldImpl* This, int value); void atlas__Field__update_device(FieldImpl* This); void atlas__Field__update_host(FieldImpl* This); void atlas__Field__sync_host_device(FieldImpl* This); +void atlas__Field__allocate_device(FieldImpl* This); +void atlas__Field__deallocate_device(FieldImpl* This); void atlas__Field__set_dirty(FieldImpl* This, int value); void atlas__Field__halo_exchange(FieldImpl* This, int on_device); void atlas__Field__adjoint_halo_exchange(FieldImpl* This, int on_device); diff --git a/src/atlas_f/field/atlas_Field_module.fypp b/src/atlas_f/field/atlas_Field_module.fypp index 68cc5f04a..c841a9a00 100644 --- a/src/atlas_f/field/atlas_Field_module.fypp +++ b/src/atlas_f/field/atlas_Field_module.fypp @@ -90,6 +90,11 @@ contains #:endfor & dummy + procedure, public :: device_allocated + procedure, public :: allocate_device + procedure, public :: deallocate_device + procedure, public :: set_host_needs_update + procedure, public :: set_device_needs_update procedure, public :: host_needs_update procedure, public :: device_needs_update procedure, public :: update_device @@ -633,6 +638,38 @@ end subroutine !------------------------------------------------------------------------------- +subroutine set_host_needs_update(this,value) + use atlas_field_c_binding + class(atlas_Field), intent(in) :: this + logical, optional, intent(in) :: value + integer :: value_int + value_int = 1 + if (present(value)) then + if (.not. value) then + value_int = 0 + endif + endif + call atlas__field__set_host_needs_update(this%CPTR_PGIBUG_A,value_int) +end subroutine + +!------------------------------------------------------------------------------- + +subroutine set_device_needs_update(this,value) + use atlas_field_c_binding + class(atlas_Field), intent(in) :: this + logical, optional, intent(in) :: value + integer :: value_int + value_int = 1 + if (present(value)) then + if (.not. value) then + value_int = 0 + endif + endif + call atlas__field__set_device_needs_update(this%CPTR_PGIBUG_A,value_int) +end subroutine + +!------------------------------------------------------------------------------- + function host_needs_update(this) use atlas_field_c_binding logical :: host_needs_update @@ -683,6 +720,35 @@ end subroutine !------------------------------------------------------------------------------- +subroutine allocate_device(this) + use atlas_field_c_binding + class(atlas_Field), intent(inout) :: this + call atlas__Field__allocate_device(this%CPTR_PGIBUG_A) +end subroutine + +!------------------------------------------------------------------------------- + +subroutine deallocate_device(this) + use atlas_field_c_binding + class(atlas_Field), intent(inout) :: this + call atlas__Field__deallocate_device(this%CPTR_PGIBUG_A) +end subroutine + +!------------------------------------------------------------------------------- + +function device_allocated(this) + use atlas_field_c_binding + logical :: device_allocated + class(atlas_Field), intent(in) :: this + if( atlas__Field__device_allocated(this%CPTR_PGIBUG_A) == 1 ) then + device_allocated = .true. + else + device_allocated = .false. + endif +end function + +!------------------------------------------------------------------------------- + subroutine halo_exchange(this,on_device) use, intrinsic :: iso_c_binding, only : c_int use atlas_field_c_binding diff --git a/src/tests/array/test_array_kernel.cu b/src/tests/array/test_array_kernel.cu index 3d98b8005..e3aa9715e 100644 --- a/src/tests/array/test_array_kernel.cu +++ b/src/tests/array/test_array_kernel.cu @@ -45,7 +45,7 @@ CASE( "test_array" ) constexpr unsigned int dy = 6; constexpr unsigned int dz = 7; - Array* ds = Array::create(dx, dy, dz); + auto ds = std::unique_ptr(Array::create(dx, dy, dz)); auto hv = make_host_view(*ds); hv(3, 3, 3) = 4.5; @@ -58,11 +58,8 @@ CASE( "test_array" ) cudaDeviceSynchronize(); ds->updateHost(); - ds->reactivateHostWriteViews(); EXPECT( hv(3, 3, 3) == 4.5 + dx*dy*dz ); - - delete ds; } CASE( "test_array_loop" ) @@ -71,7 +68,8 @@ CASE( "test_array_loop" ) constexpr unsigned int dy = 6; constexpr unsigned int dz = 7; - Array* ds = Array::create(dx, dy, dz); + + auto ds = std::unique_ptr(Array::create(dx, dy, dz)); array::ArrayView hv = make_host_view(*ds); for(int i=0; i < dx; i++) { for(int j=0; j < dy; j++) { @@ -90,7 +88,6 @@ CASE( "test_array_loop" ) cudaDeviceSynchronize(); ds->updateHost(); - ds->reactivateHostWriteViews(); for(int i=0; i < dx; i++) { for(int j=0; j < dy; j++) { @@ -99,8 +96,6 @@ CASE( "test_array_loop" ) } } } - - delete ds; } } }