Skip to content

Commit

Permalink
Merge branch 'trunk' into feat_faster_mol_Data
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe authored Sep 5, 2023
2 parents 60a560e + 4babde7 commit 96dec2b
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Model cache
src/gflownet/models/cache/


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TBConfig:
subtb_max_len: int = 128
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000
cum_subtb: bool = True


@dataclass
Expand Down
65 changes: 61 additions & 4 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,40 @@
from gflownet.trainer import GFNAlgorithm


def shift_right(x: torch.Tensor, z=0):
"Shift x right by 1, and put z in the first position"
x = torch.roll(x, 1, dims=0)
x[0] = z
return x


def cross(x: torch.Tensor):
"""
Calculate $y_{ij} = \sum_{t=i}^j x_t$.
The lower triangular portion is the inverse of the upper triangular one.
"""
assert x.ndim == 1
y = torch.cumsum(x, 0)
return y[None] - shift_right(y)[:, None]


def subTB(v: torch.tensor, x: torch.Tensor):
"""
Compute the SubTB(1):
$\forall i \leq j: D[i,j] =
\log \frac{F(s_i) \prod_{k=i}^{j} P_F(s_{k+1}|s_k)}
{F(s_{j + 1}) \prod_{k=i}^{j} P_B(s_k|s_{k+1})}$
for a single trajectory.
Note that x_k should be P_F(s_{k+1}|s_k) - P_B(s_k|s_{k+1}).
"""
assert v.ndim == x.ndim == 1
# D[i,j] = V[i] - V[j + 1]
D = v[:-1, None] - v[None, 1:]
# cross(x)[i, j] = sum(x[i:j+1])
D = D + cross(x)
return torch.triu(D)


class TrajectoryBalanceModel(nn.Module):
def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]:
raise NotImplementedError()
Expand Down Expand Up @@ -292,7 +326,6 @@ def compute_batch_losses(
fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx])
else:
fwd_cat, per_graph_out = model(batch, cond_info[batch_idx])

# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
Expand Down Expand Up @@ -349,7 +382,7 @@ def compute_batch_losses(
# We also have access to the is_sink attribute, which tells us when P_B must = 1, which
# we'll use to ignore the last padding state(s) of each trajectory. This by the same
# occasion masks out the first P_B of the "next" trajectory that we've shifted.
log_p_B = torch.cat([log_p_B[1:], log_p_B[:1]]) * (1 - batch.is_sink)
log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink)
else:
log_p_B = batch.log_p_B
assert log_p_F.shape == log_p_B.shape
Expand All @@ -360,7 +393,11 @@ def compute_batch_losses(

if self.cfg.do_subtb:
# SubTB interprets the per_graph_out predictions to predict the state flow F(s)
traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)
if self.cfg.cum_subtb:
traj_losses = self.subtb_cum(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)
else:
traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)

# The position of the first graph of each trajectory
first_graph_idx = torch.zeros_like(batch.traj_lens)
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
Expand Down Expand Up @@ -423,7 +460,6 @@ def compute_batch_losses(
"logZ": log_Z.mean(),
"loss": loss.item(),
}

return loss, info

def _init_subtb(self, dev):
Expand Down Expand Up @@ -515,3 +551,24 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
F_end = F_and_R[fidces]
total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T]
return total_loss

def subtb_cum(self, P_F, P_B, F, R, traj_lengths):
"""
Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming.
See also `subTB`
"""
dev = traj_lengths.device
num_trajs = len(traj_lengths)
total_loss = torch.zeros(num_trajs, device=dev)
x = torch.cumsum(traj_lengths, 0)
# P_B is already shifted
pdiff = P_F - P_B
for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)):
if self.cfg.do_parameterize_p_b:
e_idx -= 1
n = e_idx - s_idx
fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)])
p = pdiff[s_idx:e_idx]
total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2
return total_loss
38 changes: 33 additions & 5 deletions src/gflownet/models/bengio2021flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gzip
import os
import pickle # nosec
from pathlib import Path

import numpy as np
import requests # type: ignore
Expand Down Expand Up @@ -153,15 +154,42 @@ def forward(self, data, do_dropout=False):
return per_mol_out


def load_original_model():
num_feat = 14 + 1 + NUM_ATOMIC_NUMBERS
mpnn = MPNNet(num_feat=num_feat, num_vec=0, dim=64, num_out_per_mol=1, num_out_per_stem=105, num_conv_steps=12)
f = requests.get(
def request():
return requests.get(
"https://github.com/GFNOrg/gflownet/raw/master/mols/data/pretrained_proxy/best_params.pkl.gz",
stream=True,
timeout=30,
)
params = pickle.load(gzip.open(f.raw)) # nosec


def download(location):
f = request()
location.parent.mkdir(exist_ok=True)
with open(location, "wb") as fd:
for chunk in f.iter_content(chunk_size=128):
fd.write(chunk)


def load_weights(cache, location):
if not cache:
return pickle.load(gzip.open(request().raw)) # nosec

try:
gz = gzip.open(location)
except gzip.BadGzipFile:
download(location)
gz = gzip.open(location)
except FileNotFoundError:
download(location)
gz = gzip.open(location)
return pickle.load(gz) # nosec


def load_original_model(cache=True, location=Path(__file__).parent / "cache" / "bengio2021flow_proxy.pkl.gz"):
num_feat = 14 + 1 + NUM_ATOMIC_NUMBERS
mpnn = MPNNet(num_feat=num_feat, num_vec=0, dim=64, num_out_per_mol=1, num_out_per_stem=105, num_conv_steps=12)

params = load_weights(cache, location)
param_map = {
"lin0.weight": params[0],
"lin0.bias": params[1],
Expand Down
9 changes: 4 additions & 5 deletions src/gflownet/models/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ class GraphTransformer(nn.Module):
conditional information, since they condition the output). The graph features are projected to
virtual nodes (one per graph), which are fully connected.
The per node outputs are the concatenation of the final (post graph-convolution) node embeddings
and of the final virtual node embedding of the graph each node corresponds to.
The per node outputs are the final (post graph-convolution) node embeddings.
The per graph outputs are the concatenation of a global mean pooling operation, of the final
virtual node embeddings, and of the conditional information embedding.
node embeddings, and of the final virtual node embeddings.
"""

def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre"):
Expand Down Expand Up @@ -134,8 +133,8 @@ def forward(self, g: gd.Batch, cond: torch.Tensor):
o = o + l_h * scale + shift
o = o + ff(norm2(o, aug_batch))

glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1)
o_final = torch.cat([o[: -c.shape[0]]], 1)
o_final = o[: -c.shape[0]]
glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1)
return o_final, glob


Expand Down
29 changes: 29 additions & 0 deletions tests/test_subtb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from functools import reduce

import torch

from gflownet.algo.trajectory_balance import subTB


def subTB_ref(P_F, P_B, F):
h = F.shape[0] - 1
assert P_F.shape == P_B.shape == (h,)
assert F.ndim == 1

dtype = reduce(torch.promote_types, [P_F.dtype, P_B.dtype, F.dtype])
D = torch.zeros(h, h, dtype=dtype)
for i in range(h):
for j in range(i, h):
D[i, j] = F[i] - F[j + 1]
D[i, j] = D[i, j] + P_F[i : j + 1].sum()
D[i, j] = D[i, j] - P_B[i : j + 1].sum()
return D


def test_subTB():
for T in range(5, 20):
T = 10
P_F = torch.randint(1, 10, (T,))
P_B = torch.randint(1, 10, (T,))
F = torch.randint(1, 10, (T + 1,))
assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all()

0 comments on commit 96dec2b

Please sign in to comment.