From ef338d337d27b858cddf21e25f70e795cde5d79a Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Wed, 20 Nov 2024 14:09:57 +0100 Subject: [PATCH] explainability: use all atoms instead of heavy atoms --- .../explainability/visualization/visualization.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/molpipeline/explainability/visualization/visualization.py b/molpipeline/explainability/visualization/visualization.py index b4bd4f15..76a6ca20 100644 --- a/molpipeline/explainability/visualization/visualization.py +++ b/molpipeline/explainability/visualization/visualization.py @@ -113,7 +113,7 @@ def _add_gaussians_for_atoms( ValueGrid ValueGrid object with added functions. """ - for i in range(mol.GetNumHeavyAtoms()): + for i in range(mol.GetNumAtoms()): if atom_weights[i] == 0: continue pos = conf.GetAtomPosition(i) @@ -233,7 +233,7 @@ def make_sum_of_gaussians_grid( """ # assign default values and convert to numpy array if atom_weights is None: - atom_weights = np.zeros(mol.GetNumHeavyAtoms()) + atom_weights = np.zeros(mol.GetNumAtoms()) elif not isinstance(atom_weights, np.ndarray): atom_weights = np.array(atom_weights) @@ -243,10 +243,8 @@ def make_sum_of_gaussians_grid( bond_weights = np.array(bond_weights) # validate input - if not len(atom_weights) == mol.GetNumHeavyAtoms(): - raise ValueError( - "len(atom_weights) is not equal to number of heavy atoms in mol" - ) + if not len(atom_weights) == mol.GetNumAtoms(): + raise ValueError("len(atom_weights) is not equal to number of atoms in mol") if not len(bond_weights) == len(mol.GetBonds()): raise ValueError("len(bond_weights) is not equal to number of bonds in mol")