Skip to content

Commit

Permalink
fix dimension of output of mean_polarizability and std_polarizability
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 19, 2024
1 parent 9c4d578 commit dbf85d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def mean_polarizability(self) -> NDArray[np.float64]:
:
2D array with shape (3,3).
"""
return self._polarizabilities.mean(0, keepdim=True).clone().numpy()
return self._polarizabilities.mean(0).clone().numpy()

@property
def stddev_polarizability(self) -> NDArray[np.float64]:
Expand All @@ -196,7 +196,7 @@ def stddev_polarizability(self) -> NDArray[np.float64]:
:
2D array with shape (3,3).
"""
result = self._polarizabilities.std(0, unbiased=False, keepdim=True)
result = self._polarizabilities.std(0, unbiased=False)
return result.clone().numpy()

def scale_polarizabilities(
Expand Down
4 changes: 2 additions & 2 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
raise ValueError(f"invalid size_edge_embedding: {size_edge_embedding} <= 0")
if num_message_passes <= 0:
raise ValueError(f"invalid num_message_passes: {num_message_passes} <= 0")
if gaussian_filter_start <= 0:
inequality = f"{gaussian_filter_start} <= 0"
if gaussian_filter_start < 0:
inequality = f"{gaussian_filter_start} < 0"
raise ValueError(f"invalid gaussian_filter_start: {inequality}")
if gaussian_filter_end <= gaussian_filter_start:
inequality = f"{gaussian_filter_end} <= gaussian_filter_start"
Expand Down

0 comments on commit dbf85d5

Please sign in to comment.