Skip to content

Commit

Permalink
linting and fix element number test
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-sandfort1 committed Aug 20, 2024
1 parent 476d65a commit f14b71a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 57 deletions.
2 changes: 1 addition & 1 deletion molpipeline/abstract_pipeline_elements/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def patterns(
List of patterns.
"""
self._patterns: dict[str, tuple[Optional[int], Optional[int]]]
if isinstance(patterns, list) or isinstance(patterns, set):
if isinstance(patterns, (list, set)):
self._patterns = {pat: (1, None) for pat in patterns}
else:
self._patterns = {}
Expand Down
85 changes: 42 additions & 43 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def allowed_element_numbers(
self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]]
if allowed_element_numbers is None:
allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS
if isinstance(allowed_element_numbers, list) or isinstance(
allowed_element_numbers, set
):
if isinstance(allowed_element_numbers, (list, set)):
self._allowed_element_numbers = {
atom_number: (1, None) for atom_number in allowed_element_numbers
}
Expand Down Expand Up @@ -190,6 +188,12 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
elements_list = [atom.GetAtomicNum() for atom in value.GetAtoms()]
elements_count_dict = _list_to_dict_with_counts(elements_list)
for element, count in elements_count_dict.items():
if element not in self.allowed_element_numbers:
return InvalidInstance(
self.uuid,
f"Molecule contains forbidden element {element}.",
self.name,
)
min_count, max_count = self.allowed_element_numbers[element]
if (min_count is not None and count < min_count) or (
max_count is not None and count > max_count
Expand Down Expand Up @@ -253,8 +257,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
self.name,
)
)
else:
match_counts += 1
match_counts += 1
if self.mode == "any":
return (
value
Expand All @@ -265,27 +268,25 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
self.name,
)
)
else:
if match_counts == len(self.patterns):
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule matches one of the SmartsFilter patterns.",
self.name,
)
)
else:
return (
value
if not self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the SmartsFilter patterns.",
self.name,
)
if match_counts == len(self.patterns):
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule matches one of the SmartsFilter patterns.",
self.name,
)
)
return (
value
if not self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the SmartsFilter patterns.",
self.name,
)
)


class SmilesFilter(_BasePatternsFilter):
Expand Down Expand Up @@ -340,16 +341,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
self.name,
)
)
else:
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the SmilesFilter patterns.",
self.name,
)
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the SmilesFilter patterns.",
self.name,
)
)


class DescriptorsFilter(_MolToMolPipelineElement):
Expand Down Expand Up @@ -484,16 +484,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
self.name,
)
)
else:
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the DescriptorsFilter descriptors.",
self.name,
)
return (
value
if self.keep
else InvalidInstance(
self.uuid,
"Molecule does not match all of the DescriptorsFilter descriptors.",
self.name,
)
)


class MixtureFilter(_MolToMolPipelineElement):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_elements/test_mol2mol/test_mol2mol_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ class MolFilterTest(unittest.TestCase):
def test_element_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
default_atoms_dict = {
1: (None, None),
5: (None, None),
6: (None, None),
7: (None, None),
8: (None, None),
9: (None, None),
14: (None, None),
15: (None, None),
16: (None, None),
17: (None, None),
34: (None, None),
35: (None, None),
53: (None, None),
1: (1, None),
5: (1, None),
6: (1, None),
7: (1, None),
8: (1, None),
9: (1, None),
14: (1, None),
15: (1, None),
16: (1, None),
17: (1, None),
34: (1, None),
35: (1, None),
53: (1, None),
}

element_filter = ElementFilter()
Expand Down

0 comments on commit f14b71a

Please sign in to comment.