Skip to content

Commit

Permalink
Init refactoring (#70)
Browse files Browse the repository at this point in the history
* remove unnecessary inits and refactor

* Fix wrong typing that caused thousands of type ignores

* linting

* set correct typing
  • Loading branch information
frederik-sandfort1 authored Aug 21, 2024
1 parent 0a5a53e commit 5b52a3b
Show file tree
Hide file tree
Showing 23 changed files with 67 additions and 508 deletions.
75 changes: 10 additions & 65 deletions molpipeline/abstract_pipeline_elements/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,23 @@ class ABCPipelineElement(abc.ABC):

def __init__(
self,
name: str = "ABCPipelineElement",
name: Optional[str] = None,
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize ABCPipelineElement.
Parameters
----------
name: str
name: Optional[str], optional (default=None)
Name of PipelineElement
n_jobs: int
Number of cores used for processing.
uuid: Optional[str]
Unique identifier of the PipelineElement.
"""
if name is None:
name = self.__class__.__name__
self.name = name
self.n_jobs = n_jobs
if uuid is None:
Expand Down Expand Up @@ -182,12 +184,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
"uuid": self.uuid,
}

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""As the setter function cannot be assessed with super(), this method is implemented for inheritance.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Parameters to be set.
Returns
Expand Down Expand Up @@ -338,15 +340,15 @@ class TransformingPipelineElement(ABCPipelineElement):

def __init__(
self,
name: str = "ABCPipelineElement",
name: Optional[str] = None,
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize ABCPipelineElement.
Parameters
----------
name: str
name: Optional[str], optional (default=None)
Name of PipelineElement
n_jobs: int
Number of cores used for processing.
Expand Down Expand Up @@ -377,12 +379,12 @@ def parameters(self) -> dict[str, Any]:
return self.get_params()

@parameters.setter
def parameters(self, **parameters: dict[str, Any]) -> None:
def parameters(self, **parameters: Any) -> None:
"""Set the parameters of the object.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Object parameters as a dictionary.
Returns
Expand Down Expand Up @@ -616,25 +618,6 @@ class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC):
_input_type = "RDKitMol"
_output_type = "RDKitMol"

def __init__(
self,
name: str = "MolToMolPipelineElement",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize MolToMolPipelineElement.
Parameters
----------
name: str
Name of the PipelineElement.
n_jobs: int
Number of cores used for processing.
uuid: Optional[str]
Unique identifier of the PipelineElement.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)

def transform(self, values: list[OptionalMol]) -> list[OptionalMol]:
"""Transform list of molecules to list of molecules.
Expand Down Expand Up @@ -700,25 +683,6 @@ class AnyToMolPipelineElement(TransformingPipelineElement, abc.ABC):

_output_type = "RDKitMol"

def __init__(
self,
name: str = "AnyToMolPipelineElement",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize AnyToMolPipelineElement.
Parameters
----------
name: str
Name of the PipelineElement.
n_jobs: int
Number of cores used for processing.
uuid: Optional[str]
Unique identifier of the PipelineElement.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)

def transform(self, values: Any) -> list[OptionalMol]:
"""Transform list of instances to list of molecules.
Expand Down Expand Up @@ -756,25 +720,6 @@ class MolToAnyPipelineElement(TransformingPipelineElement, abc.ABC):

_input_type = "RDKitMol"

def __init__(
self,
name: str = "MolToAnyPipelineElement",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize MolToAnyPipelineElement.
Parameters
----------
name: str
Name of the PipelineElement.
n_jobs: int
Number of cores used for processing.
uuid: Optional[str]
Unique identifier of the PipelineElement.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)

@abc.abstractmethod
def pretransform_single(self, value: RDKitMol) -> Any:
"""Transform the molecule, but skip parameters learned during fitting.
Expand Down
16 changes: 8 additions & 8 deletions molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:

return parameters

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set object parameters relevant for copying the class.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dictionary of parameter names and values.
Returns
Expand All @@ -160,7 +160,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self:
raise ValueError(
f"return_as has to be one of {get_args(OutputDatatype)}! (Received: {return_as})"
)
self._return_as = return_as # type: ignore
self._return_as = return_as
super().set_params(**parameter_dict_copy)
return self

Expand Down Expand Up @@ -300,12 +300,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:

return parameters

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set object parameters relevant for copying the class.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dictionary of parameter names and values.
Returns
Expand Down Expand Up @@ -398,12 +398,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
parameters.pop("fill_value", None)
return parameters

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set parameters.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dictionary of parameter names and values.
Returns
Expand All @@ -417,7 +417,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self:

# explicitly check for None, since 0 is a valid value
if radius is not None:
self._radius = radius # type: ignore
self._radius = radius
# explicitly check for None, since False is a valid value
if use_features is not None:
self._use_features = bool(use_features)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
params["standardizer"] = self._standardizer
return params

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set parameters.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dictionary with parameter names and corresponding values.
Returns
Expand All @@ -123,7 +123,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self:
parameter_copy = dict(parameters)
standardizer = parameter_copy.pop("standardizer", None)
if standardizer is not None:
self._standardizer = standardizer # type: ignore
self._standardizer = standardizer
super().set_params(**parameter_copy)
return self

Expand Down
21 changes: 0 additions & 21 deletions molpipeline/any2mol/bin2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import Optional

from rdkit import Chem

from molpipeline.abstract_pipeline_elements.core import (
Expand All @@ -16,25 +14,6 @@
class BinaryToMol(AnyToMolPipelineElement):
"""Transforms binary string representation to RDKit Mol objects."""

def __init__(
self,
name: str = "bin2mol",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize BinaryToMol.
Parameters
----------
name: str, optional (default="bin2mol")
Name of PipelineElement.
n_jobs: int, optional (default=1)
Number of cores used.
uuid: str | None, optional (default=None)
UUID of the pipeline element. If None, a random UUID is generated.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)

def pretransform_single(self, value: str) -> OptionalMol:
"""Transform binary string to molecule.
Expand Down
6 changes: 3 additions & 3 deletions molpipeline/any2mol/sdf2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
params["identifier"] = self.identifier
return params

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set parameters of the object.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dictionary containing all parameters defining the object.
Returns
Expand All @@ -86,7 +86,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self:
"""
super().set_params(**parameters)
if "identifier" in parameters:
self.identifier = parameters["identifier"] # type: ignore
self.identifier = parameters["identifier"]
return self

def finish(self) -> None:
Expand Down
21 changes: 0 additions & 21 deletions molpipeline/any2mol/smiles2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import Optional

from rdkit import Chem

from molpipeline.abstract_pipeline_elements.any2mol.string2mol import (
Expand All @@ -16,25 +14,6 @@
class SmilesToMol(_StringToMolPipelineElement):
"""Transforms Smiles to RDKit Mol objects."""

def __init__(
self,
name: str = "smiles2mol",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize SmilesToMol.
Parameters
----------
name: str, optional (default="smiles2mol")
Name of PipelineElement.
n_jobs: int, optional (default=1)
Number of cores used.
uuid: str | None, optional (default=None)
UUID of the pipeline element. If None, a random UUID is generated.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)

def pretransform_single(self, value: str) -> OptionalMol:
"""Transform Smiles string to molecule.
Expand Down
8 changes: 4 additions & 4 deletions molpipeline/error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
params["element_ids"] = self.element_ids
return params

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set parameters for this element.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Dict of arameters to set.
Returns
Expand Down Expand Up @@ -508,12 +508,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
params["fill_value"] = self.fill_value
return params

def set_params(self, **parameters: dict[str, Any]) -> Self:
def set_params(self, **parameters: Any) -> Self:
"""Set parameters for this element.
Parameters
----------
parameters: dict[str, Any]
parameters: Any
Parameter dict.
Returns
Expand Down
Loading

0 comments on commit 5b52a3b

Please sign in to comment.