diff --git a/molpipeline/mol2any/mol2morgan_fingerprint.py b/molpipeline/mol2any/mol2morgan_fingerprint.py index 1c93295d..79fa46c1 100644 --- a/molpipeline/mol2any/mol2morgan_fingerprint.py +++ b/molpipeline/mol2any/mol2morgan_fingerprint.py @@ -151,12 +151,11 @@ def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]: dict[int, list[tuple[int, int]]] Dictionary with bit position as key and list of tuples with atom index and radius as value. """ - bit_info: dict[int, list[tuple[int, int]]] = {} - _ = AllChem.GetMorganFingerprintAsBitVect( - mol_obj, - self.radius, - useFeatures=self._use_features, - bitInfo=bit_info, - nBits=self._n_bits, + fp_generator = self._get_fp_generator() + additional_output = AllChem.AdditionalOutput() + additional_output.AllocateBitInfoMap() + _ = fp_generator.GetSparseFingerprint( + mol_obj, additionalOutput=additional_output ) + bit_info = additional_output.GetBitInfoMap() return bit_info diff --git a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py index 3a5e94a9..6fae46b4 100644 --- a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py @@ -10,6 +10,7 @@ from molpipeline import Pipeline from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToMorganFP +from tests.utils.fingerprints import fingerprints_to_numpy test_smiles = [ "c1ccccc1", @@ -128,6 +129,34 @@ def test_setter_getter_error_handling(self) -> None: } self.assertRaises(ValueError, mol_fp.set_params, **params) + def test_bit2atom_mapping(self) -> None: + """Test that the mapping from bits to atom weights works as intended. + + Notes + ----- + lower n_bit values, e.g. 2048, will lead to a bit clash during folding, + for the test smiles "NCCOCCCC(=O)O". + We want no folding clashes in this test to check the correct length + of the bit-to-atom mapping. + """ + n_bits = 2100 + sparse_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="sparse") + dense_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="dense") + explicit_bit_vect_morgan = MolToMorganFP( + radius=2, n_bits=n_bits, return_as="explicit_bit_vect" + ) + + smi2mol = SmilesToMol() + for test_smi in test_smiles: + for fp_gen in [sparse_morgan, dense_morgan, explicit_bit_vect_morgan]: + for counted in [False, True]: + mol = smi2mol.transform([test_smi])[0] + fp_gen.set_params(counted=counted) + fp = fp_gen.transform([mol]) + mapping = fp_gen.bit2atom_mapping(mol) + np_fp = fingerprints_to_numpy(fp) + self.assertEqual(np.nonzero(np_fp)[0].shape[0], len(mapping)) # type: ignore + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/fingerprints.py b/tests/utils/fingerprints.py index 5973d004..1ca392a4 100644 --- a/tests/utils/fingerprints.py +++ b/tests/utils/fingerprints.py @@ -8,7 +8,7 @@ # pylint: disable=no-name-in-module from rdkit.Chem import rdFingerprintGenerator as rdkit_fp -from rdkit.DataStructs import ExplicitBitVect +from rdkit.DataStructs import ExplicitBitVect, UIntSparseIntVect from scipy import sparse @@ -59,6 +59,8 @@ def fingerprints_to_numpy( """ if all(isinstance(fp, ExplicitBitVect) for fp in fingerprints): return np.array(fingerprints) + if all(isinstance(fp, UIntSparseIntVect) for fp in fingerprints): + return np.array([fp.ToList() for fp in fingerprints]) if isinstance(fingerprints, sparse.csr_matrix): return fingerprints.toarray() if isinstance(fingerprints, np.ndarray):