diff --git a/ramannoodle/polarizability/torch/gnn.py b/ramannoodle/polarizability/torch/gnn.py index 4decc52..8c5cabf 100644 --- a/ramannoodle/polarizability/torch/gnn.py +++ b/ramannoodle/polarizability/torch/gnn.py @@ -37,7 +37,7 @@ # pylint: disable=not-callable -class GaussianFilter(torch.nn.Module): +class _GaussianFilter(torch.nn.Module): """Gaussian filter. Parameters should be chosen such that all expected inputs are between lower and @@ -75,7 +75,7 @@ def forward(self, x: Tensor) -> Tensor: return torch.exp(self.coefficient * x.pow(2)) -class NodeBlock(torch.nn.Module): +class _NodeBlock(torch.nn.Module): """Edge to node message passer. Architecture and notation is based on equation (5) in https://doi.org/10.1038/ @@ -144,7 +144,7 @@ def forward( return typing.cast(Tensor, (node_embedding + c1_emb).tanh()) -class EdgeBlock(torch.nn.Module): +class _EdgeBlock(torch.nn.Module): """Node to edge message passer. Architecture and notation is based on equation (6) in https://doi.org/10.1038/ @@ -442,6 +442,21 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes gaussian_filter_end: float, ): super().__init__() + # Validate parameters + if cutoff <= 0: + raise ValueError(f"invalid cutoff: {cutoff} <= 0") + if size_node_embedding <= 0: + raise ValueError(f"invalid size_node_embedding: {size_node_embedding} <= 0") + if size_edge_embedding <= 0: + raise ValueError(f"invalid size_edge_embedding: {size_edge_embedding} <= 0") + if num_message_passes <= 0: + raise ValueError(f"invalid num_message_passes: {num_message_passes} <= 0") + if gaussian_filter_start <= 0: + inequality = f"{gaussian_filter_start} <= 0" + raise ValueError(f"invalid gaussian_filter_start: {inequality}") + if gaussian_filter_end <= gaussian_filter_start: + inequality = f"{gaussian_filter_end} <= gaussian_filter_start" + raise ValueError(f"invalid gaussian_filter_end: {inequality}") self._ref_structure = ref_structure self._cutoff = cutoff @@ -472,19 +487,19 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes ShiftedSoftplus(), Linear(size_node_embedding, size_node_embedding), ) - self._edge_embedding = GaussianFilter( + self._edge_embedding = _GaussianFilter( gaussian_filter_start, gaussian_filter_end, size_edge_embedding ) self._node_blocks = ModuleList( [ - NodeBlock(size_node_embedding, size_edge_embedding) + _NodeBlock(size_node_embedding, size_edge_embedding) for _ in range(num_message_passes) ] ) self._edge_blocks = ModuleList( [ - EdgeBlock(size_node_embedding, size_edge_embedding) + _EdgeBlock(size_node_embedding, size_edge_embedding) for _ in range(num_message_passes) ] )