diff --git a/.gitignore b/.gitignore index b6e47617..e0bd603b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Model cache +src/gflownet/models/cache/ + + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index bd0ce3de..e40303a3 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -37,6 +37,7 @@ class TBConfig: subtb_max_len: int = 128 Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 + cum_subtb: bool = True @dataclass diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 22fe655e..2915336c 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -22,6 +22,40 @@ from gflownet.trainer import GFNAlgorithm +def shift_right(x: torch.Tensor, z=0): + "Shift x right by 1, and put z in the first position" + x = torch.roll(x, 1, dims=0) + x[0] = z + return x + + +def cross(x: torch.Tensor): + """ + Calculate $y_{ij} = \sum_{t=i}^j x_t$. + The lower triangular portion is the inverse of the upper triangular one. + """ + assert x.ndim == 1 + y = torch.cumsum(x, 0) + return y[None] - shift_right(y)[:, None] + + +def subTB(v: torch.tensor, x: torch.Tensor): + """ + Compute the SubTB(1): + $\forall i \leq j: D[i,j] = + \log \frac{F(s_i) \prod_{k=i}^{j} P_F(s_{k+1}|s_k)} + {F(s_{j + 1}) \prod_{k=i}^{j} P_B(s_k|s_{k+1})}$ + for a single trajectory. + Note that x_k should be P_F(s_{k+1}|s_k) - P_B(s_k|s_{k+1}). + """ + assert v.ndim == x.ndim == 1 + # D[i,j] = V[i] - V[j + 1] + D = v[:-1, None] - v[None, 1:] + # cross(x)[i, j] = sum(x[i:j+1]) + D = D + cross(x) + return torch.triu(D) + + class TrajectoryBalanceModel(nn.Module): def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]: raise NotImplementedError() @@ -292,7 +326,6 @@ def compute_batch_losses( fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) else: fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) - # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] @@ -349,7 +382,7 @@ def compute_batch_losses( # We also have access to the is_sink attribute, which tells us when P_B must = 1, which # we'll use to ignore the last padding state(s) of each trajectory. This by the same # occasion masks out the first P_B of the "next" trajectory that we've shifted. - log_p_B = torch.cat([log_p_B[1:], log_p_B[:1]]) * (1 - batch.is_sink) + log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink) else: log_p_B = batch.log_p_B assert log_p_F.shape == log_p_B.shape @@ -360,7 +393,11 @@ def compute_batch_losses( if self.cfg.do_subtb: # SubTB interprets the per_graph_out predictions to predict the state flow F(s) - traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + if self.cfg.cum_subtb: + traj_losses = self.subtb_cum(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + else: + traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + # The position of the first graph of each trajectory first_graph_idx = torch.zeros_like(batch.traj_lens) torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) @@ -423,7 +460,6 @@ def compute_batch_losses( "logZ": log_Z.mean(), "loss": loss.item(), } - return loss, info def _init_subtb(self, dev): @@ -515,3 +551,24 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): F_end = F_and_R[fidces] total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss + + def subtb_cum(self, P_F, P_B, F, R, traj_lengths): + """ + Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming. + + See also `subTB` + """ + dev = traj_lengths.device + num_trajs = len(traj_lengths) + total_loss = torch.zeros(num_trajs, device=dev) + x = torch.cumsum(traj_lengths, 0) + # P_B is already shifted + pdiff = P_F - P_B + for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)): + if self.cfg.do_parameterize_p_b: + e_idx -= 1 + n = e_idx - s_idx + fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)]) + p = pdiff[s_idx:e_idx] + total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2 + return total_loss diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index 6352f7e3..9797e15f 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -10,6 +10,7 @@ import gzip import os import pickle # nosec +from pathlib import Path import numpy as np import requests # type: ignore @@ -153,15 +154,42 @@ def forward(self, data, do_dropout=False): return per_mol_out -def load_original_model(): - num_feat = 14 + 1 + NUM_ATOMIC_NUMBERS - mpnn = MPNNet(num_feat=num_feat, num_vec=0, dim=64, num_out_per_mol=1, num_out_per_stem=105, num_conv_steps=12) - f = requests.get( +def request(): + return requests.get( "https://github.com/GFNOrg/gflownet/raw/master/mols/data/pretrained_proxy/best_params.pkl.gz", stream=True, timeout=30, ) - params = pickle.load(gzip.open(f.raw)) # nosec + + +def download(location): + f = request() + location.parent.mkdir(exist_ok=True) + with open(location, "wb") as fd: + for chunk in f.iter_content(chunk_size=128): + fd.write(chunk) + + +def load_weights(cache, location): + if not cache: + return pickle.load(gzip.open(request().raw)) # nosec + + try: + gz = gzip.open(location) + except gzip.BadGzipFile: + download(location) + gz = gzip.open(location) + except FileNotFoundError: + download(location) + gz = gzip.open(location) + return pickle.load(gz) # nosec + + +def load_original_model(cache=True, location=Path(__file__).parent / "cache" / "bengio2021flow_proxy.pkl.gz"): + num_feat = 14 + 1 + NUM_ATOMIC_NUMBERS + mpnn = MPNNet(num_feat=num_feat, num_vec=0, dim=64, num_out_per_mol=1, num_out_per_stem=105, num_conv_steps=12) + + params = load_weights(cache, location) param_map = { "lin0.weight": params[0], "lin0.bias": params[1], diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 05f9b0e4..8c3993f0 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -25,11 +25,10 @@ class GraphTransformer(nn.Module): conditional information, since they condition the output). The graph features are projected to virtual nodes (one per graph), which are fully connected. - The per node outputs are the concatenation of the final (post graph-convolution) node embeddings - and of the final virtual node embedding of the graph each node corresponds to. + The per node outputs are the final (post graph-convolution) node embeddings. The per graph outputs are the concatenation of a global mean pooling operation, of the final - virtual node embeddings, and of the conditional information embedding. + node embeddings, and of the final virtual node embeddings. """ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre"): @@ -134,8 +133,8 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): o = o + l_h * scale + shift o = o + ff(norm2(o, aug_batch)) - glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1) - o_final = torch.cat([o[: -c.shape[0]]], 1) + o_final = o[: -c.shape[0]] + glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) return o_final, glob diff --git a/tests/test_subtb.py b/tests/test_subtb.py new file mode 100644 index 00000000..c4841689 --- /dev/null +++ b/tests/test_subtb.py @@ -0,0 +1,29 @@ +from functools import reduce + +import torch + +from gflownet.algo.trajectory_balance import subTB + + +def subTB_ref(P_F, P_B, F): + h = F.shape[0] - 1 + assert P_F.shape == P_B.shape == (h,) + assert F.ndim == 1 + + dtype = reduce(torch.promote_types, [P_F.dtype, P_B.dtype, F.dtype]) + D = torch.zeros(h, h, dtype=dtype) + for i in range(h): + for j in range(i, h): + D[i, j] = F[i] - F[j + 1] + D[i, j] = D[i, j] + P_F[i : j + 1].sum() + D[i, j] = D[i, j] - P_B[i : j + 1].sum() + return D + + +def test_subTB(): + for T in range(5, 20): + T = 10 + P_F = torch.randint(1, 10, (T,)) + P_B = torch.randint(1, 10, (T,)) + F = torch.randint(1, 10, (T + 1,)) + assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all()