Skip to content

Commit

Permalink
Add support for backwards actions in MolBuildingEnvContext (#100)
Browse files Browse the repository at this point in the history
This PR:
- adds support for backward actions in `MolBuildingEnvContext`,
- adds tests for `MolBuildingEnvContext` that check that backwards masks are correct,
- adds a toy atom environment where the reward is simply the number of rings in a carbon-only molecule of up to 6 atoms.
- fixes a bug in `GraphBuildingEnv.parent`, whereby 1-node graphs with attributes would have their parents misenumerated (which hadn't manifested itself because until now we didn't have contexts with more than 1 node attribute).

* add backward mask support + small ring task

* more comments and implementation notes
  • Loading branch information
bengioe authored Aug 4, 2023
1 parent 152b18f commit ec857a5
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 26 deletions.
18 changes: 18 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ We separate experiment concerns in four categories:
- The Trainer class is responsible for instanciating everything, and running the training & testing loop

Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.


## Graphs

This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations.

Some notes:
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.


### Graph policies & graph action categoricals

The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch.

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
4 changes: 2 additions & 2 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def add_parent(a, new_g):
GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]),
new_g,
)
if len(g.nodes) == 1:
if len(g.nodes) == 1 and len(g.nodes[i]) == 1:
# The final node is degree 0, need this special case to remove it
# and end up with S0, the empty graph root
# and end up with S0, the empty graph root (but only if it has no attrs except 'v')
add_parent(
GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]),
graph_without_node(g, i),
Expand Down
117 changes: 101 additions & 16 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from rdkit.Chem import Mol
from rdkit.Chem.rdchem import BondType, ChiralType

from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext
from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
GraphActionType,
GraphBuildingEnvContext,
graph_without_edge,
)
from gflownet.utils.graphs import random_walk_probs

DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW]
Expand Down Expand Up @@ -77,19 +83,22 @@ def __init__(
# The size of the input vector for each atom
self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values())
self.atom_attrs = sorted(self.atom_attr_values.keys())
# 'v' is set separately when creating the node, so there's no point in having a SetNodeAttr logit for it
self.settable_atom_attrs = [i for i in self.atom_attrs if i != "v"]
# The beginning position within the input vector of each attribute
self.atom_attr_slice = [0] + list(np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs]))
# The beginning position within the logit vector of each attribute
num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs]
num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.settable_atom_attrs]
self.atom_attr_logit_slice = {
k: (s, e)
for k, s, e in zip(self.atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits))
for k, s, e in zip(
self.settable_atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits)
)
}
# The attribute and value each logit dimension maps back to
self.atom_attr_logit_map = [
(k, v)
for k in self.atom_attrs
if k != "v"
for k in self.settable_atom_attrs
# index 0 is skipped because it is the default value
for v in self.atom_attr_values[k][1:]
]
Expand Down Expand Up @@ -147,12 +156,21 @@ def __init__(
GraphActionType.AddEdge,
GraphActionType.SetEdgeAttr,
]
self.bck_action_type_order = [
GraphActionType.RemoveNode,
GraphActionType.RemoveNodeAttr,
GraphActionType.RemoveEdge,
GraphActionType.RemoveEdgeAttr,
]
self.device = torch.device("cpu")

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction"""
act_type, act_row, act_col = [int(i) for i in action_idx]
t = self.action_type_order[act_type]
if fwd:
t = self.action_type_order[act_type]
else:
t = self.bck_action_type_order[act_type]
if t is GraphActionType.Stop:
return GraphAction(t)
elif t is GraphActionType.AddNode:
Expand All @@ -164,12 +182,34 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd:
a, b = g.non_edge_index[:, act_row]
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.SetEdgeAttr:
a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
a, b = g.edge_index[:, act_row * 2]
attr, val = self.bond_attr_logit_map[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val)
elif t is GraphActionType.RemoveNode:
return GraphAction(t, source=act_row)
elif t is GraphActionType.RemoveNodeAttr:
attr = self.settable_atom_attrs[act_col]
return GraphAction(t, source=act_row, attr=attr)
elif t is GraphActionType.RemoveEdge:
a, b = g.edge_index[:, act_row * 2] # see note above about edge_index
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.RemoveEdgeAttr:
a, b = g.edge_index[:, act_row * 2] # see note above about edge_index
attr = self.bond_attrs[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr)

def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]:
"""Translate a GraphAction to an index tuple"""
for u in [self.action_type_order, self.bck_action_type_order]:
if action.action in u:
type_idx = u.index(action.action)
break
else:
raise ValueError(f"Unknown action type {action.action}")

if action.action is GraphActionType.Stop:
row = col = 0
elif action.action is GraphActionType.AddNode:
Expand All @@ -191,17 +231,33 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
).argmax()
col = 0
elif action.action is GraphActionType.SetEdgeAttr:
# Here the edges are duplicated, both (i,j) and (j,i) are in edge_index
# so no need for a double check.
# row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) +
# (g.edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)).argmax()
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax()
# Because edges are duplicated but logits aren't, divide by two
row = row.div(2, rounding_mode="floor") # type: ignore
col = (
self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0]
)
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.RemoveNode:
row = action.source
col = 0
elif action.action is GraphActionType.RemoveNodeAttr:
row = action.source
col = self.settable_atom_attrs.index(action.attr)
elif action.action is GraphActionType.RemoveEdge:
row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax()
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
row = int(row) // 2
col = 0
elif action.action is GraphActionType.RemoveEdgeAttr:
row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax()
row = row.div(2, rounding_mode="floor") # type: ignore
col = self.bond_attrs.index(action.attr)
else:
raise ValueError(f"Unknown action type {action.action}")
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
Expand All @@ -211,25 +267,43 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
add_node_mask = torch.ones((x.shape[0], self.num_new_node_values))
if self.max_nodes is not None and len(g.nodes) >= self.max_nodes:
add_node_mask *= 0
remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
remove_node_attr_mask = torch.zeros((x.shape[0], len(self.settable_atom_attrs)))

explicit_valence = {}
max_valence = {}
set_node_attr_mask = torch.ones((x.shape[0], self.num_node_attr_logits))
if not len(g.nodes):
set_node_attr_mask *= 0
for i, n in enumerate(g.nodes):
ad = g.nodes[n]
if g.degree(n) <= 1 and len(ad) == 1 and all([len(g[n][neigh]) == 0 for neigh in g.neighbors(n)]):
# If there's only the 'v' key left and the node is a leaf, and the edge that connect to the node have
# no attributes set, we can remove it
remove_node_mask[i] = 1
for k, sl in zip(self.atom_attrs, self.atom_attr_slice):
# idx > 0 means that the attribute is not the default value
idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0
x[i, sl + idx] = 1
# If the attribute is already there, mask out logits
# (or if the attribute is a negative attribute and has been filled)
if k == "v":
continue
# If the attribute
# - is already there (idx > 0),
# - or the attribute is a negative attribute and has been filled
# - or the attribute is a negative attribute and is not fillable (i.e. not a key of ad)
# then mask forward logits.
# For backward logits, positively mask if the attribute is there (idx > 0).
if k in self.negative_attrs:
if k in ad and idx > 0 or k not in ad:
s, e = self.atom_attr_logit_slice[k]
set_node_attr_mask[i, s:e] = 0
# We don't want to make the attribute removable if it's not fillable (i.e. not a key of ad)
if k in ad:
remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1
elif k in ad:
s, e = self.atom_attr_logit_slice[k]
set_node_attr_mask[i, s:e] = 0
remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1
# Account for charge and explicit Hs in atom as limiting the total valence
max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]]
# Special rule for Nitrogen
Expand All @@ -256,8 +330,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
s, e = self.atom_attr_logit_slice["expl_H"]
set_node_attr_mask[i, s:e] = 0

remove_edge_mask = torch.zeros((len(g.edges), 1))
for i, (u, v) in enumerate(g.edges):
if g.degree(u) > 1 and g.degree(v) > 1:
if nx.algorithms.is_connected(graph_without_edge(g, (u, v))):
remove_edge_mask[i] = 1
edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))
set_edge_attr_mask = torch.zeros((len(g.edges), self.num_edge_attr_logits))
remove_edge_attr_mask = torch.zeros((len(g.edges), len(self.bond_attrs)))
for i, e in enumerate(g.edges):
ad = g.edges[e]
for k, sl in zip(self.bond_attrs, self.bond_attr_slice):
Expand All @@ -267,6 +347,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
if k in ad: # If the attribute is already there, mask out logits
s, e = self.bond_attr_logit_slice[k]
set_edge_attr_mask[i, s:e] = 0
remove_edge_attr_mask[i, self.bond_attrs.index(k)] = 1
# Check which bonds don't bust the valence of their atoms
if "type" not in ad: # Only if type isn't already set
sl, _ = self.bond_attr_logit_slice["type"]
Expand All @@ -293,11 +374,15 @@ def is_ok_non_edge(e):
edge_index,
edge_attr,
non_edge_index=non_edge_index,
stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1),
stop_mask=torch.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node
add_node_mask=add_node_mask,
set_node_attr_mask=set_node_attr_mask,
add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge
set_edge_attr_mask=set_edge_attr_mask,
remove_node_mask=remove_node_mask,
remove_node_attr_mask=remove_node_attr_mask,
remove_edge_mask=remove_edge_mask,
remove_edge_attr_mask=remove_edge_attr_mask,
)
if self.num_rw_feat > 0:
data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1)
Expand Down
90 changes: 90 additions & 0 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import socket
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor

from gflownet.config import Config
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar


class MakeRingsTask(GFNTask):
"""A toy task where the reward is the number of rings in the molecule."""

def __init__(
self,
rng: np.random.Generator,
):
self.rng = rng

def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
return FlatRewards(y)

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)}

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log()
return RewardScalar(scalar_logreward.flatten())

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float()
return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool()


class MakeRingsTrainer(StandardOnlineTrainer):
def set_default_hps(self, cfg: Config):
cfg.hostname = socket.gethostname()
cfg.num_workers = 8
cfg.algo.global_batch_size = 64
cfg.algo.offline_ratio = 0
cfg.model.num_emb = 128
cfg.model.num_layers = 4

cfg.algo.method = "TB"
cfg.algo.max_nodes = 6
cfg.algo.sampling_tau = 0.9
cfg.algo.illegal_action_logreward = -75
cfg.algo.train_random_action_prob = 0.0
cfg.algo.valid_random_action_prob = 0.0
cfg.algo.tb.do_parameterize_p_b = True

cfg.replay.use = False

def setup_task(self):
self.task = MakeRingsTask(rng=self.rng)

def setup_env_context(self):
self.ctx = MolBuildingEnvContext(
["C"],
charges=[0], # disable charge
chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality
num_rw_feat=0,
max_nodes=self.cfg.algo.max_nodes,
num_cond_dim=1,
)


def main():
hps = {
"log_dir": "./logs/debug_run_mr4",
"device": "cuda",
"num_training_steps": 10_000,
"num_workers": 8,
"algo": {"tb": {"do_parameterize_p_b": True}},
}
os.makedirs(hps["log_dir"], exist_ok=True)

trial = MakeRingsTrainer(hps)
trial.print_every = 1
trial.run()


if __name__ == "__main__":
main()
Loading

0 comments on commit ec857a5

Please sign in to comment.