From 47d4d90b6c6c364d47f0603e792a559d1491af6e Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 1 Oct 2024 11:25:34 +0200 Subject: [PATCH] review Christian --- .../mol2mol/filter.py | 30 +-- molpipeline/mol2mol/filter.py | 64 +++-- molpipeline/utils/molpipeline_types.py | 22 +- molpipeline/utils/value_conversions.py | 12 +- .../test_mol2mol/test_mol2mol_filter.py | 221 +++++++++--------- 5 files changed, 201 insertions(+), 148 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 951fec8a..50caecc5 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -14,12 +14,12 @@ OptionalMol, RDKitMol, ) -from molpipeline.utils.value_conversions import ( +from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, IntOrIntCountRange, - count_value_to_tuple, ) +from molpipeline.utils.value_conversions import count_value_to_tuple # possible mode types for a KeepMatchesFilter: # - "any" means one match is enough @@ -28,7 +28,7 @@ def _within_boundaries( - lower_bound: Optional[float], upper_bound: Optional[float], value: float + lower_bound: Optional[float], upper_bound: Optional[float], property: float ) -> bool: """Check if a value is within the specified boundaries. @@ -40,17 +40,17 @@ def _within_boundaries( Lower boundary. upper_bound: Optional[float] Upper boundary. - value: float - Value to check. + property: float + Property to check. Returns ------- bool True if the value is within the boundaries, else False. """ - if lower_bound is not None and value < lower_bound: + if lower_bound is not None and property < lower_bound: return False - if upper_bound is not None and value > upper_bound: + if upper_bound is not None and property > upper_bound: return False return True @@ -167,13 +167,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params = super().get_params(deep=deep) params["keep_matches"] = self.keep_matches params["mode"] = self.mode - if deep: - params["filter_elements"] = { - element: (count_tuple[0], count_tuple[1]) - for element, count_tuple in self.filter_elements.items() - } - else: - params["filter_elements"] = self.filter_elements + params["filter_elements"] = self.filter_elements return params def pretransform_single(self, value: RDKitMol) -> OptionalMol: @@ -195,9 +189,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule that matches defined filter elements, else InvalidInstance. """ - for filter_element, (min_count, max_count) in self.filter_elements.items(): - count = self._calculate_single_element_value(filter_element, value) - if _within_boundaries(min_count, max_count, count): + for filter_element, (lower_limit, upper_limit) in self.filter_elements.items(): + property = self._calculate_single_element_value(filter_element, value) + if _within_boundaries(lower_limit, upper_limit, property): # For "any" mode we can return early if a match is found if self.mode == "any": if not self.keep_matches: @@ -265,7 +259,7 @@ def _calculate_single_element_value( class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns. - Parameters + Attributes ---------- filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] List of patterns to allow in molecules. diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index f16d15d7..6d6b6f8a 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -5,11 +5,14 @@ from collections import Counter from typing import Any, Mapping, Optional, Sequence, Union +from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries + try: from typing import Self # type: ignore[attr-defined] except ImportError: from typing_extensions import Self +from loguru import logger from rdkit import Chem from rdkit.Chem import Descriptors @@ -23,13 +26,14 @@ from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) -from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol -from molpipeline.utils.value_conversions import ( +from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, IntOrIntCountRange, - count_value_to_tuple, + OptionalMol, + RDKitMol, ) +from molpipeline.utils.value_conversions import count_value_to_tuple class ElementFilter(_MolToMolPipelineElement): @@ -60,6 +64,7 @@ def __init__( allowed_element_numbers: Optional[ Union[list[int], dict[int, IntOrIntCountRange]] ] = None, + add_hydrogens: bool = True, name: str = "ElementFilter", n_jobs: int = 1, uuid: Optional[str] = None, @@ -72,6 +77,8 @@ def __init__( List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum + add_hydrogens: bool, optional (default: True) + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. name: str, optional (default: "ElementFilterPipe") Name of the pipeline element. n_jobs: int, optional (default: 1) @@ -81,6 +88,32 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.allowed_element_numbers = allowed_element_numbers # type: ignore + self.add_hydrogens = add_hydrogens + + @property + def add_hydrogens(self) -> bool: + """Get add_hydrogens.""" + return self._add_hydrogens + + @add_hydrogens.setter + def add_hydrogens(self, add_hydrogens: bool) -> None: + """Set add_hydrogens. + + Parameters + ---------- + add_hydrogens: bool + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. + """ + self._add_hydrogens = add_hydrogens + if self.add_hydrogens and 1 in self.allowed_element_numbers: + self.process_hydrogens = True + else: + if 1 in self.allowed_element_numbers: + logger.warning( + "Hydrogens are included in allowed_element_numbers, but add_hydrogens is set to False. " + "Thus hydrogens are NOT added before filtering. You might receive unexpected results." + ) + self.process_hydrogens = False @property def allowed_element_numbers(self) -> dict[int, IntCountRange]: @@ -135,6 +168,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["allowed_element_numbers"] = self.allowed_element_numbers + params["add_hydrogens"] = self.add_hydrogens return params def set_params(self, **parameters: Any) -> Self: @@ -153,6 +187,8 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "allowed_element_numbers" in parameter_copy: self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers") + if "add_hydrogens" in parameter_copy: + self.add_hydrogens = parameter_copy.pop("add_hydrogens") super().set_params(**parameter_copy) return self @@ -169,10 +205,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule if it contains only allowed elements, else InvalidInstance. """ - to_process_value = ( - Chem.AddHs(value) if 1 in self.allowed_element_numbers else value - ) - + to_process_value = Chem.AddHs(value) if self.process_hydrogens else value elements_list = [atom.GetAtomicNum() for atom in to_process_value.GetAtoms()] elements_counter = Counter(elements_list) if any( @@ -181,11 +214,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: return InvalidInstance( self.uuid, "Molecule contains forbidden chemical element.", self.name ) - for element, (min_count, max_count) in self.allowed_element_numbers.items(): + for element, (lower_limit, upper_limit) in self.allowed_element_numbers.items(): count = elements_counter[element] - if (min_count is not None and count < min_count) or ( - max_count is not None and count > max_count - ): + if not _within_boundaries(lower_limit, upper_limit, count): return InvalidInstance( self.uuid, f"Molecule contains forbidden number of element {element}.", @@ -225,6 +256,11 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: class SmilesFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMILES patterns. + In contrast to the SMARTSFilter, which also can match SMILES, the SmilesFilter + sanitizes the molecules and, e.g. checks kekulized bonds for aromaticity and + then sets it to aromatic while the SmartsFilter detects alternating single and + double bonds. + Notes ----- There are four possible scenarios: @@ -253,7 +289,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: class ComplexFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on multiple filter elements. - Parameters + Attributes ---------- filter_elements: Sequence[_MolToMolPipelineElement] MolToMol elements to use as filters. @@ -317,7 +353,7 @@ def _calculate_single_element_value( class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors. - Parameters + Attributes ---------- filter_elements: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. @@ -347,11 +383,11 @@ def filter_elements(self, descriptors: dict[str, FloatCountRange]) -> None: descriptors: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. """ - self._filter_elements = descriptors if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors): raise ValueError( "You are trying to use an invalid descriptor. Use RDKit Descriptors module." ) + self._filter_elements = descriptors def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index ff59fdec..e17b48b3 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -3,7 +3,17 @@ from __future__ import annotations from numbers import Number -from typing import Any, List, Literal, Optional, Protocol, Tuple, TypeVar, Union +from typing import ( + Any, + List, + Literal, + Optional, + Protocol, + Tuple, + TypeAlias, + TypeVar, + Union, +) try: from typing import Self # type: ignore[attr-defined] @@ -47,6 +57,16 @@ TypeConserverdIterable = TypeVar("TypeConserverdIterable", List[_T], npt.NDArray[_T]) +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] + +IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] + +# IntOrIntCountRange for Typing of count ranges +# - a single int for an exact value match +# - a range given as a tuple with a lower and upper bound +# - both limits are optional +IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] + class AnySklearnEstimator(Protocol): """Protocol for sklearn estimators.""" diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index df595a84..4206e97f 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -1,16 +1,8 @@ """Module for utilities converting values.""" -from typing import Optional, Sequence, TypeAlias, Union +from typing import Sequence -FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] - -IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] - -# IntOrIntCountRange for Typing of count ranges -# - a single int for an exact value match -# - a range given as a tuple with a lower and upper bound -# - both limits are optional -IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] +from molpipeline.utils.molpipeline_types import IntCountRange, IntOrIntCountRange def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 5e353b7a..d814c4e9 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -14,7 +14,7 @@ SmartsFilter, SmilesFilter, ) -from molpipeline.utils.value_conversions import FloatCountRange, IntOrIntCountRange +from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -63,15 +63,29 @@ def test_element_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] - ) - pipeline.set_params( - ElementFilter__allowed_element_numbers={6: 6, 1: (5, 6), 17: (0, 1)} - ) - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": { + "ElementFilter__allowed_element_numbers": { + 6: 6, + 1: (5, 6), + 17: (0, 1), + } + }, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + {"params": {"ElementFilter__add_hydrogens": False}, "result": []}, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_complex_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" @@ -120,34 +134,44 @@ def test_smarts_smiles_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] - ) - - pipeline.set_params(SmartsFilter__keep_matches=False) - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_ANTIMONY, SMILES_METAL_AU]) - pipeline.set_params( - SmartsFilter__mode="all", SmartsFilter__keep_matches=True - ) - filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) - - pipeline.set_params( - SmartsFilter__keep_matches=True, SmartsFilter__filter_elements=["I"] - ) - filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_4, []) + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": {"SmartsFilter__keep_matches": False}, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, + { + "params": { + "SmartsFilter__mode": "all", + "SmartsFilter__keep_matches": True, + }, + "result": [SMILES_CHLOROBENZENE], + }, + { + "params": { + "SmartsFilter__keep_matches": True, + "SmartsFilter__filter_elements": ["I"], + }, + "result": [], + }, + { + "params": { + "SmartsFilter__keep_matches": False, + "SmartsFilter__mode": "any", + "SmartsFilter__filter_elements": new_input_as_list, + }, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, + ] - pipeline.set_params( - SmartsFilter__keep_matches=False, - SmartsFilter__mode="any", - SmartsFilter__filter_elements=new_input_as_list, - ) - filtered_smiles_5 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_5, [SMILES_ANTIMONY, SMILES_METAL_AU]) + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" @@ -191,75 +215,62 @@ def test_descriptor_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles, SMILES_LIST) - - pipeline.set_params(DescriptorsFilter__mode="all") - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_CL_BR]) - - pipeline.set_params(DescriptorsFilter__keep_matches=False) - filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles_3, - [SMILES_ANTIMONY, SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU], - ) - - pipeline.set_params(DescriptorsFilter__mode="any") - filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_4, []) - - pipeline.set_params( - DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True - ) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (2.00, 4), - } - ) - result_lower_exact = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_exact, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1.99, 4), - } - ) - result_lower_in_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_in_bound, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (2.01, 4), - } - ) - result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_out_bound, []) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 2.00), - } - ) - result_upper_exact = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_exact, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 2.01), - } - ) - result_upper_in_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_in_bound, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 1.99), - } - ) - result_upper_out_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_out_bound, []) + test_params_list_with_results = [ + {"params": {}, "result": SMILES_LIST}, + {"params": {"DescriptorsFilter__mode": "all"}, "result": [SMILES_CL_BR]}, + { + "params": {"DescriptorsFilter__keep_matches": False}, + "result": [ + SMILES_ANTIMONY, + SMILES_BENZENE, + SMILES_CHLOROBENZENE, + SMILES_METAL_AU, + ], + }, + {"params": {"DescriptorsFilter__mode": "any"}, "result": []}, + { + "params": { + "DescriptorsFilter__keep_matches": True, + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.00, 4)}, + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)} + }, + "result": [], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)} + }, + "result": [], + }, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated."""