diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py index 7de2cbab..e9352ebd 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py @@ -33,6 +33,7 @@ class MolToFingerprintPipelineElement(MolToAnyPipelineElement, abc.ABC): """Abstract class for PipelineElements which transform molecules to integer vectors.""" _n_bits: int + _feature_names: list[str] _output_type = "binary" _return_as: OutputDatatype @@ -71,6 +72,11 @@ def n_bits(self) -> int: """Get number of bits in (or size of) fingerprint.""" return self._n_bits + @property + def feature_names(self) -> list[str]: + """Get feature names.""" + return self._feature_names[:] + @overload def assemble_output( # type: ignore self, value_list: Iterable[npt.NDArray[np.int_]] diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py index 3c0711c3..d820e41f 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py @@ -29,6 +29,7 @@ class MolToDescriptorPipelineElement(MolToAnyPipelineElement): _standardizer: Optional[AnyTransformer] _output_type = "float" + _feature_names: list[str] def __init__( self, @@ -66,6 +67,11 @@ def __init__( def n_features(self) -> int: """Return the number of features.""" + @property + def feature_names(self) -> list[str]: + """Return a copy of the feature names.""" + return self._feature_names[:] + def assemble_output( self, value_list: Iterable[npt.NDArray[np.float64]], diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index 09c2e3db..f1db3742 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -11,6 +11,7 @@ import numpy as np import numpy.typing as npt +from loguru import logger from sklearn.base import clone from molpipeline.abstract_pipeline_elements.core import ( @@ -32,6 +33,7 @@ class MolToConcatenatedVector(MolToAnyPipelineElement): def __init__( self, element_list: list[tuple[str, MolToAnyPipelineElement]], + use_feature_names_prefix: bool = True, name: str = "MolToConcatenatedVector", n_jobs: int = 1, uuid: Optional[str] = None, @@ -43,6 +45,10 @@ def __init__( ---------- element_list: list[MolToAnyPipelineElement] List of Pipeline Elements of which the output is concatenated. + use_feature_names_prefix: bool, optional (default=True) + If True, will add the pipeline element's name as prefix to feature names. + If False, only the feature names are used. This can lead to duplicate + feature names. name: str, optional (default="MolToConcatenatedVector") name of pipeline. n_jobs: int, optional (default=1) @@ -53,17 +59,15 @@ def __init__( Additional keyword arguments. Can be used to set parameters of the pipeline elements. """ self._element_list = element_list + if len(element_list) == 0: + raise ValueError("element_list must contain at least one element.") + self._use_feature_names_prefix = use_feature_names_prefix super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - output_types = set() - for _, element in self._element_list: - element.n_jobs = self.n_jobs - output_types.add(element.output_type) - if len(output_types) == 1: - self._output_type = output_types.pop() - else: - self._output_type = "mixed" - self._requires_fitting = any( - element[1]._requires_fitting for element in element_list + # set element execution details + self._set_element_execution_details(self._element_list) + # set feature names + self._feature_names = self._create_feature_names( + self._element_list, self._use_feature_names_prefix ) self.set_params(**kwargs) @@ -82,11 +86,88 @@ def n_features(self) -> int: elif hasattr(element, "n_bits"): feature_count += element.n_bits else: - raise AssertionError( + raise ValueError( f"Element {element} does not have n_features or n_bits." ) return feature_count + @property + def feature_names(self) -> list[str]: + """Return the feature names of concatenated elements.""" + return self._feature_names[:] + + @staticmethod + def _create_feature_names( + element_list: list[tuple[str, MolToAnyPipelineElement]], + use_feature_names_prefix: bool, + ) -> list[str]: + """Create feature names for concatenated vector from its elements. + + Parameters + ---------- + element_list: list[tuple[str, MolToAnyPipelineElement]] + List of pipeline elements. + use_feature_names_prefix: bool + If True, will add the pipeline element's name as prefix to feature names. + If False, only the feature names are used. This can lead to duplicate + feature names. + + Raises + ------ + ValueError + If element does not have feature_names attribute. + + Returns + ------- + list[str] + List of feature names. + """ + feature_names = [] + for name, element in element_list: + if not hasattr(element, "feature_names"): + raise ValueError( + f"Element {element} does not have feature_names attribute." + ) + + if use_feature_names_prefix: + # use element name as prefix + feature_names.extend( + [f"{name}__{feature}" for feature in element.feature_names] # type: ignore[attr-defined] + ) + else: + feature_names.extend(element.feature_names) # type: ignore[attr-defined] + + if len(feature_names) != len(set(feature_names)): + logger.warning( + "Feature names in MolToConcatenatedVector are not unique." + " Set use_feature_names_prefix=True and use unique pipeline element" + " names to avoid this." + ) + return feature_names + + def _set_element_execution_details( + self, element_list: list[tuple[str, MolToAnyPipelineElement]] + ) -> None: + """Set output type and requires fitting for the concatenated vector. + + Parameters + ---------- + element_list: list[tuple[str, MolToAnyPipelineElement]] + List of pipeline elements. + """ + output_types = set() + for _, element in self._element_list: + element.n_jobs = self.n_jobs + output_types.add(element.output_type) + if len(output_types) == 1: + self._output_type = output_types.pop() + else: + self._output_type = "mixed" + self._requires_fitting = any( + element[1]._requires_fitting # pylint: disable=protected-access + for element in element_list + ) + def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. @@ -105,31 +186,47 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters["element_list"] = [ (str(name), clone(ele)) for name, ele in self.element_list ] + parameters["use_feature_names_prefix"] = bool( + self._use_feature_names_prefix + ) else: parameters["element_list"] = self.element_list + parameters["use_feature_names_prefix"] = self._use_feature_names_prefix for name, element in self.element_list: for key, value in element.get_params(deep=deep).items(): parameters[f"{name}__{key}"] = value return parameters - def set_params(self, **parameters: Any) -> Self: - """Set parameters. + def _set_element_list( + self, parameter_copy: dict[str, Any], **parameters: Any + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Set the element list and run necessary configurations. Parameters ---------- + parameter_copy: dict[str, Any] + Copy of parameters. parameters: Any - Parameters to set. + Original parameters. + + Raises + ------ + ValueError + If element_list is empty. Returns ------- - Self - Mol2ConcatenatedVector object with updated parameters. + tuple[dict[str, Any], dict[str, Any]] + Updated parameter_copy and parameters. """ - parameter_copy = dict(parameters) element_list = parameter_copy.pop("element_list", None) if element_list is not None: self._element_list = element_list + if len(element_list) == 0: + raise ValueError("element_list must contain at least one element.") + # reset element execution details + self._set_element_execution_details(self._element_list) step_params: dict[str, dict[str, Any]] = {} step_dict = dict(self._element_list) to_delete_list = [] @@ -150,6 +247,39 @@ def set_params(self, **parameters: Any) -> Self: _ = parameter_copy.pop(to_delete, None) for step, params in step_params.items(): step_dict[step].set_params(**params) + return parameter_copy, parameters + + def set_params(self, **parameters: Any) -> Self: + """Set parameters. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Mol2ConcatenatedVector object with updated parameters. + """ + parameter_copy = dict(parameters) + + # handle element_list + parameter_copy, parameters = self._set_element_list( + parameter_copy, **parameters + ) + + # handle use_feature_names_prefix + use_feature_names_prefix = parameter_copy.pop("use_feature_names_prefix", None) + if use_feature_names_prefix is not None: + self._use_feature_names_prefix = use_feature_names_prefix + # reset feature names + self._feature_names = self._create_feature_names( + self._element_list, + self._use_feature_names_prefix, # type: ignore[arg-type] + ) + + # set parameters of super super().set_params(**parameter_copy) return self diff --git a/molpipeline/mol2any/mol2maccs_key_fingerprint.py b/molpipeline/mol2any/mol2maccs_key_fingerprint.py index 9b70773e..e59006bb 100644 --- a/molpipeline/mol2any/mol2maccs_key_fingerprint.py +++ b/molpipeline/mol2any/mol2maccs_key_fingerprint.py @@ -21,6 +21,7 @@ class MolToMACCSFP(MolToFingerprintPipelineElement): """ _n_bits = 167 # MACCS keys have 166 bits + 1 bit for an all-zero vector (bit 0) + _feature_names = [f"maccs_{i}" for i in range(_n_bits)] def pretransform_single( self, value: RDKitMol diff --git a/molpipeline/mol2any/mol2morgan_fingerprint.py b/molpipeline/mol2any/mol2morgan_fingerprint.py index 36fca193..04e1941a 100644 --- a/molpipeline/mol2any/mol2morgan_fingerprint.py +++ b/molpipeline/mol2any/mol2morgan_fingerprint.py @@ -83,6 +83,7 @@ def __init__( f"Number of bits has to be a positve integer, which is > 0! (Received: {n_bits})" ) self._n_bits = n_bits + self._feature_names = [f"morgan_{i}" for i in range(self._n_bits)] def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. diff --git a/molpipeline/mol2any/mol2net_charge.py b/molpipeline/mol2any/mol2net_charge.py index 759ec84b..4c86c3e8 100644 --- a/molpipeline/mol2any/mol2net_charge.py +++ b/molpipeline/mol2any/mol2net_charge.py @@ -56,6 +56,7 @@ def __init__( UUID of the pipeline element, by default None """ self._descriptor_list = ["NetCharge"] + self._feature_names = self._descriptor_list self._charge_method = charge_method # pylint: disable=R0801 super().__init__( @@ -72,7 +73,7 @@ def n_features(self) -> int: @property def descriptor_list(self) -> list[str]: - """Return a copy of the descriptor list.""" + """Return a copy of the descriptor list. Alias of `feature_names`.""" return self._descriptor_list[:] def _get_net_charge_gasteiger( diff --git a/molpipeline/mol2any/mol2path_fingerprint.py b/molpipeline/mol2any/mol2path_fingerprint.py index 12934806..916ab2db 100644 --- a/molpipeline/mol2any/mol2path_fingerprint.py +++ b/molpipeline/mol2any/mol2path_fingerprint.py @@ -21,7 +21,7 @@ class Mol2PathFP( MolToRDKitGenFPElement ): # pylint: disable=too-many-instance-attributes - """Folded Morgan Fingerprint. + """Folded Path Fingerprint. Feature-mapping to vector-positions is arbitrary. @@ -99,9 +99,10 @@ def __init__( ) if not isinstance(n_bits, int) or n_bits < 1: raise ValueError( - f"Number of bits has to be a positve integer, which is > 0! (Received: {n_bits})" + f"Number of bits has to be a positive integer, which is > 0! (Received: {n_bits})" ) self._n_bits = n_bits + self._feature_names = [f"path_{i}" for i in range(self._n_bits)] self._min_path = min_path self._max_path = max_path self._use_hs = use_hs diff --git a/molpipeline/mol2any/mol2rdkit_phys_chem.py b/molpipeline/mol2any/mol2rdkit_phys_chem.py index 251b6b3c..9bb6ff9a 100644 --- a/molpipeline/mol2any/mol2rdkit_phys_chem.py +++ b/molpipeline/mol2any/mol2rdkit_phys_chem.py @@ -72,6 +72,7 @@ def __init__( UUID of the PipelineElement. If None, a new UUID is generated. """ self.descriptor_list = descriptor_list # type: ignore + self._feature_names = self._descriptor_list self._return_with_errors = return_with_errors self._log_exceptions = log_exceptions super().__init__( @@ -88,7 +89,7 @@ def n_features(self) -> int: @property def descriptor_list(self) -> list[str]: - """Return a copy of the descriptor list.""" + """Return a copy of the descriptor list. Alias of `feature_names`.""" return self._descriptor_list[:] @descriptor_list.setter diff --git a/tests/test_elements/test_mol2any/test_mol2concatenated.py b/tests/test_elements/test_mol2any/test_mol2concatenated.py index 5bb57742..1915a251 100644 --- a/tests/test_elements/test_mol2any/test_mol2concatenated.py +++ b/tests/test_elements/test_mol2any/test_mol2concatenated.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import unittest from typing import Any, Literal, get_args @@ -10,14 +11,18 @@ from sklearn.preprocessing import StandardScaler from molpipeline import Pipeline +from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import ( + Mol2PathFP, MolToConcatenatedVector, + MolToMACCSFP, MolToMorganFP, MolToNetCharge, MolToRDKitPhysChem, ) from tests.utils.fingerprints import fingerprints_to_numpy +from tests.utils.logging import capture_logs class TestConcatenatedFingerprint(unittest.TestCase): @@ -91,6 +96,24 @@ def test_generation(self) -> None: self.assertTrue(np.allclose(output, output2)) self.assertTrue(np.allclose(output, output3)) + def test_empty_element_list(self) -> None: + """Test if an empty element list raises an error.""" + # test constructor + with self.assertRaises(ValueError): + MolToConcatenatedVector([]) + + # test setter + concat_elem = MolToConcatenatedVector( + [ + ( + "RDKitPhysChem", + MolToRDKitPhysChem(), + ) + ] + ) + with self.assertRaises(ValueError): + concat_elem.set_params(element_list=[]) + def test_n_features(self) -> None: """Test getting the number of features in the concatenated vector.""" @@ -131,6 +154,167 @@ def test_n_features(self) -> None: net_charge_elem[1].n_features + 16 + physchem_elem[1].n_features, ) + def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals + """Test getting the names of features in the concatenated vector.""" + + physchem_elem = ( + "RDKitPhysChem", + MolToRDKitPhysChem(), + ) + net_charge_elem = ("NetCharge", MolToNetCharge()) + morgan_elem = ( + "MorganFP", + MolToMorganFP(n_bits=17), + ) + path_elem = ( + "PathFP", + Mol2PathFP(n_bits=15), + ) + maccs_elem = ( + "MACCSFP", + MolToMACCSFP(), + ) + + elements = [ + physchem_elem, + net_charge_elem, + morgan_elem, + path_elem, + maccs_elem, + ] + + for use_feature_names_prefix in [False, True]: + # test all subsets are compatible + powerset = itertools.chain.from_iterable( + itertools.combinations(elements, r) for r in range(len(elements) + 1) + ) + # skip empty subset + next(powerset) + + for elements_subset in powerset: + conc_elem = MolToConcatenatedVector( + list(elements_subset), + use_feature_names_prefix=use_feature_names_prefix, + ) + feature_names = conc_elem.feature_names + + if use_feature_names_prefix: + # test feature names are unique if prefix is used or only one element is used + self.assertEqual( + len(feature_names), + len(set(feature_names)), + ) + + # test a feature names and n_features are consistent + self.assertEqual( + len(feature_names), + conc_elem.n_features, + ) + + seen_names = 0 + for elem_name, elem in elements_subset: + self.assertTrue(hasattr(elem, "feature_names")) + elem_feature_names = elem.feature_names # type: ignore[attr-defined] + elem_n_features = len(elem_feature_names) + relevant_names = feature_names[ + seen_names : seen_names + elem_n_features + ] + + if use_feature_names_prefix: + # feature_names should be prefixed with element name + prefixes, feat_names = map( + list, zip(*[name.split("__") for name in relevant_names]) + ) + # test feature names are the same + self.assertListEqual(elem_feature_names, feat_names) + # test prefixes are the same as element names + self.assertTrue(all(prefix == elem_name for prefix in prefixes)) + else: + # feature_names should be the same as element feature names + self.assertListEqual(elem_feature_names, relevant_names) + + seen_names += elem_n_features + + def test_logging_feature_names_uniqueness(self) -> None: + """Test that a warning is logged when feature names are not unique.""" + elements: list[tuple[str, MolToAnyPipelineElement]] = [ + ( + "MorganFP", + MolToMorganFP(n_bits=17), + ), + ( + "MorganFP_with_feats", + MolToMorganFP(n_bits=16, use_features=True), + ), + ] + + # First test is with no prefix + use_feature_names_prefix = False + with capture_logs() as output: + conc_elem = MolToConcatenatedVector( + elements, + use_feature_names_prefix=use_feature_names_prefix, + ) + feature_names = conc_elem.feature_names + + # test log message + self.assertEqual(len(output), 1) + message = output[0] + self.assertIn( + "Feature names in MolToConcatenatedVector are not unique", message + ) + self.assertEqual(message.record["level"].name, "WARNING") + + # test feature names are NOT unique + self.assertNotEqual(len(feature_names), len(set(feature_names))) + + # Second test is with prefix + use_feature_names_prefix = True + with capture_logs() as output: + conc_elem = MolToConcatenatedVector( + elements, + use_feature_names_prefix=use_feature_names_prefix, + ) + feature_names = conc_elem.feature_names + + # test log message + self.assertEqual(len(output), 0) + + # test feature names are unique + self.assertEqual(len(feature_names), len(set(feature_names))) + + def test_getter_setter(self) -> None: + """Test getter and setter methods.""" + elements: list[tuple[str, MolToAnyPipelineElement]] = [ + ( + "MorganFP", + MolToMorganFP(n_bits=17), + ), + ( + "MorganFP_with_feats", + MolToMorganFP(n_bits=16, use_features=True), + ), + ] + concat_elem = MolToConcatenatedVector( + elements, + use_feature_names_prefix=True, + ) + self.assertEqual(len(concat_elem.get_params()["element_list"]), 2) + self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], True) + # test that there are no duplicates in feature names + self.assertEqual( + len(concat_elem.feature_names), len(set(concat_elem.feature_names)) + ) + params: dict[str, Any] = { + "use_feature_names_prefix": False, + } + concat_elem.set_params(**params) + self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], False) + # test that there are duplicates in feature names + self.assertNotEqual( + len(concat_elem.feature_names), len(set(concat_elem.feature_names)) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py index 553c28ec..6ec46b19 100644 --- a/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py @@ -99,6 +99,14 @@ def test_setter_getter_error_handling(self) -> None: } self.assertRaises(ValueError, mol_fp.set_params, **params) + def test_feature_names(self) -> None: + """Test if the feature names are correct.""" + mol_fp = MolToMACCSFP() + feature_names = mol_fp.feature_names + self.assertEqual(len(feature_names), mol_fp.n_bits) + # feature names should be unique + self.assertEqual(len(feature_names), len(set(feature_names))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py index 2b765949..5b3977d2 100644 --- a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py @@ -149,6 +149,14 @@ def test_bit2atom_mapping(self) -> None: np_fp = fingerprints_to_numpy(fp) self.assertEqual(np.nonzero(np_fp)[0].shape[0], len(mapping)) # type: ignore + def test_feature_names(self) -> None: + """Test if the feature names are correct.""" + mol_fp = MolToMorganFP(n_bits=1024) + feature_names = mol_fp.feature_names + self.assertEqual(len(feature_names), 1024) + # feature names should be unique + self.assertEqual(len(feature_names), len(set(feature_names))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py index 691abfb9..ad21a7e0 100644 --- a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py @@ -142,6 +142,14 @@ def test_setter_getter_error_handling(self) -> None: } self.assertRaises(ValueError, mol_fp.set_params, **params) + def test_feature_names(self) -> None: + """Test if the feature names are correct.""" + mol_fp = Mol2PathFP(n_bits=1024) + feature_names = mol_fp.feature_names + self.assertEqual(len(feature_names), 1024) + # feature names should be unique + self.assertEqual(len(feature_names), len(set(feature_names))) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/logging.py b/tests/utils/logging.py new file mode 100644 index 00000000..cef2f471 --- /dev/null +++ b/tests/utils/logging.py @@ -0,0 +1,41 @@ +"""Test utils for logging.""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Generator + +import loguru +from loguru import logger + + +@contextmanager +def capture_logs( + level: str = "INFO", log_format: str = "{level}:{name}:{message}" +) -> Generator[list[loguru.Message], None, None]: + """Capture loguru-based logs. + + Custom context manager to test loguru-based logs. For details and usage examples, + see https://loguru.readthedocs.io/en/latest/resources/migration.html#replacing-assertlogs-method-from-unittest-library + + Parameters + ---------- + level : str, optional + Log level, by default "INFO" + log_format : str, optional + Log format, by default "{level}:{name}:{message}" + + Yields + ------- + list[loguru.Message] + List of log messages + + Returns + ------- + Generator[list[loguru.Message], None, None] + List of log messages + """ + output: list[loguru.Message] = [] + handler_id = logger.add(output.append, level=level, format=log_format) + yield output + logger.remove(handler_id)