Skip to content

Commit

Permalink
more doc and such
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 23, 2023
1 parent 6c477e7 commit 106efc2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __init__(
self.reward_normalize_losses = False
self.sample_temp = 1
self.bootstrap_own_reward = self.cfg.bootstrap_own_reward
# When the model is autoregressive, we can avoid giving it ["A", "AB", "ABC", ...] as a sequence of inputs, and
# instead give "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using something
# like a transformer with causal self-attention.
self.model_is_autoregressive = False

self.graph_sampler = GraphSampler(
Expand Down
19 changes: 14 additions & 5 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, List, Tuple
from typing import Any, List, Sequence, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -63,9 +63,12 @@ def __init__(self, seqs: List[torch.Tensor], pad: int):
self.x = pad_sequence(seqs, batch_first=False, padding_value=pad)
self.mask = self.x.eq(pad).T
self.lens = torch.tensor([len(i) for i in seqs], dtype=torch.long)
# This tells where (in the flattened array of outputs) the non-masked outputs are.
# E.g. if the batch is [["ABC", "VWXYZ"]], logit_idx would be [0, 1, 2, 5, 6, 7, 8, 9]
self.logit_idx = self.x.ne(pad).flatten().nonzero().flatten()
# Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this
# is the total number of timesteps.
self.num_graphs = self.lens.sum().item()
# self.batch = torch.tensor([[i] * len(s) for i, s in enumerate(seqs)])

def to(self, device):
for name in dir(self):
Expand All @@ -75,8 +78,13 @@ def to(self, device):
return self


class GenericSeqBuildingContext(GraphBuildingEnvContext):
def __init__(self, alphabet, num_cond_dim=0):
class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
"""This class masquerades as a GraphBuildingEnvContext, but actually generates sequences of tokens.
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
"""

def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
self.alphabet = alphabet
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]

Expand All @@ -87,6 +95,7 @@ def __init__(self, alphabet, num_cond_dim=0):
self.num_cond_dim = num_cond_dim

def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
act_type, _, act_col = [int(i) for i in action_idx]
t = self.action_type_order[act_type]
if t is GraphActionType.Stop:
Expand Down Expand Up @@ -118,7 +127,7 @@ def is_sane(self, g: Graph) -> bool:

def graph_to_mol(self, g: Graph):
s: Seq = g # type: ignore
return "".join(self.alphabet[i] for i in s.seq)
return "".join(self.alphabet[int(i)] for i in s.seq)

def object_to_log_repr(self, g: Graph):
return self.graph_to_mol(g)
30 changes: 23 additions & 7 deletions src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This code is adapted from https://github.com/MJ10/mo_gfn
import math

import torch
Expand All @@ -8,9 +9,9 @@
from gflownet.envs.seq_building_env import SeqBatch


class MLP(nn.Module):
class MLPWithDropout(nn.Module):
def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False):
super(MLP, self).__init__()
super(MLPWithDropout, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
layers = [nn.Linear(in_dim, hidden_layers[0]), nn.ReLU()]
Expand All @@ -20,11 +21,13 @@ def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False
layers.append(nn.Linear(hidden_layers[-1], out_dim))
self.model = nn.Sequential(*layers)

def forward(self, x, with_uncertainty=False):
def forward(self, x):
return self.model(x)


class SeqTransformerGFN(nn.Module):
"""A standard transformer-encoder based GFN model for sequences."""

ctx: GraphBuildingEnvContext

def __init__(
Expand All @@ -34,10 +37,9 @@ def __init__(
num_state_out=1,
):
super().__init__()
# num_hid, cond_dim, max_len, vocab_size, num_actions, dropout, num_layers, num_head, use_cond, **kwargs
self.ctx = env_ctx
num_hid = cfg.model.num_emb
num_outs = env_ctx.num_outputs + num_state_out
num_outs = env_ctx.num_actions + num_state_out
mc = cfg.model
self.pos = PositionalEncoding(num_hid, dropout=cfg.model.dropout, max_len=cfg.algo.max_len + 2)
self.use_cond = env_ctx.num_cond_dim > 0
Expand All @@ -48,13 +50,25 @@ def __init__(
self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers)
self.logZ = nn.Linear(env_ctx.num_cond_dim, 1)
if self.use_cond:
self.output = MLP(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid)
else:
self.output = MLP(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout)
self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout)
self.num_hid = num_hid

def forward(self, xs: SeqBatch, cond, batched=False):
"""Returns a GraphActionCategorical and a tensor of state predictions.
Parameters
----------
xs: SeqBatch
A batch of sequences.
cond: torch.Tensor
A tensor of conditional information.
batched: bool
If True, the it's assumed that the cond tensor is constant along a sequence, and the output is given
at each timestep (of the autoregressive process), which works because we are using causal self-attenion.
If False, only the last timesteps' output is returned, which one would use to sample the next token."""
x = self.embedding(xs.x)
x = self.pos(x) # (time, batch, nemb)
x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device))
Expand All @@ -71,6 +85,8 @@ def forward(self, xs: SeqBatch, cond, batched=False):
if batched:
# out is (time, batch, nout)
out = out.transpose(1, 0).contiguous().reshape((-1, out.shape[2])) # (batch * time, nout)
# logit_idx tells us where (in the flattened array of outputs) the non-masked outputs are.
# E.g. if the batch is [["ABC", "VWXYZ"]], logit_idx would be [0, 1, 2, 5, 6, 7, 8, 9]
stop_logits = out[xs.logit_idx, 0:1] # (proper_time, 1)
state_preds = out[xs.logit_idx, 1:2] # (proper_time, 1)
add_node_logits = out[xs.logit_idx, 2:] # (proper_time, nout - 1)
Expand Down
10 changes: 7 additions & 3 deletions src/gflownet/tasks/toy_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from torch import Tensor

from gflownet.config import Config
from gflownet.envs.seq_building_env import GenericSeqBuildingContext, SeqBuildingEnv
from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv
from gflownet.models.seq_transformer import SeqTransformerGFN
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
from gflownet.utils.conditioning import TemperatureConditional


class ToySeqTask(GFNTask):
"""Sets up a task where the reward is the number of times some sequences appear in the input"""
"""Sets up a task where the reward is the number of times some sequences appear in the input. Normalized to be
in [0,1]"""

def __init__(
self,
Expand Down Expand Up @@ -88,13 +89,16 @@ def setup_task(self):

def setup_env_context(self):
self.env = SeqBuildingEnv(None)
self.ctx = GenericSeqBuildingContext(
self.ctx = AutoregressiveSeqBuildingContext(
"abc",
self.task.num_cond_dim,
)

def setup_algo(self):
super().setup_algo()
# If the algo implements it, avoid giving, ["A", "AB", "ABC", ...] as a sequence of inputs, and instead give
# "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using a transformer with
# causal self-attention.
self.algo.model_is_autoregressive = True


Expand Down

0 comments on commit 106efc2

Please sign in to comment.