Skip to content

Commit

Permalink
explainability: use all atoms instead of heavy atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 20, 2024
1 parent a5effe4 commit ef338d3
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _add_gaussians_for_atoms(
ValueGrid
ValueGrid object with added functions.
"""
for i in range(mol.GetNumHeavyAtoms()):
for i in range(mol.GetNumAtoms()):
if atom_weights[i] == 0:
continue
pos = conf.GetAtomPosition(i)
Expand Down Expand Up @@ -233,7 +233,7 @@ def make_sum_of_gaussians_grid(
"""
# assign default values and convert to numpy array
if atom_weights is None:
atom_weights = np.zeros(mol.GetNumHeavyAtoms())
atom_weights = np.zeros(mol.GetNumAtoms())
elif not isinstance(atom_weights, np.ndarray):
atom_weights = np.array(atom_weights)

Expand All @@ -243,10 +243,8 @@ def make_sum_of_gaussians_grid(
bond_weights = np.array(bond_weights)

# validate input
if not len(atom_weights) == mol.GetNumHeavyAtoms():
raise ValueError(
"len(atom_weights) is not equal to number of heavy atoms in mol"
)
if not len(atom_weights) == mol.GetNumAtoms():
raise ValueError("len(atom_weights) is not equal to number of atoms in mol")

if not len(bond_weights) == len(mol.GetBonds()):
raise ValueError("len(bond_weights) is not equal to number of bonds in mol")
Expand Down

0 comments on commit ef338d3

Please sign in to comment.