Skip to content

Commit

Permalink
fixed sequence setup to be masked & p(x) computable
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Jan 10, 2024
1 parent d1982ec commit 7ca54c5
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 151 deletions.
23 changes: 18 additions & 5 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

# For typing's sake, we'll pretend that a sequence is a graph.
class Seq(Graph):
def __init__(self):
self.seq: list[Any] = []
def __init__(self, seq=None):
self.seq: list[Any] = [] if seq is None else seq

def __repr__(self):
return "".join(map(str, self.seq))
Expand Down Expand Up @@ -58,7 +58,8 @@ def reverse(self, g: Graph, ga: GraphAction):


class SeqBatch:
def __init__(self, seqs: List[torch.Tensor], pad: int):
def __init__(self, seqs: List[torch.Tensor], pad: int, max_len: int = 2048):
self.max_len = max_len + 1 # +1 for BOS
self.seqs = seqs
self.x = pad_sequence(seqs, batch_first=False, padding_value=pad)
self.mask = self.x.eq(pad).T
Expand All @@ -69,6 +70,14 @@ def __init__(self, seqs: List[torch.Tensor], pad: int):
# Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this
# is the total number of timesteps.
self.num_graphs = self.lens.sum().item()
self.batch_stop_mask = torch.ones_like(self.logit_idx)[:, None]
self.batch_append_mask = (
torch.ones_like(self.logit_idx)
if self.lens.max() < self.max_len
else (self.logit_idx % self.max_len).lt(self.max_len - 1)
)[:, None].float()
self.tail_stop_mask = torch.ones((len(seqs), 1))
self.tail_append_mask = (self.lens[:, None] < self.max_len).float()

def to(self, device):
for name in dir(self):
Expand All @@ -84,7 +93,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
"""

def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
def __init__(self, alphabet: Sequence[str], num_cond_dim=0, max_len=None):
self.alphabet = alphabet
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]

Expand All @@ -93,6 +102,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
self.pad_token = len(alphabet) + 1
self.num_actions = len(alphabet) + 1 # Alphabet + Stop
self.num_cond_dim = num_cond_dim
self.max_len = max_len

def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
Expand Down Expand Up @@ -120,7 +130,7 @@ def graph_to_Data(self, g: Graph):
return torch.tensor([self.bos_token] + s.seq, dtype=torch.long)

def collate(self, graphs: List[Data]):
return SeqBatch(graphs, pad=self.pad_token)
return SeqBatch(graphs, pad=self.pad_token, max_len=self.max_len)

def is_sane(self, g: Graph) -> bool:
return True
Expand All @@ -131,3 +141,6 @@ def graph_to_mol(self, g: Graph):

def object_to_log_repr(self, g: Graph):
return self.graph_to_mol(g)

def mol_to_graph(self, mol) -> Graph:
return mol
13 changes: 12 additions & 1 deletion src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext
from gflownet.envs.seq_building_env import SeqBatch
from gflownet.models.config import SeqPosEnc
from gflownet.models.graph_transformer import mlp
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType


class MLPWithDropout(nn.Module):
def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False):
super(MLPWithDropout, self).__init__()
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(
self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid)
encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout)
self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers)
self.logZ = nn.Linear(env_ctx.num_cond_dim, 1)
self.logZ = mlp(env_ctx.num_cond_dim, num_hid, 1, 2) #nn.Linear(env_ctx.num_cond_dim, 1)
if self.use_cond:
self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid)
Expand Down Expand Up @@ -109,6 +111,7 @@ def forward(self, xs: SeqBatch, cond, batched=False):
state_preds = out[xs.logit_idx, 0:ns] # (proper_time, num_state_out)
stop_logits = out[xs.logit_idx, ns : ns + 1] # (proper_time, 1)
add_node_logits = out[xs.logit_idx, ns + 1 :] # (proper_time, nout - 1)
masks = [xs.batch_stop_mask, xs.batch_append_mask]
# `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs))
# which is what we need to give to GraphActionCategorical
else:
Expand All @@ -119,18 +122,26 @@ def forward(self, xs: SeqBatch, cond, batched=False):
state_preds = out[:, 0:ns]
stop_logits = out[:, ns : ns + 1]
add_node_logits = out[:, ns + 1 :]
masks = [xs.tail_stop_mask, xs.tail_append_mask]

stop_logits = self._mask(stop_logits, masks[0])
add_node_logits = self._mask(add_node_logits, masks[1])

return (
GraphActionCategorical(
xs,
logits=[stop_logits, add_node_logits],
keys=[None, None],
types=self.ctx.action_type_order,
masks=masks,
slice_dict={},
),
state_preds,
)

def _mask(self, logits, mask):
return logits * mask + (1 - mask) * -1e6


def generate_square_subsequent_mask(sz: int):
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
Expand Down
Loading

0 comments on commit 7ca54c5

Please sign in to comment.