From cf32f9a93aff515e90e9adb361116fac02be783d Mon Sep 17 00:00:00 2001 From: David Ormrod Morley Date: Mon, 9 Dec 2024 15:40:54 +0100 Subject: [PATCH] Fix mypy issues SCMSUITE-10131 SO107 --- core/basejob.py | 12 ++++++------ interfaces/adfsuite/ams.py | 6 +++--- mol/molecule.py | 4 ++-- unit_tests/test_molecule.py | 9 +++++++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/core/basejob.py b/core/basejob.py index 9a43d51d..e53efc35 100644 --- a/core/basejob.py +++ b/core/basejob.py @@ -4,7 +4,7 @@ import threading import time from os.path import join as opj -from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Union, Callable, Any +from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Union from abc import ABC, abstractmethod import traceback @@ -30,7 +30,7 @@ __all__ = ["SingleJob", "MultiJob"] -def _fail_on_exception(func: Callable[["Job", Any], Any]) -> Callable[["Job", Any], Any]: +def _fail_on_exception(func): """Decorator to wrap a job method and mark the job as failed on any exception.""" def wrapper(self: "Job", *args, **kwargs): @@ -39,11 +39,11 @@ def wrapper(self: "Job", *args, **kwargs): except: # Mark job status as failed and the results as complete self.status = JobStatus.FAILED - self.results.finished.set() - self.results.done.set() + self.results.finished.set() # type: ignore + self.results.done.set() # type: ignore # Notify any parent multi-job of the failure - if self.parent and self in self.parent: - self.parent._notify() + if self.parent and self in self.parent: # type: ignore + self.parent._notify() # type: ignore # Store the exception message to be accessed from get_errormsg self._error_msg = traceback.format_exc() diff --git a/interfaces/adfsuite/ams.py b/interfaces/adfsuite/ams.py index 762f3229..e7480651 100644 --- a/interfaces/adfsuite/ams.py +++ b/interfaces/adfsuite/ams.py @@ -1802,8 +1802,8 @@ def get_density_along_axis( start_step, end_step, every, _ = self._get_integer_start_end_every_max(start_fs, end_fs, every_fs, None) nEntries = self.readrkf("History", "nEntries") - coords = np.array(self.get_history_property("Coords")).reshape(nEntries, -1, 3) - coords = coords[start_step:end_step:every] + history_coords = np.array(self.get_history_property("Coords")).reshape(nEntries, -1, 3) + coords = history_coords[start_step:end_step:every] nEntries = len(coords) axis2index = {"x": 0, "y": 1, "z": 2} @@ -2507,7 +2507,7 @@ def get_errormsg(self) -> Optional[str]: try: log_err_lines = self.results.grep_file("ams.log", "ERROR: ") if log_err_lines: - self._error_msg = log_err_lines[-1].partition("ERROR: ")[2] + self._error_msg: Optional[str] = log_err_lines[-1].partition("ERROR: ")[2] return self._error_msg except FileError: pass diff --git a/mol/molecule.py b/mol/molecule.py index 4f82d1db..922740ae 100644 --- a/mol/molecule.py +++ b/mol/molecule.py @@ -1487,7 +1487,7 @@ def get_fragment(self, indices): return ret - def get_complete_molecules_within_threshold(self, atom_indices, threshold: float): + def get_complete_molecules_within_threshold(self, atom_indices: List[int], threshold: float): """ Returns a new molecule containing complete submolecules for any molecules that are closer than ``threshold`` to any of the atoms in ``atom_indices``. @@ -1508,7 +1508,7 @@ def get_complete_molecules_within_threshold(self, atom_indices, threshold: float zero_based_indices = [x - 1 for x in atom_indices] D = distance_array(solvated_coords, solvated_coords)[zero_based_indices] less_equal = np.less_equal(D, threshold) - within_threshold = np.any(less_equal, axis=0) + within_threshold = np.any(less_equal, axis=0) # type: ignore good_indices = [i for i, value in enumerate(within_threshold) if value] complete_indices: Set[int] = set() diff --git a/unit_tests/test_molecule.py b/unit_tests/test_molecule.py index 564f36ce..1159db9c 100644 --- a/unit_tests/test_molecule.py +++ b/unit_tests/test_molecule.py @@ -216,6 +216,15 @@ def test_system_and_atomic_charge(self, mol): with pytest.raises(MoleculeError): assert mol.guess_atomic_charges() == [1, 1, 0, 0] + def test_get_complete_molecules_within_threshold(self, mol): + m0 = mol.get_complete_molecules_within_threshold([2], 0) + m1 = mol.get_complete_molecules_within_threshold([2], 1) + m2 = mol.get_complete_molecules_within_threshold([2], 2) + + assert m0.get_formula() == "H" + assert m1.get_formula() == "HO" + assert m2.get_formula() == "H2O" + class TestNiO(MoleculeTestBase): """