From d95e13df3812f8b4c505bbb887f9e763fbcdd563 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 26 Nov 2024 20:35:51 +0100 Subject: [PATCH] Allocatable classes for factories --- .gitignore | 2 + examples/Makefile | 5 +- examples/issue227_allocatable/run.py | 65 ++++----- .../issue235_allocatable_classes/Makefile | 25 ++++ .../Makefile.meson | 6 + .../issue235_allocatable_classes/myclass.f90 | 39 ++++++ .../myclass_factory.f90 | 18 +++ .../issue235_allocatable_classes/mytype.f90 | 31 +++++ examples/issue235_allocatable_classes/run.py | 83 ++++++++++++ f90wrap/f90wrapgen.py | 128 ++++++++++++++---- f90wrap/fortran.py | 11 ++ f90wrap/pywrapgen.py | 12 +- f90wrap/transform.py | 5 +- 13 files changed, 356 insertions(+), 74 deletions(-) create mode 100644 examples/issue235_allocatable_classes/Makefile create mode 100644 examples/issue235_allocatable_classes/Makefile.meson create mode 100644 examples/issue235_allocatable_classes/myclass.f90 create mode 100644 examples/issue235_allocatable_classes/myclass_factory.f90 create mode 100644 examples/issue235_allocatable_classes/mytype.f90 create mode 100644 examples/issue235_allocatable_classes/run.py diff --git a/.gitignore b/.gitignore index 72f788f3..67e51f91 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build *.mod *.a *.so +*.x f90wrap*.f90 *.pyc .pydevproject @@ -20,3 +21,4 @@ src.* .ipynb_checkpoints .idea/ *.swp +itest/ diff --git a/examples/Makefile b/examples/Makefile index a3cac481..4b7d809a 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -25,7 +25,10 @@ EXAMPLES = arrayderivedtypes \ derivedtypes_procedure \ optional_string \ long_subroutine_name \ - kind_map_default + kind_map_default \ + issue206_subroutine_oldstyle \ + issue227_allocatable \ + issue235_allocatable_classes PYTHON = python diff --git a/examples/issue227_allocatable/run.py b/examples/issue227_allocatable/run.py index ddba7afc..ae4f91ce 100644 --- a/examples/issue227_allocatable/run.py +++ b/examples/issue227_allocatable/run.py @@ -1,54 +1,41 @@ #!/usr/bin/env python -import os +import unittest import gc import tracemalloc import itest - -def main(): - test_type_output_is_wrapped() - test_intrinsic_output_is_not_wrapped() - test_array_output_is_not_wrapped() - test_type_output_wrapper() - test_memory_leak() - - -def test_type_output_is_wrapped(): - assert hasattr(itest.alloc_output, 'alloc_output_type_func') - - -def test_intrinsic_output_is_not_wrapped(): - assert (not hasattr(itest.alloc_output, 'alloc_output_intrinsic_func')) - - -def test_array_output_is_not_wrapped(): - assert (not hasattr(itest.alloc_output, 'alloc_output_array_func')) - - VAL = 10.0 TOL = 1e-13 +class TestAllocOutput(unittest.TestCase): + + def test_type_output_is_wrapped(self): + self.assertTrue(hasattr(itest.alloc_output, 'alloc_output_type_func')) -def test_type_output_wrapper(): - t = itest.alloc_output.alloc_output_type_func(VAL) - assert(abs(t.a - VAL) < TOL) + def test_intrinsic_output_is_not_wrapped(self): + self.assertFalse(hasattr(itest.alloc_output, 'alloc_output_intrinsic_func')) + def test_array_output_is_not_wrapped(self): + self.assertFalse(hasattr(itest.alloc_output, 'alloc_output_array_func')) -def test_memory_leak(): - gc.collect() - t = [] - tracemalloc.start() - start_snapshot = tracemalloc.take_snapshot() - for i in range(2048): - t.append(itest.alloc_output.alloc_output_type_func(VAL)) - del t - gc.collect() - end_snapshot = tracemalloc.take_snapshot() - tracemalloc.stop() - stats = end_snapshot.compare_to(start_snapshot, 'lineno') - assert sum(stat.size_diff for stat in stats) < 1024 + def test_type_output_wrapper(self): + t = itest.alloc_output.alloc_output_type_func(VAL) + self.assertAlmostEqual(t.a, VAL, delta=TOL) + def test_memory_leak(self): + gc.collect() + t = [] + tracemalloc.start() + start_snapshot = tracemalloc.take_snapshot() + for i in range(8192): + t.append(itest.alloc_output.alloc_output_type_func(VAL)) + del t + gc.collect() + end_snapshot = tracemalloc.take_snapshot() + tracemalloc.stop() + stats = end_snapshot.compare_to(start_snapshot, 'lineno') + self.assertLess(sum(stat.size_diff for stat in stats), 4096) if __name__ == '__main__': - main() + unittest.main() diff --git a/examples/issue235_allocatable_classes/Makefile b/examples/issue235_allocatable_classes/Makefile new file mode 100644 index 00000000..2e61ae90 --- /dev/null +++ b/examples/issue235_allocatable_classes/Makefile @@ -0,0 +1,25 @@ +FC = gfortran +FCFLAGS = -fPIC +PYTHON = python + +all: wrapper + +test: wrapper + $(PYTHON) run.py + +wrapper: f90wrapper mytype.o myclass.o myclass_factory.o + f2py-f90wrap --build-dir . -c -m _itest --opt="-O0 -g" \ + f90wrap_mytype.f90 f90wrap_myclass.f90 f90wrap_myclass_factory.f90 \ + mytype.o myclass.o myclass_factory.o + +f90wrapper: mytype.f90 myclass.f90 myclass_factory.f90 + f90wrap -m itest -P mytype.f90 myclass.f90 myclass_factory.f90 -v + +%.o : %.f90 + $(FC) $(FCFLAGS) -c -g -O0 $< -o $@ + +clean: + rm -f *.o f90wrap*.f90 *.so *.mod + rm -rf src.*/ + rm -rf itest/ + -rm -rf src.*/ .f2py_f2cmap .libs/ __pycache__/ diff --git a/examples/issue235_allocatable_classes/Makefile.meson b/examples/issue235_allocatable_classes/Makefile.meson new file mode 100644 index 00000000..1fe6c182 --- /dev/null +++ b/examples/issue235_allocatable_classes/Makefile.meson @@ -0,0 +1,6 @@ +include ../make.meson.inc + +NAME := itest + +test: build + $(PYTHON) run.py diff --git a/examples/issue235_allocatable_classes/myclass.f90 b/examples/issue235_allocatable_classes/myclass.f90 new file mode 100644 index 00000000..58c2e1b0 --- /dev/null +++ b/examples/issue235_allocatable_classes/myclass.f90 @@ -0,0 +1,39 @@ +module myclass + +implicit none + +integer :: create_count = 0 +integer :: destroy_count = 0 + +type :: myclass_t + real :: val +contains + procedure :: get_val => myclass_get_val + procedure :: set_val => myclass_set_val + final :: myclass_destroy +end type myclass_t + +contains + +subroutine myclass_get_val(self, val) + class(myclass_t), intent(in) :: self + real, intent(out) :: val + + val = self%val +end subroutine myclass_get_val + +subroutine myclass_set_val(self, val) + class(myclass_t), intent(inout) :: self + real, intent(in) :: val + + self%val = val +end subroutine myclass_set_val + +subroutine myclass_destroy(self) + type(myclass_t), intent(inout) :: self + + destroy_count = destroy_count + 1 + print *, 'Destroying class_t with val = ', self%val +end subroutine myclass_destroy + +end module myclass diff --git a/examples/issue235_allocatable_classes/myclass_factory.f90 b/examples/issue235_allocatable_classes/myclass_factory.f90 new file mode 100644 index 00000000..7515f582 --- /dev/null +++ b/examples/issue235_allocatable_classes/myclass_factory.f90 @@ -0,0 +1,18 @@ +module myclass_factory + +use myclass, only: myclass_t, create_count +implicit none + +contains + +function myclass_create(val) result(myobject) + class(myclass_t), allocatable :: myobject + real, intent(in) :: val + + allocate(myclass_t :: myobject) + call myobject%set_val(val) + create_count = create_count + 1 + +end function myclass_create + +end module myclass_factory diff --git a/examples/issue235_allocatable_classes/mytype.f90 b/examples/issue235_allocatable_classes/mytype.f90 new file mode 100644 index 00000000..d033871c --- /dev/null +++ b/examples/issue235_allocatable_classes/mytype.f90 @@ -0,0 +1,31 @@ +module mytype + + implicit none + + integer :: create_count = 0 + integer :: destroy_count = 0 + + type :: mytype_t + real :: val + contains + final :: mytype_destroy + end type mytype_t + + contains + + function mytype_create(val) result(self) + type(mytype_t) :: self + real, intent(in) :: val + + self%val = val + create_count = create_count + 1 + end function mytype_create + + subroutine mytype_destroy(self) + type(mytype_t), intent(inout) :: self + + destroy_count = destroy_count + 1 + print *, 'Destroying mytype_t with val = ', self%val + end subroutine mytype_destroy + +end module mytype diff --git a/examples/issue235_allocatable_classes/run.py b/examples/issue235_allocatable_classes/run.py new file mode 100644 index 00000000..8b60a8da --- /dev/null +++ b/examples/issue235_allocatable_classes/run.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +import unittest +from itest import mytype, myclass, myclass_factory + +REF = 3.1415 +TOL = 1.0e-6 + +class TestMyType(unittest.TestCase): + + def test_create_destroy_type_object(self): + """Object creation and destruction should happen only once.""" + mytype.set_create_count(0) + mytype.set_destroy_count(0) + + obj = mytype.mytype_create(REF) + + self.assertEqual(mytype.get_create_count(), 1) + + self.assertTrue(abs(obj.val - REF) < TOL) + + del obj + + self.assertEqual(mytype.get_create_count(), 1) + self.assertGreaterEqual(mytype.get_destroy_count(), 1) + + def test_type_member_access(self): + """Direct access of member variables.""" + obj = mytype.mytype_create(REF) + + self.assertTrue(abs(obj.val - REF) < TOL) + + obj.val = 2.0 * REF + + self.assertTrue(abs(obj.val - 2.0 * REF) < TOL) + + del obj + + +class TestMyClass(unittest.TestCase): + + def test_create_destroy_class_object(self): + """Object creation and destruction should happen only once.""" + myclass.set_create_count(0) + myclass.set_destroy_count(0) + + obj = myclass_factory.myclass_create(REF) + + self.assertEqual(myclass.get_create_count(), 1) + + self.assertTrue(abs(obj.get_val() - REF) < TOL) + + del obj + + self.assertEqual(myclass.get_create_count(), 1) + self.assertGreaterEqual(myclass.get_destroy_count(), 1) + + def test_class_getter_setter(self): + """Getters and setters defined in Fortran should work.""" + obj = myclass_factory.myclass_create(REF) + + self.assertTrue(abs(obj.get_val() - REF) < TOL) + + obj.set_val(2.0 * REF) + + self.assertTrue(abs(obj.get_val() - 2.0 * REF) < TOL) + + del obj + + def test_class_member_access(self): + """Direct access of member variables.""" + obj = myclass_factory.myclass_create(REF) + + self.assertTrue(abs(obj.val - REF) < TOL) + + obj.val = 2.0 * REF + + self.assertTrue(abs(obj.val - 2.0 * REF) < TOL) + + del obj + + +if __name__ == "__main__": + unittest.main() diff --git a/f90wrap/f90wrapgen.py b/f90wrap/f90wrapgen.py index 69c527e8..810c3faa 100644 --- a/f90wrap/f90wrapgen.py +++ b/f90wrap/f90wrapgen.py @@ -218,7 +218,7 @@ def write_super_type_lines(self, ty): self.write("end type " + ty.name) self.write() - def write_type_lines(self, tname, recursive=False): + def write_type_lines(self, tname, recursive=False, tname_inner=None): """ Write a pointer type for a given type name @@ -231,21 +231,58 @@ def write_type_lines(self, tname, recursive=False): Adjusts array pointer for recursive derived type array """ tname = ft.strip_type(tname) + if tname_inner is None: + tname_inner = tname if not recursive: self.write( - """type %(typename)s_ptr_type - type(%(typename)s), pointer :: p => NULL() -end type %(typename)s_ptr_type""" - % {"typename": tname} + "type %(typename)s_ptr_type\n" + " type(%(typename_inner)s), pointer :: p => NULL()\n" + "end type %(typename)s_ptr_type" % { + "typename": tname, + "typename_inner": tname_inner + } ) else: self.write( - """type %(typename)s_rec_ptr_type - type(%(typename)s), pointer :: p => NULL() -end type %(typename)s_rec_ptr_type""" - % {"typename": tname} + "type %(typename)s_rec_ptr_type\n" + " type(%(typename_inner)s), pointer :: p => NULL()\n" + "end type %(typename)s_rec_ptr_type" % { + "typename": tname, + "typename_inner": tname_inner + } ) + def write_class_lines(self, cname, recursive=False): + """ + Write a pointer type for a given class name + + Parameters + ---------- + tname : `str` + Should be the name of a class in the wrapped code. + """ + cname = ft.strip_type(cname) + self.write( + "type %(classname)s_wrapper_type\n" + " class(%(classname)s), allocatable :: obj\n" + "end type %(classname)s_wrapper_type" % {"classname": cname} + ) + self.write_type_lines(cname, recursive, f"{cname}_wrapper_type") + + def is_class(self, tname): + if not tname in self.types: + return False + if "used_as_class" in self.types[tname].attributes: + return True + return False + + def write_type_or_class_lines(self, tname, recursive=False): + if self.is_class(tname): + self.write_class_lines(tname, recursive) + else: + self.write_type_lines(tname, recursive) + + def write_arg_decl_lines(self, node): """ Write argument declaration lines to the code @@ -331,6 +368,9 @@ def write_init_lines(self, node): """ for alloc in node.allocate: self.write("allocate(%s_ptr%%p)" % alloc) # (self.prefix, alloc)) + if (self.is_class(node.type_name) and "constructor" in node.attributes + and "skip_call" in node.attributes): + self.write("allocate(this_ptr%p%obj)") for arg in node.arguments: if not hasattr(arg, "init_lines"): continue @@ -361,12 +401,23 @@ def write_call_lines(self, node, func_name): def dummy_arg_name(arg): return arg.orig_name + def is_type_a_class(arg_type): + if arg_type.startswith("class") and arg_type[6:-1]: + return True + if arg_type.startswith("type") and arg_type[5:-1]: + tname = arg_type[5:-1] + if self.is_class(tname): + return True + return False + def actual_arg_name(arg): name = arg.name if (hasattr(node, "transfer_in") and arg.name in node.transfer_in) or ( hasattr(node, "transfer_out") and arg.name in node.transfer_out ): name += "_ptr%p" + if is_type_a_class(arg.type): + name += "%obj" if "super-type" in arg.doc: name += "%items" return name @@ -489,7 +540,7 @@ def visit_Procedure(self, node): for tname in node.types: if tname in self.types and "super-type" in self.types[tname].doc: self.write_super_type_lines(self.types[tname]) - self.write_type_lines(tname) + self.write_type_or_class_lines(tname) self.write_arg_decl_lines(node) self.write_transfer_in_lines(node) self.write_init_lines(node) @@ -518,6 +569,11 @@ def visit_Type(self, node): return self.generic_visit(node) + def _get_type_member_array_name(self, t, element_name): + if (self.is_class(t.orig_name)): + return "this_ptr%%p%%obj%%%s" % element_name + return "this_ptr%%p%%%s" % element_name + def _write_sc_array_wrapper(self, t, el, dims, sizeof_fortran_t): """ Write wrapper for arrays of intrinsic types @@ -558,7 +614,7 @@ def _write_sc_array_wrapper(self, t, el, dims, sizeof_fortran_t): self.write("use, intrinsic :: iso_c_binding, only : c_int") self.write("implicit none") if isinstance(t, ft.Type): - self.write_type_lines(t.orig_name) + self.write_type_or_class_lines(t.orig_name) self.write("integer(c_int), intent(in) :: this(%d)" % sizeof_fortran_t) self.write("type(%s_ptr_type) :: this_ptr" % t.orig_name) else: @@ -579,7 +635,7 @@ def _write_sc_array_wrapper(self, t, el, dims, sizeof_fortran_t): self.write("dtype = %s" % ft.fortran_array_type(el.type, self.kind_map)) if isinstance(t, ft.Type): self.write("this_ptr = transfer(this, this_ptr)") - array_name = "this_ptr%%p%%%s" % el.orig_name + array_name = self._get_type_member_array_name(t, el.orig_name) else: array_name = "%s_%s" % (t.name, el.name) @@ -727,13 +783,13 @@ def _write_array_getset_item(self, t, el, sizeof_fortran_t, getset): same_type = ft.strip_type(t.name) == ft.strip_type(el.type) if isinstance(t, ft.Type): - self.write_type_lines(t.name) - self.write_type_lines(el.type, same_type) + self.write_type_or_class_lines(t.name) + self.write_type_or_class_lines(el.type, same_type) self.write("integer, intent(in) :: %s(%d)" % (this, sizeof_fortran_t)) if isinstance(t, ft.Type): self.write("type(%s_ptr_type) :: this_ptr" % t.name) - array_name = "this_ptr%%p%%%s" % el.name + array_name = self._get_type_member_array_name(t, el.name) else: array_name = "%s_%s" % (t.name, el.name) self.write("integer, intent(in) :: %s" % (safe_i)) @@ -847,15 +903,15 @@ def _write_array_len(self, t, el, sizeof_fortran_t): # Check if the type has recursive definition: same_type = ft.strip_type(t.name) == ft.strip_type(el.type) if isinstance(t, ft.Type): - self.write_type_lines(t.name) - self.write_type_lines(el.type, same_type) + self.write_type_or_class_lines(t.name) + self.write_type_or_class_lines(el.type, same_type) self.write("integer, intent(out) :: %s" % (safe_n)) self.write("integer, intent(in) :: %s(%d)" % (this, sizeof_fortran_t)) if isinstance(t, ft.Type): self.write("type(%s_ptr_type) :: this_ptr" % t.name) self.write() self.write("this_ptr = transfer(%s, this_ptr)" % (this)) - array_name = "this_ptr%%p%%%s" % el.name + array_name = self._get_type_member_array_name(t, el.name) else: array_name = "%s_%s" % (t.name, el.name) @@ -946,10 +1002,10 @@ def _write_scalar_wrapper(self, t, el, sizeof_fortran_t, getset): self.write("implicit none") if isinstance(t, ft.Type): - self.write_type_lines(t.orig_name) + self.write_type_or_class_lines(t.orig_name) if el.type.startswith("type") and not (el.type == "type(" + t.orig_name + ")"): - self.write_type_lines(el.type) + self.write_type_or_class_lines(el.type) if isinstance(t, ft.Type): self.write("integer, intent(in) :: this(%d)" % sizeof_fortran_t) @@ -976,9 +1032,14 @@ def _write_scalar_wrapper(self, t, el, sizeof_fortran_t, getset): self.write("this_ptr = transfer(this, this_ptr)") if getset == "get": if isinstance(t, ft.Type): - self.write( - "%s_ptr%%p => this_ptr%%p%%%s" % (el.orig_name, el.orig_name) - ) + if (self.is_class(t.orig_name)): + self.write( + "%s_ptr%%p%%obj = this_ptr%%p%%%s" % (el.orig_name, el.orig_name) + ) + else: + self.write( + "%s_ptr%%p => this_ptr%%p%%%s" % (el.orig_name, el.orig_name) + ) else: self.write( "%s_ptr%%p => %s_%s" % (el.orig_name, t.name, el.orig_name) @@ -992,9 +1053,14 @@ def _write_scalar_wrapper(self, t, el, sizeof_fortran_t, getset): % (el.orig_name, localvar, el.orig_name) ) if isinstance(t, ft.Type): - self.write( - "this_ptr%%p%%%s = %s_ptr%%p" % (el.orig_name, el.orig_name) - ) + if (self.is_class(t.orig_name)): + self.write( + "this_ptr%%p%%obj%%%s = %s_ptr%%p" % (el.orig_name, el.orig_name) + ) + else: + self.write( + "this_ptr%%p%%%s = %s_ptr%%p" % (el.orig_name, el.orig_name) + ) else: self.write( "%s_%s = %s_ptr%%p" % (t.name, el.orig_name, el.orig_name) @@ -1012,12 +1078,18 @@ def _write_scalar_wrapper(self, t, el, sizeof_fortran_t, getset): self.write("this_ptr = transfer(this, this_ptr)") if getset == "get": if isinstance(t, ft.Type): - self.write("%s = this_ptr%%p%%%s" % (localvar, el.orig_name)) + if (self.is_class(t.orig_name)): + self.write("%s = this_ptr%%p%%obj%%%s" % (localvar, el.orig_name)) + else: + self.write("%s = this_ptr%%p%%%s" % (localvar, el.orig_name)) else: self.write("%s = %s_%s" % (localvar, t.name, el.orig_name)) else: if isinstance(t, ft.Type): - self.write("this_ptr%%p%%%s = %s" % (el.orig_name, localvar)) + if (self.is_class(t.orig_name)): + self.write("this_ptr%%p%%obj%%%s = %s" % (el.orig_name, localvar)) + else: + self.write("this_ptr%%p%%%s = %s" % (el.orig_name, localvar)) else: self.write("%s_%s = %s" % (t.name, el.orig_name, localvar)) self.dedent() diff --git a/f90wrap/fortran.py b/f90wrap/fortran.py index ca0c9842..a71d7d9a 100644 --- a/f90wrap/fortran.py +++ b/f90wrap/fortran.py @@ -609,6 +609,17 @@ def find_types(tree, skipped_types=None): else: log.info('Skipping type %s defined in module %s' % (node.name, mod.name)) + for mod in walk_modules(tree): + for node in walk(mod): + if not 'type' in node.__dict__: + continue + if node.type.startswith('class('): + class_name = derived_typename(node.type) + if not class_name in types or class_name in skipped_types: + continue + if 'used_as_class' not in types[class_name].attributes: + types[class_name].attributes.append('used_as_class') + return types def fix_argument_attributes(node): diff --git a/f90wrap/pywrapgen.py b/f90wrap/pywrapgen.py index 497df048..410a3c06 100644 --- a/f90wrap/pywrapgen.py +++ b/f90wrap/pywrapgen.py @@ -521,7 +521,7 @@ def visit_Procedure(self, node): if isinstance(node, ft.Function): # convert any derived type return values to Python objects for ret_val in node.ret_val: - if ret_val.type.startswith("type"): + if ret_val.type.startswith("type") or ret_val.type.startswith("class"): cls_name = normalise_class_name( ft.strip_type(ret_val.type), self.class_names ) @@ -626,6 +626,13 @@ def visit_Type(self, node): self.write(format_doc_string(node)) self.generic_visit(node) + self.write_member_variables(node) + + self.write() + self.dedent() + self.write() + + def write_member_variables(self, node): properties = [] for el in node.elements: dims = list(filter(lambda x: x.startswith("dimension"), el.attributes)) @@ -643,9 +650,6 @@ def visit_Type(self, node): self.write( "_dt_array_initialisers = [%s]" % (", ".join(node.dt_array_initialisers)) ) - self.write() - self.dedent() - self.write() def write_scalar_wrappers(self, node, el, properties): dct = dict( diff --git a/f90wrap/transform.py b/f90wrap/transform.py index ef832120..1f584a6b 100644 --- a/f90wrap/transform.py +++ b/f90wrap/transform.py @@ -252,7 +252,8 @@ def visit_Procedure(self, node): continue else: # allocatable arguments only allowed for derived types - if 'allocatable' in arg.attributes and not arg.type.startswith('type'): + if 'allocatable' in arg.attributes and not ( + arg.type.startswith('type') or arg.type.startswith('class')): log.warning('removing routine %s due to allocatable intrinsic type arguments' % node.name) return None # no pointer arguments @@ -854,9 +855,9 @@ def add_missing_destructors(tree): for child in ft.iter_child_nodes(node): if 'destructor' in child.attributes: log.info('found destructor %s', child.name) + child.attributes.append('skip_call') break else: - log.info('adding missing destructor for %s', node.name) new_node = ft.Subroutine('%s_finalise' % node.name, node.filename,