Skip to content

Commit

Permalink
properly account for num_state_out
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 24, 2023
1 parent f902eea commit 39981e2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,7 @@ def mol_to_graph(self, mol: Mol) -> Graph:
The corresponding Graph representation of that molecule.
"""
raise NotImplementedError()

def object_to_log_repr(self, g: Graph) -> str:
"""Convert a Graph to a string representation for logging purposes"""
return ""
14 changes: 8 additions & 6 deletions src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
):
super().__init__()
self.ctx = env_ctx
self.num_state_out = num_state_out
num_hid = cfg.model.num_emb
num_outs = env_ctx.num_actions + num_state_out
mc = cfg.model
Expand Down Expand Up @@ -82,24 +83,25 @@ def forward(self, xs: SeqBatch, cond, batched=False):
final_rep = x if batched else pooled_x

out: torch.Tensor = self.output(final_rep)
ns = self.num_state_out
if batched:
# out is (time, batch, nout)
out = out.transpose(1, 0).contiguous().reshape((-1, out.shape[2])) # (batch * time, nout)
# logit_idx tells us where (in the flattened array of outputs) the non-masked outputs are.
# E.g. if the batch is [["ABC", "VWXYZ"]], logit_idx would be [0, 1, 2, 5, 6, 7, 8, 9]
stop_logits = out[xs.logit_idx, 0:1] # (proper_time, 1)
state_preds = out[xs.logit_idx, 1:2] # (proper_time, 1)
add_node_logits = out[xs.logit_idx, 2:] # (proper_time, nout - 1)
state_preds = out[xs.logit_idx, 0:ns] # (proper_time, num_state_out)
stop_logits = out[xs.logit_idx, ns : ns + 1] # (proper_time, 1)
add_node_logits = out[xs.logit_idx, ns + 1 :] # (proper_time, nout - 1)
# `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs))
# which is what we need to give to GraphActionCategorical
else:
# The default num_graphs is computed for the batched case, so we need to fix it here so that
# GraphActionCategorical knows how many "graphs" (sequence inputs) there are
xs.num_graphs = out.shape[0]
# out is (batch, nout)
stop_logits = out[:, 0:1]
state_preds = out[:, 1:2]
add_node_logits = out[:, 2:]
state_preds = out[:, 0:ns]
stop_logits = out[:, ns : ns + 1]
add_node_logits = out[:, ns + 1 :]

return (
GraphActionCategorical(
Expand Down

0 comments on commit 39981e2

Please sign in to comment.