diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index bd0485f7..1ffff643 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -164,7 +164,7 @@ def _set_element_execution_details( else: self._output_type = "mixed" self._requires_fitting = any( - element[1]._requires_fitting for element in element_list + element[1]._requires_fitting for element in element_list # type: ignore[protected-access] ) def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -197,22 +197,21 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return parameters - def set_params(self, **parameters: Any) -> Self: - """Set parameters. + def _set_element_list( + self, parameter_copy: dict[str, Any], **parameters: Any + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Set the element list and run necessary configurations. Parameters ---------- - parameters: Any - Parameters to set. + element_list: list[tuple[str, MolToAnyPipelineElement]] + List of pipeline elements. - Returns - ------- - Self - Mol2ConcatenatedVector object with updated parameters. + Raises + ------ + ValueError + If element_list is empty. """ - parameter_copy = dict(parameters) - - # handle element_list element_list = parameter_copy.pop("element_list", None) if element_list is not None: self._element_list = element_list @@ -240,6 +239,27 @@ def set_params(self, **parameters: Any) -> Self: _ = parameter_copy.pop(to_delete, None) for step, params in step_params.items(): step_dict[step].set_params(**params) + return parameter_copy, parameters + + def set_params(self, **parameters: Any) -> Self: + """Set parameters. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Mol2ConcatenatedVector object with updated parameters. + """ + parameter_copy = dict(parameters) + + # handle element_list + parameter_copy, parameters = self._set_element_list( + parameter_copy, **parameters + ) # handle use_feature_names_prefix use_feature_names_prefix = parameter_copy.pop("use_feature_names_prefix", None) diff --git a/tests/test_elements/test_mol2any/test_mol2concatenated.py b/tests/test_elements/test_mol2any/test_mol2concatenated.py index 9212b18c..967e35d7 100644 --- a/tests/test_elements/test_mol2any/test_mol2concatenated.py +++ b/tests/test_elements/test_mol2any/test_mol2concatenated.py @@ -11,6 +11,7 @@ from sklearn.preprocessing import StandardScaler from molpipeline import Pipeline +from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import ( MolToConcatenatedVector, @@ -234,7 +235,7 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals def test_logging_feature_names_uniqueness(self) -> None: """Test that a warning is logged when feature names are not unique.""" - elements = [ + elements: list[tuple[str, MolToAnyPipelineElement]] = [ ( "MorganFP", MolToMorganFP(n_bits=17), @@ -282,7 +283,7 @@ def test_logging_feature_names_uniqueness(self) -> None: def test_getter_setter(self) -> None: """Test getter and setter methods.""" - elements = [ + elements: list[tuple[str, MolToAnyPipelineElement]] = [ ( "MorganFP", MolToMorganFP(n_bits=17), diff --git a/tests/utils/logging.py b/tests/utils/logging.py index 3d5948c6..89a425e9 100644 --- a/tests/utils/logging.py +++ b/tests/utils/logging.py @@ -1,15 +1,18 @@ """Test utils for logging.""" +from __future__ import annotations + from contextlib import contextmanager from typing import Generator +import loguru from loguru import logger @contextmanager def capture_logs( level="INFO", format="{level}:{name}:{message}" -) -> Generator[list[str], None, None]: +) -> Generator[list[loguru.Message], None, None]: # ign """Capture loguru-based logs. Custom context manager to test loguru-based logs. For details and usage examples, @@ -24,10 +27,15 @@ def capture_logs( Yields ------- - list[str] + list[loguru.Message] + List of log messages + + Returns + ------- + Generator[list[loguru.Message], None, None] List of log messages """ - output: list[str] = [] + output: list[loguru.Message] = [] handler_id = logger.add(output.append, level=level, format=format) yield output logger.remove(handler_id)