Skip to content

Commit

Permalink
parameter validation and private classes
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 19, 2024
1 parent 3f500d0 commit 9c4d578
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# pylint: disable=not-callable


class GaussianFilter(torch.nn.Module):
class _GaussianFilter(torch.nn.Module):
"""Gaussian filter.
Parameters should be chosen such that all expected inputs are between lower and
Expand Down Expand Up @@ -75,7 +75,7 @@ def forward(self, x: Tensor) -> Tensor:
return torch.exp(self.coefficient * x.pow(2))


class NodeBlock(torch.nn.Module):
class _NodeBlock(torch.nn.Module):
"""Edge to node message passer.
Architecture and notation is based on equation (5) in https://doi.org/10.1038/
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(
return typing.cast(Tensor, (node_embedding + c1_emb).tanh())


class EdgeBlock(torch.nn.Module):
class _EdgeBlock(torch.nn.Module):
"""Node to edge message passer.
Architecture and notation is based on equation (6) in https://doi.org/10.1038/
Expand Down Expand Up @@ -442,6 +442,21 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
gaussian_filter_end: float,
):
super().__init__()
# Validate parameters
if cutoff <= 0:
raise ValueError(f"invalid cutoff: {cutoff} <= 0")
if size_node_embedding <= 0:
raise ValueError(f"invalid size_node_embedding: {size_node_embedding} <= 0")
if size_edge_embedding <= 0:
raise ValueError(f"invalid size_edge_embedding: {size_edge_embedding} <= 0")
if num_message_passes <= 0:
raise ValueError(f"invalid num_message_passes: {num_message_passes} <= 0")
if gaussian_filter_start <= 0:
inequality = f"{gaussian_filter_start} <= 0"
raise ValueError(f"invalid gaussian_filter_start: {inequality}")
if gaussian_filter_end <= gaussian_filter_start:
inequality = f"{gaussian_filter_end} <= gaussian_filter_start"
raise ValueError(f"invalid gaussian_filter_end: {inequality}")

self._ref_structure = ref_structure
self._cutoff = cutoff
Expand Down Expand Up @@ -472,19 +487,19 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
ShiftedSoftplus(),
Linear(size_node_embedding, size_node_embedding),
)
self._edge_embedding = GaussianFilter(
self._edge_embedding = _GaussianFilter(
gaussian_filter_start, gaussian_filter_end, size_edge_embedding
)

self._node_blocks = ModuleList(
[
NodeBlock(size_node_embedding, size_edge_embedding)
_NodeBlock(size_node_embedding, size_edge_embedding)
for _ in range(num_message_passes)
]
)
self._edge_blocks = ModuleList(
[
EdgeBlock(size_node_embedding, size_edge_embedding)
_EdgeBlock(size_node_embedding, size_edge_embedding)
for _ in range(num_message_passes)
]
)
Expand Down

0 comments on commit 9c4d578

Please sign in to comment.