diff --git a/.gitignore b/.gitignore index b6e47617..e0bd603b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Model cache +src/gflownet/models/cache/ + + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index bd0ce3de..e40303a3 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -37,6 +37,7 @@ class TBConfig: subtb_max_len: int = 128 Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 + cum_subtb: bool = True @dataclass diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 22fe655e..2915336c 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -22,6 +22,40 @@ from gflownet.trainer import GFNAlgorithm +def shift_right(x: torch.Tensor, z=0): + "Shift x right by 1, and put z in the first position" + x = torch.roll(x, 1, dims=0) + x[0] = z + return x + + +def cross(x: torch.Tensor): + """ + Calculate $y_{ij} = \sum_{t=i}^j x_t$. + The lower triangular portion is the inverse of the upper triangular one. + """ + assert x.ndim == 1 + y = torch.cumsum(x, 0) + return y[None] - shift_right(y)[:, None] + + +def subTB(v: torch.tensor, x: torch.Tensor): + """ + Compute the SubTB(1): + $\forall i \leq j: D[i,j] = + \log \frac{F(s_i) \prod_{k=i}^{j} P_F(s_{k+1}|s_k)} + {F(s_{j + 1}) \prod_{k=i}^{j} P_B(s_k|s_{k+1})}$ + for a single trajectory. + Note that x_k should be P_F(s_{k+1}|s_k) - P_B(s_k|s_{k+1}). + """ + assert v.ndim == x.ndim == 1 + # D[i,j] = V[i] - V[j + 1] + D = v[:-1, None] - v[None, 1:] + # cross(x)[i, j] = sum(x[i:j+1]) + D = D + cross(x) + return torch.triu(D) + + class TrajectoryBalanceModel(nn.Module): def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]: raise NotImplementedError() @@ -292,7 +326,6 @@ def compute_batch_losses( fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) else: fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) - # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] @@ -349,7 +382,7 @@ def compute_batch_losses( # We also have access to the is_sink attribute, which tells us when P_B must = 1, which # we'll use to ignore the last padding state(s) of each trajectory. This by the same # occasion masks out the first P_B of the "next" trajectory that we've shifted. - log_p_B = torch.cat([log_p_B[1:], log_p_B[:1]]) * (1 - batch.is_sink) + log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink) else: log_p_B = batch.log_p_B assert log_p_F.shape == log_p_B.shape @@ -360,7 +393,11 @@ def compute_batch_losses( if self.cfg.do_subtb: # SubTB interprets the per_graph_out predictions to predict the state flow F(s) - traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + if self.cfg.cum_subtb: + traj_losses = self.subtb_cum(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + else: + traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) + # The position of the first graph of each trajectory first_graph_idx = torch.zeros_like(batch.traj_lens) torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) @@ -423,7 +460,6 @@ def compute_batch_losses( "logZ": log_Z.mean(), "loss": loss.item(), } - return loss, info def _init_subtb(self, dev): @@ -515,3 +551,24 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): F_end = F_and_R[fidces] total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss + + def subtb_cum(self, P_F, P_B, F, R, traj_lengths): + """ + Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming. + + See also `subTB` + """ + dev = traj_lengths.device + num_trajs = len(traj_lengths) + total_loss = torch.zeros(num_trajs, device=dev) + x = torch.cumsum(traj_lengths, 0) + # P_B is already shifted + pdiff = P_F - P_B + for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)): + if self.cfg.do_parameterize_p_b: + e_idx -= 1 + n = e_idx - s_idx + fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)]) + p = pdiff[s_idx:e_idx] + total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2 + return total_loss diff --git a/tests/test_subtb.py b/tests/test_subtb.py new file mode 100644 index 00000000..c4841689 --- /dev/null +++ b/tests/test_subtb.py @@ -0,0 +1,29 @@ +from functools import reduce + +import torch + +from gflownet.algo.trajectory_balance import subTB + + +def subTB_ref(P_F, P_B, F): + h = F.shape[0] - 1 + assert P_F.shape == P_B.shape == (h,) + assert F.ndim == 1 + + dtype = reduce(torch.promote_types, [P_F.dtype, P_B.dtype, F.dtype]) + D = torch.zeros(h, h, dtype=dtype) + for i in range(h): + for j in range(i, h): + D[i, j] = F[i] - F[j + 1] + D[i, j] = D[i, j] + P_F[i : j + 1].sum() + D[i, j] = D[i, j] - P_B[i : j + 1].sum() + return D + + +def test_subTB(): + for T in range(5, 20): + T = 10 + P_F = torch.randint(1, 10, (T,)) + P_B = torch.randint(1, 10, (T,)) + F = torch.randint(1, 10, (T + 1,)) + assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all()