Skip to content

Commit

Permalink
Merge pull request #28 from speckhard:patch-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 439292921
Change-Id: Ia4b86525a772dc360631c24d2969bc3edefd3c93
  • Loading branch information
jg8610 committed Apr 5, 2022
2 parents 4ec7f1e + 74ad137 commit 4b235ab
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions jraph/examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def run():
senders=np.array([0, 1]), receivers=np.array([2, 2]))
logging.info("Nested graph %r", nested_graph)

# Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
# Creates a GraphsTuple from scratch containing 2 graphs using an implicit
# batch dimension.
# The first graph has 3 nodes and 2 edges.
# The second graph has 1 nodes and 1 edges.
# The second graph has 1 node and 1 edge.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
Expand Down Expand Up @@ -93,7 +93,7 @@ def run():
# Creates a padded GraphsTuple from an existing GraphsTuple.
# The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
# Three graphs are added for the padding.
# First an dummy graph which contains the padding nodes and edges and secondly
# First a dummy graph which contains the padding nodes and edges and secondly
# two empty graphs without nodes or edges to pad out the graphs.
padded_graph = jraph.pad_with_graphs(
single_graph, n_node=10, n_edge=5, n_graph=4)
Expand All @@ -104,7 +104,7 @@ def run():
single_graph = jraph.unpad_with_graphs(padded_graph)
logging.info("Unpadded graph %r", single_graph)

# Creates a GraphsTuple containing a 2 graphs using an explicit batch
# Creates a GraphsTuple containing 2 graphs using an explicit batch
# dimension.
# An explicit batch dimension requires more memory, but can simplify
# the definition of functions operating on the graph.
Expand All @@ -113,7 +113,7 @@ def run():
# Using an explicit batch requires padding all feature vectors to
# the maximum size of nodes and edges.
# The first graph has 3 nodes and 2 edges.
# The second graph has 1 nodes and 1 edges.
# The second graph has 1 node and 1 edge.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
Expand All @@ -125,7 +125,7 @@ def run():
receivers=np.array([[2, 2], [0, -1]]))
logging.info("Explicitly batched graph %r", explicitly_batched_graph)

# Running a graph propagation steps.
# Running a graph propagation step.
# First define the update functions for the edges, nodes and globals.
# In this example we use the identity everywhere.
# For Graph neural networks, each update function is typically a neural
Expand Down Expand Up @@ -156,6 +156,7 @@ def update_globals_fn(
aggregated_node_features,
aggregated_edge_features,
globals_):
"""Returns the global features."""
del aggregated_node_features
del aggregated_edge_features
return globals_
Expand All @@ -166,8 +167,8 @@ def update_globals_fn(
aggregate_nodes_for_globals_fn = jraph.segment_sum
aggregate_edges_for_globals_fn = jraph.segment_sum

# Optionally define attention logit function and attention reduce function.
# This can be used for graph attention.
# Optionally define an attention logit function and an attention reduce
# function. This can be used for graph attention.
# The attention function calculates attention weights, and the apply
# attention function calculates the new edge feature given the weights.
# We don't use graph attention here, and just pass the defaults.
Expand Down

0 comments on commit 4b235ab

Please sign in to comment.