Skip to content

Commit

Permalink
add faster subTB
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Aug 16, 2023
1 parent 4b3f41f commit dfb2db1
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Model cache
src/gflownet/models/cache/bengio2021flow_proxy.pkl.gz


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TBConfig:
subtb_max_len: int = 128
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000
cum_subtb: bool = True


@dataclass
Expand Down
66 changes: 62 additions & 4 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor
from torch.utils import benchmark
from torch_scatter import scatter, scatter_sum

from gflownet.algo.graph_sampling import GraphSampler
Expand All @@ -22,6 +23,38 @@
from gflownet.trainer import GFNAlgorithm


def shift_right(x: torch.Tensor, z=0):
"Shift x right by 1, and put z in the first position"
x = torch.roll(x, 1, dims=0)
x[0] = z
return x


def cross(x: torch.Tensor):
"""
Calculate $y_{ij} = \sum_{t=i}^j x_t$.
The lower triangular portion is the inverse of the upper triangular one.
"""
assert x.ndim == 1
y = torch.cumsum(x, 0)
return y[None] - shift_right(y)[:, None]


def subTB(v: torch.tensor, x: torch.Tensor):
"""
Compute the SubTB(1):
$\forall i \leq j: D[i,j] = \log \frac{F(s_i) \prod_{k=i}^{j} P_F(s_{k+1}|s_k)}{F(s_{j + 1}) \prod_{k=i}^{j} P_B(s_k|s_{k+1})}$
for a single trajectory.
Note that x_k should be P_F(s_{k+1}|s_k) - P_B(s_k|s_{k+1}).
"""
assert v.ndim == x.ndim == 1
# D[i,j] = V[i] - V[j + 1]
D = v[:-1, None] - v[None, 1:]
# cross(x)[i, j] = sum(x[i:j+1])
D = D + cross(x)
return torch.triu(D)


class TrajectoryBalanceModel(nn.Module):
def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]:
raise NotImplementedError()
Expand Down Expand Up @@ -292,7 +325,6 @@ def compute_batch_losses(
fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx])
else:
fwd_cat, per_graph_out = model(batch, cond_info[batch_idx])

# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
Expand Down Expand Up @@ -349,7 +381,7 @@ def compute_batch_losses(
# We also have access to the is_sink attribute, which tells us when P_B must = 1, which
# we'll use to ignore the last padding state(s) of each trajectory. This by the same
# occasion masks out the first P_B of the "next" trajectory that we've shifted.
log_p_B = torch.cat([log_p_B[1:], log_p_B[:1]]) * (1 - batch.is_sink)
log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink)
else:
log_p_B = batch.log_p_B
assert log_p_F.shape == log_p_B.shape
Expand All @@ -360,7 +392,11 @@ def compute_batch_losses(

if self.cfg.do_subtb:
# SubTB interprets the per_graph_out predictions to predict the state flow F(s)
traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)
if self.cfg.cum_subtb:
traj_losses = self.subtb_cum(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)
else:
traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)

# The position of the first graph of each trajectory
first_graph_idx = torch.zeros_like(batch.traj_lens)
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
Expand Down Expand Up @@ -423,7 +459,8 @@ def compute_batch_losses(
"logZ": log_Z.mean(),
"loss": loss.item(),
}

if self.cfg.do_subtb and self.cfg.cum_subtb:
info["subtb_diff"] = subtb_diff
return loss, info

def _init_subtb(self, dev):
Expand Down Expand Up @@ -515,3 +552,24 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
F_end = F_and_R[fidces]
total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T]
return total_loss

def subtb_cum(self, P_F, P_B, F, R, traj_lengths):
"""
Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming.
See also `subTB`
"""
dev = traj_lengths.device
num_trajs = len(traj_lengths)
total_loss = torch.zeros(num_trajs, device=dev)
x = torch.cumsum(traj_lengths, 0)
# P_B is already shifted
pdiff = P_F - P_B
for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)):
if self.cfg.do_parameterize_p_b:
e_idx -= 1
n = e_idx - s_idx
fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)])
p = pdiff[s_idx:e_idx]
total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2
return total_loss
28 changes: 28 additions & 0 deletions tests/test_subtb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from functools import reduce
from gflownet.algo.trajectory_balance import subTB

import torch


def subTB_ref(P_F, P_B, F):
h = F.shape[0] - 1
assert P_F.shape == P_B.shape == (h,)
assert F.ndim == 1

dtype = reduce(torch.promote_types, [P_F.dtype, P_B.dtype, F.dtype])
D = torch.zeros(h, h, dtype=dtype)
for i in range(h):
for j in range(i, h):
D[i, j] = F[i] - F[j + 1]
D[i, j] = D[i, j] + P_F[i : j + 1].sum()
D[i, j] = D[i, j] - P_B[i : j + 1].sum()
return D


def test_subTB():
for T in range(5, 20):
T = 10
P_F = torch.randint(1, 10, (T,))
P_B = torch.randint(1, 10, (T,))
F = torch.randint(1, 10, (T + 1,))
assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all()

0 comments on commit dfb2db1

Please sign in to comment.