Skip to content

Commit

Permalink
missing layer args
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jul 27, 2023
1 parent 217915b commit 203d203
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions segnn_jax/segnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def _decoder(st_graph: SteerableGraphsTuple):
nodes = st_graph.graph.nodes
# pre pool block
for i in range(blocks):
nodes = O3TensorProductGate(latent_irreps, name=f"prepool_{i}")(
nodes, st_graph.node_attributes
)
nodes = O3TensorProductGate(
latent_irreps, name=f"prepool_{i}", o3_layer=O3Layer
)(nodes, st_graph.node_attributes)

if task == "node":
nodes = O3Layer(output_irreps, name="output")(
Expand Down Expand Up @@ -185,9 +185,9 @@ def _message(
msg = e3nn.concatenate([msg, additional_message_features], axis=-1)
# message mlp (phi_m in the paper) steered by edge attributeibutes
for i in range(self._blocks):
msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
msg, edge_attribute
)
msg = O3TensorProductGate(
self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer
)(msg, edge_attribute)
# NOTE: original implementation only applied batch norm to messages
if self._norm == "batch":
msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg)
Expand All @@ -207,9 +207,9 @@ def _update(
x = e3nn.concatenate([nodes, msg], axis=-1)
# update mlp (phi_f in the paper) steered by node attributeibutes
for i in range(self._blocks - 1):
x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
x, node_attribute
)
x = O3TensorProductGate(
self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer
)(x, node_attribute)
# last update layer without activation
update = self._O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")(
x, node_attribute
Expand Down

0 comments on commit 203d203

Please sign in to comment.