diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 191b3ddd..acfbf0c0 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -70,7 +70,7 @@ def patterns( List of patterns. """ self._patterns: dict[str, tuple[Optional[int], Optional[int]]] - if isinstance(patterns, list) or isinstance(patterns, set): + if isinstance(patterns, (list, set)): self._patterns = {pat: (1, None) for pat in patterns} else: self._patterns = {} diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 55edde5b..95adb85f 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -118,9 +118,7 @@ def allowed_element_numbers( self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS - if isinstance(allowed_element_numbers, list) or isinstance( - allowed_element_numbers, set - ): + if isinstance(allowed_element_numbers, (list, set)): self._allowed_element_numbers = { atom_number: (1, None) for atom_number in allowed_element_numbers } @@ -190,6 +188,12 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: elements_list = [atom.GetAtomicNum() for atom in value.GetAtoms()] elements_count_dict = _list_to_dict_with_counts(elements_list) for element, count in elements_count_dict.items(): + if element not in self.allowed_element_numbers: + return InvalidInstance( + self.uuid, + f"Molecule contains forbidden element {element}.", + self.name, + ) min_count, max_count = self.allowed_element_numbers[element] if (min_count is not None and count < min_count) or ( max_count is not None and count > max_count @@ -253,8 +257,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - match_counts += 1 + match_counts += 1 if self.mode == "any": return ( value @@ -265,27 +268,25 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - if match_counts == len(self.patterns): - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule matches one of the SmartsFilter patterns.", - self.name, - ) - ) - else: - return ( - value - if not self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmartsFilter patterns.", - self.name, - ) + if match_counts == len(self.patterns): + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule matches one of the SmartsFilter patterns.", + self.name, ) + ) + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmartsFilter patterns.", + self.name, + ) + ) class SmilesFilter(_BasePatternsFilter): @@ -340,16 +341,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmilesFilter patterns.", - self.name, - ) + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmilesFilter patterns.", + self.name, ) + ) class DescriptorsFilter(_MolToMolPipelineElement): @@ -484,16 +484,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the DescriptorsFilter descriptors.", - self.name, - ) + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the DescriptorsFilter descriptors.", + self.name, ) + ) class MixtureFilter(_MolToMolPipelineElement): diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index ad79ea40..8b2ecf5b 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -37,19 +37,19 @@ class MolFilterTest(unittest.TestCase): def test_element_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" default_atoms_dict = { - 1: (None, None), - 5: (None, None), - 6: (None, None), - 7: (None, None), - 8: (None, None), - 9: (None, None), - 14: (None, None), - 15: (None, None), - 16: (None, None), - 17: (None, None), - 34: (None, None), - 35: (None, None), - 53: (None, None), + 1: (1, None), + 5: (1, None), + 6: (1, None), + 7: (1, None), + 8: (1, None), + 9: (1, None), + 14: (1, None), + 15: (1, None), + 16: (1, None), + 17: (1, None), + 34: (1, None), + 35: (1, None), + 53: (1, None), } element_filter = ElementFilter()