Skip to content

Commit

Permalink
Merge branch 'main' into 81-postpredictionwrapper-handle-set_params-w…
Browse files Browse the repository at this point in the history
…ithout-wrapped_estimator
  • Loading branch information
c-w-feldmann authored Sep 23, 2024
2 parents c81b048 + 22efd0d commit 4ae5da2
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
pip install pylint
- name: Analysing the code with pylint
run: |
pylint -d C0301,R0913,W1202 $(git ls-files '*.py') --ignored-modules "rdkit"
pylint -d C0301,R0913,W1202 $(git ls-files '*.py') --ignored-modules "rdkit" --max-positional-arguments 10
mypy:
runs-on: ubuntu-latest
steps:
Expand Down
4 changes: 2 additions & 2 deletions molpipeline/estimators/chemprop/component_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PredictorWrapper(_Predictor, BaseEstimator, abc.ABC): # type: ignore
_T_default_criterion: LossFunction
_T_default_metric: Metric

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
n_tasks: int = 1,
input_dim: int = DEFAULT_HIDDEN_DIM,
Expand Down Expand Up @@ -327,7 +327,7 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN
_T_default_criterion = CrossEntropyLoss
_T_default_metric = CrossEntropyMetric

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
n_classes: int,
n_tasks: int = 1,
Expand Down
13 changes: 6 additions & 7 deletions molpipeline/mol2any/mol2morgan_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,11 @@ def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]:
dict[int, list[tuple[int, int]]]
Dictionary with bit position as key and list of tuples with atom index and radius as value.
"""
bit_info: dict[int, list[tuple[int, int]]] = {}
_ = AllChem.GetMorganFingerprintAsBitVect(
mol_obj,
self.radius,
useFeatures=self._use_features,
bitInfo=bit_info,
nBits=self._n_bits,
fp_generator = self._get_fp_generator()
additional_output = AllChem.AdditionalOutput()
additional_output.AllocateBitInfoMap()
_ = fp_generator.GetSparseFingerprint(
mol_obj, additionalOutput=additional_output
)
bit_info = additional_output.GetBitInfoMap()
return bit_info
2 changes: 1 addition & 1 deletion molpipeline/mol2any/mol2path_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Mol2PathFP(
"""

# pylint: disable=too-many-arguments,too-many-locals
# pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
def __init__(
self,
min_path: int = 1,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol
from molpipeline.mol2any import MolToMorganFP
from tests.utils.fingerprints import fingerprints_to_numpy

test_smiles = [
"c1ccccc1",
Expand Down Expand Up @@ -128,6 +129,34 @@ def test_setter_getter_error_handling(self) -> None:
}
self.assertRaises(ValueError, mol_fp.set_params, **params)

def test_bit2atom_mapping(self) -> None:
"""Test that the mapping from bits to atom weights works as intended.
Notes
-----
lower n_bit values, e.g. 2048, will lead to a bit clash during folding,
for the test smiles "NCCOCCCC(=O)O".
We want no folding clashes in this test to check the correct length
of the bit-to-atom mapping.
"""
n_bits = 2100
sparse_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="sparse")
dense_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="dense")
explicit_bit_vect_morgan = MolToMorganFP(
radius=2, n_bits=n_bits, return_as="explicit_bit_vect"
)

smi2mol = SmilesToMol()
for test_smi in test_smiles:
for fp_gen in [sparse_morgan, dense_morgan, explicit_bit_vect_morgan]:
for counted in [False, True]:
mol = smi2mol.transform([test_smi])[0]
fp_gen.set_params(counted=counted)
fp = fp_gen.transform([mol])
mapping = fp_gen.bit2atom_mapping(mol)
np_fp = fingerprints_to_numpy(fp)
self.assertEqual(np.nonzero(np_fp)[0].shape[0], len(mapping)) # type: ignore


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion tests/utils/fingerprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# pylint: disable=no-name-in-module
from rdkit.Chem import rdFingerprintGenerator as rdkit_fp
from rdkit.DataStructs import ExplicitBitVect
from rdkit.DataStructs import ExplicitBitVect, UIntSparseIntVect
from scipy import sparse


Expand Down Expand Up @@ -59,6 +59,8 @@ def fingerprints_to_numpy(
"""
if all(isinstance(fp, ExplicitBitVect) for fp in fingerprints):
return np.array(fingerprints)
if all(isinstance(fp, UIntSparseIntVect) for fp in fingerprints):
return np.array([fp.ToList() for fp in fingerprints])
if isinstance(fingerprints, sparse.csr_matrix):
return fingerprints.toarray()
if isinstance(fingerprints, np.ndarray):
Expand Down

0 comments on commit 4ae5da2

Please sign in to comment.