Skip to content

Commit

Permalink
working LR sequence generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 23, 2023
1 parent 4babde7 commit 4eb3e2f
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 25 deletions.
19 changes: 14 additions & 5 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self.reward_normalize_losses = False
self.sample_temp = 1
self.bootstrap_own_reward = self.cfg.bootstrap_own_reward
self.model_is_autoregressive = False

self.graph_sampler = GraphSampler(
ctx,
Expand Down Expand Up @@ -243,10 +244,15 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
if self.model_is_autoregressive:
torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0]) for tj in trajs]
actions = [self.ctx.GraphAction_to_aidx(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"]]
else:
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
batch.log_p_B = torch.cat([i["bck_logprobs"] for i in trajs], 0)
Expand Down Expand Up @@ -325,7 +331,10 @@ def compute_batch_losses(
if self.cfg.do_parameterize_p_b:
fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx])
else:
fwd_cat, per_graph_out = model(batch, cond_info[batch_idx])
if self.model_is_autoregressive:
fwd_cat, per_graph_out = model(batch, cond_info, batched=True)
else:
fwd_cat, per_graph_out = model(batch, cond_info[batch_idx])
# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
Expand Down
13 changes: 3 additions & 10 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag

# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
Expand Down Expand Up @@ -232,9 +231,6 @@ def __iter__(self):
# Override the is_valid key in case the task made some mols invalid
for i in range(num_online):
trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item()
if self.log_molecule_smis:
for i, m in zip(valid_idcs, valid_mols):
trajs[i]["smi"] = Chem.MolToSmiles(m)

# Compute scalar rewards from conditional information & flat rewards
flat_rewards = torch.stack(flat_rewards)
Expand Down Expand Up @@ -355,13 +351,10 @@ def validate_batch(self, batch, trajs):
raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep])

def log_generated(self, trajs, rewards, flat_rewards, cond_info):
if self.log_molecule_smis:
mols = [
Chem.MolToSmiles(self.ctx.graph_to_mol(trajs[i]["result"])) if trajs[i]["is_valid"] else ""
for i in range(len(trajs))
]
if hasattr(self.ctx, "object_to_log_repr"):
mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs]
else:
mols = [nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(t["result"], None, "v") for t in trajs]
mols = [""] * len(trajs)

flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist()
rewards = rewards.data.numpy().tolist()
Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,7 @@ def is_sane(self, g: Graph) -> bool:
if mol is None:
return False
return True

def object_to_log_repr(self, g: Graph):
"""Convert a Graph to a string representation"""
return Chem.MolToSmiles(self.graph_to_mol(g))
22 changes: 13 additions & 9 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -423,10 +423,11 @@ def __init__(
self,
graphs: gd.Batch,
logits: List[torch.Tensor],
keys: List[str],
keys: List[Union[str, None]],
types: List[GraphActionType],
deduplicate_edge_index=True,
masks: List[torch.Tensor] = None,
slice_dict: Optional[dict[str, torch.Tensor]] = None,
):
"""A multi-type Categorical compatible with generating structured actions.
Expand Down Expand Up @@ -461,13 +462,16 @@ def __init__(
be graph-level (i.e. if there are `k` graphs in the Batch
object, this logit tensor would have shape `(k, m)`)
types: List[GraphActionType]
The action type each logit corresponds to.
The action type each logit corresponds to.
deduplicate_edge_index: bool, default=True
If true, this means that the 'edge_index' keys have been reduced
by e_i[::2] (presumably because the graphs are undirected)
If true, this means that the 'edge_index' keys have been reduced
by e_i[::2] (presumably because the graphs are undirected)
masks: List[Tensor], default=None
If not None, a list of broadcastable tensors that multiplicatively
mask out logits of invalid actions
If not None, a list of broadcastable tensors that multiplicatively
mask out logits of invalid actions
slice_dist: Optional[dict[str, Tensor]], default=None
If not None, a map of tensors that indicate the start (and end) the graph index
of each object keyed. If None, uses the `_slice_dict` attribute of the graphs.
"""
self.num_graphs = graphs.num_graphs
assert all([i.ndim == 2 for i in logits])
Expand Down Expand Up @@ -504,9 +508,9 @@ def __init__(
for k in keys
]
# This is the cumulative sum (prefixed by 0) of N[i]s
slice_dict = graphs._slice_dict if slice_dict is None else slice_dict
self.slice = [
graphs._slice_dict[k].to(dev) if k is not None else torch.arange(graphs.num_graphs + 1, device=dev)
for k in keys
slice_dict[k].to(dev) if k is not None else torch.arange(graphs.num_graphs + 1, device=dev) for k in keys
]
self.logprobs = None

Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,7 @@ def is_sane(self, g: Graph) -> bool:
if mol is None:
return False
return True

def object_to_log_repr(self, g: Graph):
"""Convert a Graph to a string representation"""
return Chem.MolToSmiles(self.graph_to_mol(g))
120 changes: 120 additions & 0 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any, List, Tuple
from copy import deepcopy

import torch
import torch_geometric.data as gd
from torch_geometric.data import Data
from torch.nn.utils.rnn import pad_sequence

from gflownet.envs.graph_building_env import Graph, GraphAction
from .graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnv, GraphBuildingEnvContext


# For typing's sake, we'll pretend that a sequence is a graph.
class Seq(Graph):
def __init__(self):
self.seq: list[Any] = []

def __repr__(self):
return "".join(self.seq)

@property
def nodes(self):
return self.seq


class SeqBuildingEnv(GraphBuildingEnv):
"""This class masquerades as a GraphBuildingEnv, but actually generates sequences of tokens."""

def __init__(self, variant):
super().__init__()

def new(self):
return Seq()

def step(self, g: Graph, a: GraphAction):
s: Seq = deepcopy(g) # type: ignore
if a.action == GraphActionType.AddNode:
s.seq.append(a.value)
return s

def count_backward_transitions(self, g: Graph, check_idempotent: bool = False):
return 1

def parents(self, g: Graph):
s: Seq = deepcopy(g) # type: ignore
if not len(s.seq):
return []
v = s.seq.pop()
return [(GraphAction(GraphActionType.AddNode, value=v), s)]

def reverse(self, g: Graph, ga: GraphAction):
# TODO: if we implement non-LR variants we'll need to do something here
return GraphAction(GraphActionType.Stop)


class SeqBatch:
def __init__(self, seqs: List[torch.Tensor], pad: int):
self.seqs = seqs
self.x = pad_sequence(seqs, batch_first=False, padding_value=pad)
self.mask = self.x.eq(pad).T
self.lens = torch.tensor([len(i) for i in seqs], dtype=torch.long)
self.logit_idx = self.x.ne(pad).flatten().nonzero().flatten()
self.num_graphs = self.lens.sum().item()
# self.batch = torch.tensor([[i] * len(s) for i, s in enumerate(seqs)])

def to(self, device):
for name in dir(self):
x = getattr(self, name)
if isinstance(x, torch.Tensor):
setattr(self, name, x.to(device))
return self


class GenericSeqBuildingContext(GraphBuildingEnvContext):
def __init__(self, alphabet, num_cond_dim=0):
self.alphabet = alphabet
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]

self.num_tokens = len(alphabet) + 2 # Alphabet + BOS + PAD
self.bos_token = len(alphabet)
self.pad_token = len(alphabet) + 1
self.num_outputs = len(alphabet) + 1 # Alphabet + Stop
self.num_cond_dim = num_cond_dim

def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
act_type, _, act_col = [int(i) for i in action_idx]
t = self.action_type_order[act_type]
if t is GraphActionType.Stop:
return GraphAction(t)
elif t is GraphActionType.AddNode:
return GraphAction(t, value=act_col)
raise ValueError(action_idx)

def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, int]:
if action.action is GraphActionType.Stop:
col = 0
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.AddNode:
col = action.value
type_idx = self.action_type_order.index(action.action)
else:
raise ValueError(action)
return (type_idx, 0, int(col))

def graph_to_Data(self, g: Graph):
s: Seq = g # type: ignore
return torch.tensor([self.bos_token] + s.seq, dtype=torch.long)

def collate(self, graphs: List[Data]):
return SeqBatch(graphs, pad=self.pad_token)

def is_sane(self, g: Graph) -> bool:
return True

def graph_to_mol(self, g: Graph):
s: Seq = g # type: ignore
return "".join(self.alphabet[i] for i in s.seq)

def object_to_log_repr(self, g: Graph):
return self.graph_to_mol(g)
1 change: 1 addition & 0 deletions src/gflownet/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ class ModelConfig:

num_layers: int = 3
num_emb: int = 128
dropout: float = 0
graph_transformer: GraphTransformerConfig = GraphTransformerConfig()
120 changes: 120 additions & 0 deletions src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext
from gflownet.envs.seq_building_env import SeqBatch
from gflownet.config import Config


class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False):
super(MLP, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
layers = [nn.Linear(in_dim, hidden_layers[0]), nn.ReLU()]
layers += [nn.Dropout(dropout_prob)] if init_drop else []
for i in range(1, len(hidden_layers)):
layers.extend([nn.Linear(hidden_layers[i - 1], hidden_layers[i]), nn.ReLU(), nn.Dropout(dropout_prob)])
layers.append(nn.Linear(hidden_layers[-1], out_dim))
self.model = nn.Sequential(*layers)

def forward(self, x, with_uncertainty=False):
return self.model(x)


class SeqTransformerGFN(nn.Module):
ctx: GraphBuildingEnvContext

def __init__(
self,
env_ctx,
cfg: Config,
num_state_out=1,
):
super().__init__()
# num_hid, cond_dim, max_len, vocab_size, num_actions, dropout, num_layers, num_head, use_cond, **kwargs
self.ctx = env_ctx
num_hid = cfg.model.num_emb
num_outs = env_ctx.num_outputs + num_state_out
mc = cfg.model
self.pos = PositionalEncoding(num_hid, dropout=cfg.model.dropout, max_len=cfg.algo.max_len + 2)
self.use_cond = env_ctx.num_cond_dim > 0
self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid)
encoder_layers = nn.TransformerEncoderLayer(
num_hid, mc.graph_transformer.num_heads, num_hid, dropout=mc.dropout
)
self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers)
self.logZ = nn.Linear(env_ctx.num_cond_dim, 1)
if self.use_cond:
self.output = MLP(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid)
else:
self.output = MLP(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout)
self.num_hid = num_hid

def forward(self, xs: SeqBatch, cond, batched=False):
x = self.embedding(xs.x)
x = self.pos(x) # (time, batch, nemb)
x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device))
pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb)

if self.use_cond:
cond_var = self.cond_embed(cond) # (batch, nemb)
cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var
final_rep = torch.cat((x, cond_var), axis=-1) if batched else torch.cat((pooled_x, cond_var), axis=-1)
else:
final_rep = x if batched else pooled_x

out: torch.Tensor = self.output(final_rep)
if batched:
# out is (time, batch, nout)
out = out.transpose(1, 0).contiguous().reshape((-1, out.shape[2])) # (batch * time, nout)
stop_logits = out[xs.logit_idx, 0:1] # (proper_time, 1)
state_preds = out[xs.logit_idx, 1:2] # (proper_time, 1)
add_node_logits = out[xs.logit_idx, 2:] # (proper_time, nout - 1)
# `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs))
# which is what we need to give to GraphActionCategorical
else:
# The default num_graphs is computed for the batched case, so we need to fix it here so that
# GraphActionCategorical knows how many "graphs" (sequence inputs) there are
xs.num_graphs = out.shape[0]
# out is (batch, nout)
stop_logits = out[:, 0:1]
state_preds = out[:, 1:2]
add_node_logits = out[:, 2:]

return (
GraphActionCategorical(
xs,
logits=[stop_logits, add_node_logits],
keys=[None, None],
types=self.ctx.action_type_order,
slice_dict={},
),
state_preds,
)


def generate_square_subsequent_mask(sz: int):
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)

def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
Loading

0 comments on commit 4eb3e2f

Please sign in to comment.