From 58d2dccafabcc7c51bd3530b5a75642d143812de Mon Sep 17 00:00:00 2001 From: ddundo Date: Thu, 28 Nov 2024 21:01:46 +0000 Subject: [PATCH 1/8] #241: Add method to --- goalie/function_data.py | 52 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/goalie/function_data.py b/goalie/function_data.py index 7a564bd..ad7490d 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -280,6 +280,58 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): f = self._data[field][field_type][i][j] outfile.save_function(f, name=name, idx=j) + def transfer(self, target, method="interpolate"): + """ + Interpolate or project all functions from this :class:`~FunctionData` object to the target :class:`~FunctionData` object. + + :arg target: the target :class:`~FunctionData` object to which to transfer the data + :type target: :class:`.FunctionData` + :arg method: the transfer method to use, either 'interpolate' or 'project' + :type method: :class:`str` + """ + if method not in ["interpolate", "project"]: + raise ValueError( + f"Transfer method '{method}' not supported. Supported methods are 'interpolate' or 'project'." + ) + + if ( + self.time_partition.num_subintervals + != target.time_partition.num_subintervals + ): + raise ValueError( + "Source and target have different numbers of subintervals." + ) + if ( + self.time_partition.num_exports_per_subinterval + != target.time_partition.num_exports_per_subinterval + ): + raise ValueError( + "Source and target have different numbers of exports per subinterval." + ) + + common_fields = set(self.time_partition.field_names) & set( + target.time_partition.field_names + ) + if not common_fields: + raise ValueError("No common fields between source and target.") + + common_labels = set(self.labels) & set(target.labels) + if not common_labels: + raise ValueError("No common labels between source and target.") + + for field in common_fields: + for label in common_labels: + for i in range(self.time_partition.num_subintervals): + for j in range( + self.time_partition.num_exports_per_subinterval[i] - 1 + ): + source_function = self._data[field][label][i][j] + target_function = target._data[field][label][i][j] + if method == "interpolate": + target_function.interpolate(source_function) + elif method == "project": + target_function.project(source_function) + class ForwardSolutionData(FunctionData): """ From 2b7a72f4a398c2b7ad344979d3b9fadbf8010bff Mon Sep 17 00:00:00 2001 From: ddundo Date: Thu, 28 Nov 2024 21:02:15 +0000 Subject: [PATCH 2/8] #241: Add tests --- test/test_function_data.py | 152 +++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/test/test_function_data.py b/test/test_function_data.py index 36a2789..a79d8f9 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -267,5 +267,157 @@ def test_export_h5(self): self.assertTrue(os.path.exists(export_filepath)) +class TestTransferFunctionData(BaseTestCases.TestFunctionData): + """ + Unit tests for transferring data from one :class:`~.FunctionData` to another. + """ + + def setUp(self): + super().setUpUnsteady() + self.labels = ("forward", "forward_old") + + def _create_function_data(self): + self.solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + self.solution_data._create_data() + + def test_transfer_method_error(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="invalid_method") + self.assertEqual( + str(cm.exception), + "Transfer method 'invalid_method' not supported. Supported methods are 'interpolate' or 'project'.", + ) + + def test_transfer_subintervals_error(self): + target_time_partition = TimePartition( + 1.5 * self.time_partition.end_time, + self.time_partition.num_subintervals + 1, + self.time_partition.timesteps + [0.25], + self.time_partition.field_names, + ) + target_function_spaces = { + self.field: [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), + "Source and target have different numbers of subintervals.", + ) + + def test_transfer_exports_error(self): + target_time_partition = TimePartition( + self.time_partition.end_time, + self.time_partition.num_subintervals, + self.time_partition.timesteps, + self.time_partition.field_names, + num_timesteps_per_export=[1, 2], + ) + target_function_spaces = { + self.field: [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), + "Source and target have different numbers of exports per subinterval.", + ) + + def test_transfer_common_fields_error(self): + target_time_partition = TimePartition( + self.time_partition.end_time, + self.time_partition.num_subintervals, + self.time_partition.timesteps, + ["different_field"], + ) + target_function_spaces = { + "different_field": [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), "No common fields between source and target." + ) + + def test_transfer_common_labels_error(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + target_solution_data.labels = ("different_label",) + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), "No common labels between source and target." + ) + + def test_transfer_interpolate(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="interpolate") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + + def test_transfer_project(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="project") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + + if __name__ == "__main__": unittest.main() From 64f46835db1c364cf67c9f5fb6cb581e519b683d Mon Sep 17 00:00:00 2001 From: ddundo Date: Thu, 28 Nov 2024 21:04:00 +0000 Subject: [PATCH 3/8] #241: Fix line lengths --- goalie/function_data.py | 9 ++++++--- test/test_function_data.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/goalie/function_data.py b/goalie/function_data.py index ad7490d..2194bde 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -282,16 +282,19 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): def transfer(self, target, method="interpolate"): """ - Interpolate or project all functions from this :class:`~FunctionData` object to the target :class:`~FunctionData` object. + Interpolate or project all functions from this :class:`~FunctionData` object to + the target :class:`~FunctionData` object. - :arg target: the target :class:`~FunctionData` object to which to transfer the data + :arg target: the target :class:`~FunctionData` object to which to transfer the + data :type target: :class:`.FunctionData` :arg method: the transfer method to use, either 'interpolate' or 'project' :type method: :class:`str` """ if method not in ["interpolate", "project"]: raise ValueError( - f"Transfer method '{method}' not supported. Supported methods are 'interpolate' or 'project'." + f"Transfer method '{method}' not supported." + " Supported methods are 'interpolate' or 'project'." ) if ( diff --git a/test/test_function_data.py b/test/test_function_data.py index a79d8f9..bfebbc3 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -291,7 +291,8 @@ def test_transfer_method_error(self): self.solution_data.transfer(target_solution_data, method="invalid_method") self.assertEqual( str(cm.exception), - "Transfer method 'invalid_method' not supported. Supported methods are 'interpolate' or 'project'.", + "Transfer method 'invalid_method' not supported." + " Supported methods are 'interpolate' or 'project'.", ) def test_transfer_subintervals_error(self): From 5ad30de42d254ce95a84ca599dea8626bd7abb52 Mon Sep 17 00:00:00 2001 From: ddundo Date: Thu, 28 Nov 2024 21:13:12 +0000 Subject: [PATCH 4/8] #241: Shorten variable names --- goalie/function_data.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/goalie/function_data.py b/goalie/function_data.py index 2194bde..ec3f4e4 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -291,30 +291,24 @@ def transfer(self, target, method="interpolate"): :arg method: the transfer method to use, either 'interpolate' or 'project' :type method: :class:`str` """ + stp = self.time_partition + ttp = target.time_partition + if method not in ["interpolate", "project"]: raise ValueError( f"Transfer method '{method}' not supported." " Supported methods are 'interpolate' or 'project'." ) - - if ( - self.time_partition.num_subintervals - != target.time_partition.num_subintervals - ): + if stp.num_subintervals != ttp.num_subintervals: raise ValueError( "Source and target have different numbers of subintervals." ) - if ( - self.time_partition.num_exports_per_subinterval - != target.time_partition.num_exports_per_subinterval - ): + if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval: raise ValueError( "Source and target have different numbers of exports per subinterval." ) - common_fields = set(self.time_partition.field_names) & set( - target.time_partition.field_names - ) + common_fields = set(stp.field_names) & set(ttp.field_names) if not common_fields: raise ValueError("No common fields between source and target.") @@ -324,10 +318,8 @@ def transfer(self, target, method="interpolate"): for field in common_fields: for label in common_labels: - for i in range(self.time_partition.num_subintervals): - for j in range( - self.time_partition.num_exports_per_subinterval[i] - 1 - ): + for i in range(stp.num_subintervals): + for j in range(stp.num_exports_per_subinterval[i] - 1): source_function = self._data[field][label][i][j] target_function = target._data[field][label][i][j] if method == "interpolate": From 87ff9783e15b5becf8ac97ef49ce50cba83aedb1 Mon Sep 17 00:00:00 2001 From: ddundo Date: Sun, 8 Dec 2024 11:35:16 +0000 Subject: [PATCH 5/8] #241: Fix API link --- goalie/function_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/goalie/function_data.py b/goalie/function_data.py index ec3f4e4..0bb9252 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -282,12 +282,12 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): def transfer(self, target, method="interpolate"): """ - Interpolate or project all functions from this :class:`~FunctionData` object to - the target :class:`~FunctionData` object. + Interpolate or project all functions from this :class:`~.FunctionData` object to + the target :class:`~.FunctionData` object. - :arg target: the target :class:`~FunctionData` object to which to transfer the + :arg target: the target :class:`~.FunctionData` object to which to transfer the data - :type target: :class:`.FunctionData` + :type target: :class:`~.FunctionData` :arg method: the transfer method to use, either 'interpolate' or 'project' :type method: :class:`str` """ From 3952b60e2bd32789fb99d49a82f694a4aaadbc0f Mon Sep 17 00:00:00 2001 From: ddundo Date: Sun, 8 Dec 2024 12:08:29 +0000 Subject: [PATCH 6/8] #241: Add prolong option --- goalie/function_data.py | 14 +++++++++----- test/test_function_data.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/goalie/function_data.py b/goalie/function_data.py index 0bb9252..6d38207 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -6,6 +6,7 @@ import firedrake.function as ffunc import firedrake.functionspace as ffs +from firedrake import TransferManager from firedrake.checkpointing import CheckpointFile from firedrake.output.vtk_output import VTKFile @@ -282,22 +283,23 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): def transfer(self, target, method="interpolate"): """ - Interpolate or project all functions from this :class:`~.FunctionData` object to - the target :class:`~.FunctionData` object. + Transfer all functions from this :class:`~.FunctionData` object to the target + :class:`~.FunctionData` object by interpolation, projection or prolongation. :arg target: the target :class:`~.FunctionData` object to which to transfer the data :type target: :class:`~.FunctionData` - :arg method: the transfer method to use, either 'interpolate' or 'project' + :arg method: the transfer method to use. Either 'interpolate', 'project' or + 'prolong' :type method: :class:`str` """ stp = self.time_partition ttp = target.time_partition - if method not in ["interpolate", "project"]: + if method not in ["interpolate", "project", "prolong"]: raise ValueError( f"Transfer method '{method}' not supported." - " Supported methods are 'interpolate' or 'project'." + " Supported methods are 'interpolate', 'project', and 'prolong'." ) if stp.num_subintervals != ttp.num_subintervals: raise ValueError( @@ -326,6 +328,8 @@ def transfer(self, target, method="interpolate"): target_function.interpolate(source_function) elif method == "project": target_function.project(source_function) + elif method == "prolong": + TransferManager().prolong(source_function, target_function) class ForwardSolutionData(FunctionData): diff --git a/test/test_function_data.py b/test/test_function_data.py index bfebbc3..831684c 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -292,7 +292,7 @@ def test_transfer_method_error(self): self.assertEqual( str(cm.exception), "Transfer method 'invalid_method' not supported." - " Supported methods are 'interpolate' or 'project'.", + " Supported methods are 'interpolate', 'project', and 'prolong'.", ) def test_transfer_subintervals_error(self): @@ -419,6 +419,33 @@ def test_transfer_project(self): == target_function.dat.data.all() ) + def test_transfer_prolong(self): + enriched_mesh = MeshHierarchy(self.mesh, 1)[-1] + target_function_spaces = { + self.field: [ + FunctionSpace(enriched_mesh, "DG", 0) + for _ in range(self.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + self.time_partition, target_function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="prolong") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + if __name__ == "__main__": unittest.main() From 33102ec5b77a452182f39f01f1e55addb358afe9 Mon Sep 17 00:00:00 2001 From: ddundo Date: Sun, 8 Dec 2024 12:22:25 +0000 Subject: [PATCH 7/8] #241: Assign 1 to all fns --- test/test_function_data.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_function_data.py b/test/test_function_data.py index 831684c..240e2ee 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -282,6 +282,14 @@ def _create_function_data(self): ) self.solution_data._create_data() + def _assign_one_to_FunctionData(self, fn): + tp = self.solution_data.time_partition + for field in tp.field_names: + for label in self.solution_data.labels: + for i in range(tp.num_subintervals): + for j in range(tp.num_exports_per_subinterval[i] - 1): + fn._data[field][label][i][j].assign(1) + def test_transfer_method_error(self): target_solution_data = ForwardSolutionData( self.time_partition, self.function_spaces @@ -384,6 +392,7 @@ def test_transfer_interpolate(self): self.time_partition, self.function_spaces ) target_solution_data._create_data() + self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="interpolate") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: @@ -404,6 +413,7 @@ def test_transfer_project(self): self.time_partition, self.function_spaces ) target_solution_data._create_data() + self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="project") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: @@ -431,6 +441,7 @@ def test_transfer_prolong(self): self.time_partition, target_function_spaces ) target_solution_data._create_data() + self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="prolong") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: From 48a448e6683b6b1c4e90047ddf280e347238c25d Mon Sep 17 00:00:00 2001 From: ddundo Date: Sun, 8 Dec 2024 12:28:17 +0000 Subject: [PATCH 8/8] #241: Assign 1 to all source functions --- test/test_function_data.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/test_function_data.py b/test/test_function_data.py index 240e2ee..b88375f 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -282,13 +282,13 @@ def _create_function_data(self): ) self.solution_data._create_data() - def _assign_one_to_FunctionData(self, fn): + # Assign 1 to all functions tp = self.solution_data.time_partition for field in tp.field_names: for label in self.solution_data.labels: for i in range(tp.num_subintervals): for j in range(tp.num_exports_per_subinterval[i] - 1): - fn._data[field][label][i][j].assign(1) + self.solution_data._data[field][label][i][j].assign(1) def test_transfer_method_error(self): target_solution_data = ForwardSolutionData( @@ -392,7 +392,6 @@ def test_transfer_interpolate(self): self.time_partition, self.function_spaces ) target_solution_data._create_data() - self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="interpolate") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: @@ -413,7 +412,6 @@ def test_transfer_project(self): self.time_partition, self.function_spaces ) target_solution_data._create_data() - self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="project") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: @@ -441,7 +439,6 @@ def test_transfer_prolong(self): self.time_partition, target_function_spaces ) target_solution_data._create_data() - self._assign_one_to_FunctionData(target_solution_data) self.solution_data.transfer(target_solution_data, method="prolong") for field in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: