diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0bbb6347..821d4633 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -160,6 +160,7 @@ jobs: - run: python benchmark/benchmark_big_model.py name: Benchmark against big model + changelog-test: runs-on: ubuntu-latest if: github.ref != 'refs/heads/main' diff --git a/benchmark/benchmark_big_model.py b/benchmark/benchmark_big_model.py index 89d1d4ad..26af2a9b 100644 --- a/benchmark/benchmark_big_model.py +++ b/benchmark/benchmark_big_model.py @@ -3,7 +3,7 @@ import time import tracemalloc -FAIL_THRESHOLD = 500 +FAIL_THRESHOLD = 30 tracemalloc.start() start = time.time() diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index 0cda72cb..530c536c 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -1,6 +1,13 @@ MontePy Changelog ================= +#Next Version# +-------------- + +**Performance Improvement** + +* Fixed method of linking ``Material`` to ``ThermalScattering`` objects, avoiding a very expensive O(N:sup:`2`) (:issue:`510`). + 0.4.0 -------------- diff --git a/montepy/data_inputs/data_input.py b/montepy/data_inputs/data_input.py index 80b0fd59..f6cf97f9 100644 --- a/montepy/data_inputs/data_input.py +++ b/montepy/data_inputs/data_input.py @@ -179,6 +179,8 @@ def update_pointers(self, data_inputs): :param data_inputs: a list of the data inputs in the problem :type data_inputs: list + :returns: True iff this input should be removed from ``problem.data_inputs`` + :rtype: bool, None """ pass diff --git a/montepy/data_inputs/material.py b/montepy/data_inputs/material.py index a6fe2c04..0c754e95 100644 --- a/montepy/data_inputs/material.py +++ b/montepy/data_inputs/material.py @@ -177,18 +177,7 @@ def update_pointers(self, data_inputs): :param data_inputs: a list of the data inputs in the problem :type data_inputs: list """ - for input in list(data_inputs): - if isinstance(input, thermal_scattering.ThermalScatteringLaw): - if input.old_number == self.number: - if not self._thermal_scattering: - self._thermal_scattering = input - input._parent_material = self - data_inputs.remove(input) - else: - raise MalformedInputError( - self, - f"Multiple MT inputs were specified for this material: {self.number}.", - ) + pass @staticmethod def _class_prefix(): diff --git a/montepy/data_inputs/thermal_scattering.py b/montepy/data_inputs/thermal_scattering.py index 4f5bd5d6..c52b083a 100644 --- a/montepy/data_inputs/thermal_scattering.py +++ b/montepy/data_inputs/thermal_scattering.py @@ -122,19 +122,40 @@ def update_pointers(self, data_inputs): :param data_inputs: a list of the data inputs in the problem :type data_inputs: list + :returns: True iff this input should be removed from ``problem.data_inputs`` + :rtype: bool """ + # use caching first + if self._problem: + try: + mat = self._problem.materials[self.old_number] + except KeyError: + raise MalformedInputError( + self._input, "MT input is detached from a parent material" + ) + # brute force it found = False - for input in data_inputs: - if isinstance(input, montepy.data_inputs.material.Material): - if input.number == self.old_number: + for data_input in data_inputs: + if isinstance(data_input, montepy.data_inputs.material.Material): + if data_input.number == self.old_number: + mat = data_input found = True - self._parent_material = input - + break + # actually update things if not found: raise MalformedInputError( self._input, "MT input is detached from a parent material" ) + if mat.thermal_scattering: + raise MalformedInputError( + self, + f"Multiple MT inputs were specified for this material: {self.old_number}.", + ) + mat.thermal_scattering = self + self._parent_material = mat + return True + def __str__(self): return f"THERMAL SCATTER: {self.thermal_scattering_laws}" diff --git a/montepy/mcnp_problem.py b/montepy/mcnp_problem.py index cbdd05ed..2f1ae98b 100644 --- a/montepy/mcnp_problem.py +++ b/montepy/mcnp_problem.py @@ -332,9 +332,11 @@ def handle_error(e): ParticleTypeNotInCell, ) as e: handle_error(e) - for input in self._data_inputs: + to_delete = [] + for data_index, data_input in enumerate(self._data_inputs): try: - input.update_pointers(self._data_inputs) + if data_input.update_pointers(self._data_inputs): + to_delete.append(data_index) except ( BrokenObjectLinkError, MalformedInputError, @@ -343,6 +345,8 @@ def handle_error(e): ) as e: handle_error(e) continue + for delete_index in to_delete[::-1]: + del self._data_inputs[delete_index] def remove_duplicate_surfaces(self, tolerance): """Finds duplicate surfaces in the problem, and remove them. diff --git a/prof/dump_results.py b/prof/dump_results.py index fa9f9a53..75beede4 100644 --- a/prof/dump_results.py +++ b/prof/dump_results.py @@ -2,5 +2,5 @@ from pstats import SortKey stats = pstats.Stats("prof/combined.prof") -stats.sort_stats(SortKey.CUMULATIVE, SortKey.TIME).print_stats(300, "montepy") -stats.sort_stats(SortKey.CUMULATIVE, SortKey.TIME).print_stats(100, "sly") +stats.sort_stats(SortKey.CUMULATIVE, SortKey.TIME).print_stats("montepy", 50) +stats.sort_stats(SortKey.CUMULATIVE, SortKey.TIME).print_stats("sly", 20) diff --git a/prof/profile_big_model.py b/prof/profile_big_model.py index b9701f39..8998d069 100644 --- a/prof/profile_big_model.py +++ b/prof/profile_big_model.py @@ -10,6 +10,6 @@ stats = pstats.Stats("prof/big.prof") stats.sort_stats(pstats.SortKey.CUMULATIVE, pstats.SortKey.TIME).print_stats( - 100, "montepy" + "montepy", 70 ) -stats.sort_stats(pstats.SortKey.CUMULATIVE, pstats.SortKey.TIME).print_stats(100, "sly") +stats.sort_stats(pstats.SortKey.CUMULATIVE, pstats.SortKey.TIME).print_stats("sly", 20) diff --git a/tests/test_material.py b/tests/test_material.py index 86d15bad..603b92c7 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -340,7 +340,8 @@ def test_thermal_scattering_format_mcnp(self): in_str = "M20 1001.80c 0.5 8016.80c 0.5" input_card = Input([in_str], BlockType.DATA) material = Material(input_card) - material.update_pointers([card]) + material.thermal_scattering = card + card._parent_material = material material.thermal_scattering.thermal_scattering_laws = ["grph.20t"] self.assertEqual(card.format_for_mcnp_input((6, 2, 0)), ["Mt20 grph.20t "])