diff --git a/jraph/examples/hamiltonian_graph_network.py b/jraph/examples/hamiltonian_graph_network.py index 943165c..65da383 100644 --- a/jraph/examples/hamiltonian_graph_network.py +++ b/jraph/examples/hamiltonian_graph_network.py @@ -199,7 +199,8 @@ def set_system_state( position: np.ndarray, momentum: np.ndarray) -> jraph.GraphsTuple: """Sets the non-static parameters of the graph (momentum, position).""" - nodes = static_graph.nodes.copy(position=position, momentum=momentum) + nodes = static_graph.nodes.set("position", position) + nodes = nodes.set("momentum", momentum) return static_graph._replace(nodes=nodes)