diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 05f9b0e4..8c3993f0 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -25,11 +25,10 @@ class GraphTransformer(nn.Module): conditional information, since they condition the output). The graph features are projected to virtual nodes (one per graph), which are fully connected. - The per node outputs are the concatenation of the final (post graph-convolution) node embeddings - and of the final virtual node embedding of the graph each node corresponds to. + The per node outputs are the final (post graph-convolution) node embeddings. The per graph outputs are the concatenation of a global mean pooling operation, of the final - virtual node embeddings, and of the conditional information embedding. + node embeddings, and of the final virtual node embeddings. """ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre"): @@ -134,8 +133,8 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): o = o + l_h * scale + shift o = o + ff(norm2(o, aug_batch)) - glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1) - o_final = torch.cat([o[: -c.shape[0]]], 1) + o_final = o[: -c.shape[0]] + glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) return o_final, glob