diff --git a/jraph/examples/basic.py b/jraph/examples/basic.py index 0fa4741..66f335c 100644 --- a/jraph/examples/basic.py +++ b/jraph/examples/basic.py @@ -57,10 +57,10 @@ def run(): senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) - # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit + # Creates a GraphsTuple from scratch containing 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. - # The second graph has 1 nodes and 1 edges. + # The second graph has 1 node and 1 edge. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. @@ -93,7 +93,7 @@ def run(): # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. - # First an dummy graph which contains the padding nodes and edges and secondly + # First a dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs( single_graph, n_node=10, n_edge=5, n_graph=4) @@ -104,7 +104,7 @@ def run(): single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) - # Creates a GraphsTuple containing a 2 graphs using an explicit batch + # Creates a GraphsTuple containing 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. @@ -113,7 +113,7 @@ def run(): # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. - # The second graph has 1 nodes and 1 edges. + # The second graph has 1 node and 1 edge. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. @@ -125,7 +125,7 @@ def run(): receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) - # Running a graph propagation steps. + # Running a graph propagation step. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural @@ -156,6 +156,7 @@ def update_globals_fn( aggregated_node_features, aggregated_edge_features, globals_): + """Returns the global features.""" del aggregated_node_features del aggregated_edge_features return globals_ @@ -166,8 +167,8 @@ def update_globals_fn( aggregate_nodes_for_globals_fn = jraph.segment_sum aggregate_edges_for_globals_fn = jraph.segment_sum - # Optionally define attention logit function and attention reduce function. - # This can be used for graph attention. + # Optionally define an attention logit function and an attention reduce + # function. This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults.