Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 20, 2024
1 parent c4c9904 commit 6e405b0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
44 changes: 32 additions & 12 deletions molpipeline/mol2any/mol2concatinated_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _set_element_execution_details(
else:
self._output_type = "mixed"
self._requires_fitting = any(
element[1]._requires_fitting for element in element_list
element[1]._requires_fitting for element in element_list # type: ignore[protected-access]
)

def get_params(self, deep: bool = True) -> dict[str, Any]:
Expand Down Expand Up @@ -197,22 +197,21 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:

return parameters

def set_params(self, **parameters: Any) -> Self:
"""Set parameters.
def _set_element_list(
self, parameter_copy: dict[str, Any], **parameters: Any
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Set the element list and run necessary configurations.
Parameters
----------
parameters: Any
Parameters to set.
element_list: list[tuple[str, MolToAnyPipelineElement]]
List of pipeline elements.
Returns
-------
Self
Mol2ConcatenatedVector object with updated parameters.
Raises
------
ValueError
If element_list is empty.
"""
parameter_copy = dict(parameters)

# handle element_list
element_list = parameter_copy.pop("element_list", None)
if element_list is not None:
self._element_list = element_list
Expand Down Expand Up @@ -240,6 +239,27 @@ def set_params(self, **parameters: Any) -> Self:
_ = parameter_copy.pop(to_delete, None)
for step, params in step_params.items():
step_dict[step].set_params(**params)
return parameter_copy, parameters

def set_params(self, **parameters: Any) -> Self:
"""Set parameters.
Parameters
----------
parameters: Any
Parameters to set.
Returns
-------
Self
Mol2ConcatenatedVector object with updated parameters.
"""
parameter_copy = dict(parameters)

# handle element_list
parameter_copy, parameters = self._set_element_list(
parameter_copy, **parameters
)

# handle use_feature_names_prefix
use_feature_names_prefix = parameter_copy.pop("use_feature_names_prefix", None)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_elements/test_mol2any/test_mol2concatenated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.preprocessing import StandardScaler

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement
from molpipeline.any2mol import SmilesToMol
from molpipeline.mol2any import (
MolToConcatenatedVector,
Expand Down Expand Up @@ -234,7 +235,7 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals

def test_logging_feature_names_uniqueness(self) -> None:
"""Test that a warning is logged when feature names are not unique."""
elements = [
elements: list[tuple[str, MolToAnyPipelineElement]] = [
(
"MorganFP",
MolToMorganFP(n_bits=17),
Expand Down Expand Up @@ -282,7 +283,7 @@ def test_logging_feature_names_uniqueness(self) -> None:

def test_getter_setter(self) -> None:
"""Test getter and setter methods."""
elements = [
elements: list[tuple[str, MolToAnyPipelineElement]] = [
(
"MorganFP",
MolToMorganFP(n_bits=17),
Expand Down
14 changes: 11 additions & 3 deletions tests/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Test utils for logging."""

from __future__ import annotations

from contextlib import contextmanager
from typing import Generator

import loguru
from loguru import logger


@contextmanager
def capture_logs(
level="INFO", format="{level}:{name}:{message}"
) -> Generator[list[str], None, None]:
) -> Generator[list[loguru.Message], None, None]: # ign
"""Capture loguru-based logs.
Custom context manager to test loguru-based logs. For details and usage examples,
Expand All @@ -24,10 +27,15 @@ def capture_logs(
Yields
-------
list[str]
list[loguru.Message]
List of log messages
Returns
-------
Generator[list[loguru.Message], None, None]
List of log messages
"""
output: list[str] = []
output: list[loguru.Message] = []
handler_id = logger.add(output.append, level=level, format=format)
yield output
logger.remove(handler_id)

0 comments on commit 6e405b0

Please sign in to comment.