Skip to content

Commit

Permalink
Fix docstring in GraphTransformer (#97) (#104)
Browse files Browse the repository at this point in the history
This PR fixes the docstring in GraphTransformer as mentioned in #97.

* Fix docstring in GraphTransformer (#97)

* trim trailing whitespace
  • Loading branch information
hohyun312 authored Aug 22, 2023
1 parent 86eccea commit 4babde7
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/gflownet/models/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ class GraphTransformer(nn.Module):
conditional information, since they condition the output). The graph features are projected to
virtual nodes (one per graph), which are fully connected.
The per node outputs are the concatenation of the final (post graph-convolution) node embeddings
and of the final virtual node embedding of the graph each node corresponds to.
The per node outputs are the final (post graph-convolution) node embeddings.
The per graph outputs are the concatenation of a global mean pooling operation, of the final
virtual node embeddings, and of the conditional information embedding.
node embeddings, and of the final virtual node embeddings.
"""

def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre"):
Expand Down Expand Up @@ -134,8 +133,8 @@ def forward(self, g: gd.Batch, cond: torch.Tensor):
o = o + l_h * scale + shift
o = o + ff(norm2(o, aug_batch))

glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1)
o_final = torch.cat([o[: -c.shape[0]]], 1)
o_final = o[: -c.shape[0]]
glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1)
return o_final, glob


Expand Down

0 comments on commit 4babde7

Please sign in to comment.