diff --git a/README.md b/README.md index 00791982..3b7a143d 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera ## Repo overview -- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations (only [Trajectory Balance](https://arxiv.org/abs/2201.13259) for now), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. +- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. - [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. - [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data. - [examples](docs/examples), contains simple example implementations of GFlowNet. @@ -30,8 +30,11 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). -- [utils](src/gflownet/utils), contains utilities (multiprocessing). -- [`train.py`](src/gflownet/train.py), defines a general harness for training GFlowNet models. +- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). +- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. +- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. + +See [implementation notes](docs/implementation_notes.md) for more. ## Getting started @@ -44,12 +47,12 @@ A good place to get started is with the [sEH fragment-based MOO task](src/gflown This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the `--find-links` arguments must be specified as well: ```bash -pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cu113.html +pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu117.html ``` Or for CPU use: ```bash -pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cpu.html +pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html ``` To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository-pip-dependency/) a specific tag, for example here `v0.0.10`, use the following scheme: @@ -57,6 +60,8 @@ To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository- pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ... ``` +If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main_3.9.txt`. + ## Developing & Contributing TODO: Write Contributing.md. diff --git a/VERSION b/VERSION index bd263dfa..6a9e0389 100644 --- a/VERSION +++ b/VERSION @@ -1,2 +1,2 @@ MAJOR="0" -MINOR="0" +MINOR="1" diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md new file mode 100644 index 00000000..6930728d --- /dev/null +++ b/docs/implementation_notes.md @@ -0,0 +1,36 @@ +# Implementation notes + +This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments. + +## Environment, Context, Task, Trainers + +We separate experiment concerns in four categories: +- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP. +- The Context provides an interface between the agent and the environment, it + - maps graphs to torch_geometric `Data` + instances + - maps GraphActions to action indices + - produces action masks + - communicates to the model what inputs it should expect +- The Task class is responsible for computing the reward of a state, and for sampling conditioning information +- The Trainer class is responsible for instanciating everything, and running the training & testing loop + +Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. + + +## Graphs + +This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations. + +Some notes: +- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs. +- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding. + + +### Graph policies & graph action categoricals + +The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch. + +Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. + +The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 25209cff..7afb21b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,14 +56,14 @@ keywords = ["gflownet"] requires-python = ">=3.8,<3.10" dynamic = ["version"] dependencies = [ - "torch==1.10.2", + "torch==1.13.1", # These pins are specific on purpose, some of these packages have # unstable APIs since they are fairly new. We could instead pin # them as >= in dev until something breaks? "torch-geometric==2.0.3", - "torch-scatter==2.0.9", - "torch-sparse==0.6.13", - "torch-cluster==1.6.0", + "torch-scatter", + "torch-sparse", + "torch-cluster", "rdkit", "tables", "scipy", @@ -71,10 +71,11 @@ dependencies = [ "tensorboard", "cvxopt", "pyarrow", - "botorch==0.6.6", # pin because of the torch==1.10.2 dependency, botorch>=0.7 requires torch>=1.11 - # pins to help depencency resolution, because of the above pin - "pyro-ppl==1.8.0", - "gpytorch==1.8.1", + "gitpython", + "botorch", + "pyro-ppl", + "gpytorch", + "omegaconf>=2.3", ] [project.optional-dependencies] diff --git a/requirements/dev_3.8.txt b/requirements/dev_3.8.txt deleted file mode 100644 index 4b45ebb0..00000000 --- a/requirements/dev_3.8.txt +++ /dev/null @@ -1,297 +0,0 @@ -absl-py==1.4.0 - # via tensorboard -bandit[toml]==1.7.5 - # via gflownet (pyproject.toml) -black==23.3.0 - # via gflownet (pyproject.toml) -blosc2==2.0.0 - # via tables -botorch==0.6.6 - # via gflownet (pyproject.toml) -build==0.10.0 - # via pip-tools -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 - # via requests -cfgv==3.3.1 - # via pre-commit -charset-normalizer==3.1.0 - # via requests -click==8.1.3 - # via - # black - # pip-compile-multi - # pip-tools -coverage[toml]==7.2.5 - # via pytest-cov -cvxopt==1.3.0 - # via gflownet (pyproject.toml) -cython==0.29.34 - # via tables -distlib==0.3.6 - # via virtualenv -exceptiongroup==1.1.1 - # via pytest -filelock==3.12.0 - # via virtualenv -gitdb==4.0.10 - # via gitpython -gitpython==3.1.31 - # via - # bandit - # gflownet (pyproject.toml) -google-auth==2.17.3 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -googledrivedownloader==0.4 - # via torch-geometric -gpytorch==1.8.1 - # via - # botorch - # gflownet (pyproject.toml) -grpcio==1.54.0 - # via tensorboard -identify==2.5.23 - # via pre-commit -idna==3.4 - # via requests -importlib-metadata==6.6.0 - # via - # markdown - # typeguard -iniconfig==2.0.0 - # via pytest -isodate==0.6.1 - # via rdflib -isort==5.12.0 - # via gflownet (pyproject.toml) -jinja2==3.1.2 - # via torch-geometric -joblib==1.2.0 - # via scikit-learn -markdown==3.4.3 - # via tensorboard -markdown-it-py==2.2.0 - # via rich -markupsafe==2.1.2 - # via - # jinja2 - # werkzeug -mdurl==0.1.2 - # via markdown-it-py -msgpack==1.0.5 - # via blosc2 -multipledispatch==0.6.0 - # via botorch -mypy==1.2.0 - # via gflownet (pyproject.toml) -mypy-extensions==1.0.0 - # via - # black - # mypy -networkx==3.1 - # via - # gflownet (pyproject.toml) - # torch-geometric -nodeenv==1.7.0 - # via pre-commit -numexpr==2.8.4 - # via tables -numpy==1.24.3 - # via - # gpytorch - # numexpr - # opt-einsum - # pandas - # pyarrow - # pyro-ppl - # rdkit - # scikit-learn - # scipy - # tables - # tensorboard - # torch-geometric -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via pyro-ppl -packaging==23.1 - # via - # black - # build - # pytest - # tables -pandas==2.0.1 - # via torch-geometric -pathspec==0.11.1 - # via black -pbr==5.11.1 - # via stevedore -pillow==9.5.0 - # via rdkit -pip-compile-multi==2.6.2 - # via gflownet (pyproject.toml) -pip-tools==6.13.0 - # via pip-compile-multi -platformdirs==3.5.0 - # via - # black - # virtualenv -pluggy==1.0.0 - # via pytest -pre-commit==3.2.2 - # via gflownet (pyproject.toml) -protobuf==4.22.3 - # via tensorboard -py-cpuinfo==9.0.0 - # via tables -pyarrow==11.0.0 - # via gflownet (pyproject.toml) -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pygments==2.15.1 - # via rich -pyparsing==3.0.9 - # via - # rdflib - # torch-geometric -pyproject-hooks==1.0.0 - # via build -pyro-api==0.1.2 - # via pyro-ppl -pyro-ppl==1.8.0 - # via - # botorch - # gflownet (pyproject.toml) -pytest==7.3.1 - # via - # gflownet (pyproject.toml) - # pytest-cov -pytest-cov==4.0.0 - # via gflownet (pyproject.toml) -python-dateutil==2.8.2 - # via pandas -pytz==2023.3 - # via pandas -pyyaml==6.0 - # via - # bandit - # pre-commit - # torch-geometric - # yacs -rdflib==6.3.2 - # via torch-geometric -rdkit==2022.9.5 - # via gflownet (pyproject.toml) -requests==2.29.0 - # via - # requests-oauthlib - # tensorboard - # torch-geometric -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.3.5 - # via bandit -rsa==4.9 - # via google-auth -ruff==0.0.263 - # via gflownet (pyproject.toml) -scikit-learn==1.2.2 - # via - # gpytorch - # torch-geometric -scipy==1.10.1 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # scikit-learn - # torch-geometric - # torch-sparse -six==1.16.0 - # via - # google-auth - # isodate - # multipledispatch - # python-dateutil -smmap==5.0.0 - # via gitdb -stevedore==5.0.0 - # via bandit -tables==3.8.0 - # via gflownet (pyproject.toml) -tensorboard==2.12.2 - # via gflownet (pyproject.toml) -tensorboard-data-server==0.7.0 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.1.0 - # via scikit-learn -tomli==2.0.1 - # via - # bandit - # black - # build - # coverage - # mypy - # pytest -toposort==1.10 - # via pip-compile-multi -torch==1.10.2 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # pyro-ppl -torch-cluster==1.6.0 - # via gflownet (pyproject.toml) -torch-geometric==2.0.3 - # via gflownet (pyproject.toml) -torch-scatter==2.0.9 - # via gflownet (pyproject.toml) -torch-sparse==0.6.13 - # via gflownet (pyproject.toml) -tqdm==4.65.0 - # via - # pyro-ppl - # torch-geometric -typeguard==3.0.2 - # via gflownet (pyproject.toml) -types-pkg-resources==0.1.3 - # via gflownet (pyproject.toml) -typing-extensions==4.5.0 - # via - # black - # mypy - # rich - # torch - # typeguard -tzdata==2023.3 - # via pandas -urllib3==1.26.15 - # via requests -virtualenv==20.23.0 - # via pre-commit -werkzeug==2.3.3 - # via tensorboard -wheel==0.40.0 - # via - # pip-tools - # tensorboard -yacs==0.1.8 - # via torch-geometric -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools diff --git a/requirements/dev_3.9.txt b/requirements/dev_3.9.txt index 34784306..4d755596 100644 --- a/requirements/dev_3.9.txt +++ b/requirements/dev_3.9.txt @@ -117,6 +117,8 @@ numpy==1.24.2 # torch-geometric oauthlib==3.2.2 # via requests-oauthlib +omegaconf==2.3.0 + # via gflownet opt-einsum==3.3.0 # via pyro-ppl packaging==23.1 @@ -246,7 +248,7 @@ tomli==2.0.1 # pytest toposort==1.10 # via pip-compile-multi -torch==1.10.2 +torch==1.13.1 # via # botorch # gflownet (pyproject.toml) @@ -256,9 +258,9 @@ torch-cluster==1.6.0 # via gflownet (pyproject.toml) torch-geometric==2.0.3 # via gflownet (pyproject.toml) -torch-scatter==2.0.9 +torch-scatter==2.1.1 # via gflownet (pyproject.toml) -torch-sparse==0.6.13 +torch-sparse==0.6.17 # via gflownet (pyproject.toml) tqdm==4.65.0 # via diff --git a/requirements/main_3.8.txt b/requirements/main_3.8.txt deleted file mode 100644 index e9cff706..00000000 --- a/requirements/main_3.8.txt +++ /dev/null @@ -1,186 +0,0 @@ -absl-py==1.4.0 - # via tensorboard -blosc2==2.0.0 - # via tables -botorch==0.6.6 - # via gflownet (pyproject.toml) -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 - # via requests -charset-normalizer==3.1.0 - # via requests -cvxopt==1.3.0 - # via gflownet (pyproject.toml) -cython==0.29.34 - # via tables -google-auth==2.17.3 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -googledrivedownloader==0.4 - # via torch-geometric -gpytorch==1.8.1 - # via - # botorch - # gflownet (pyproject.toml) -grpcio==1.54.0 - # via tensorboard -idna==3.4 - # via requests -importlib-metadata==6.6.0 - # via markdown -isodate==0.6.1 - # via rdflib -jinja2==3.1.2 - # via torch-geometric -joblib==1.2.0 - # via scikit-learn -markdown==3.4.3 - # via tensorboard -markupsafe==2.1.2 - # via - # jinja2 - # werkzeug -msgpack==1.0.5 - # via blosc2 -multipledispatch==0.6.0 - # via botorch -networkx==3.1 - # via - # gflownet (pyproject.toml) - # torch-geometric -numexpr==2.8.4 - # via tables -numpy==1.24.3 - # via - # gpytorch - # numexpr - # opt-einsum - # pandas - # pyarrow - # pyro-ppl - # rdkit - # scikit-learn - # scipy - # tables - # tensorboard - # torch-geometric -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via pyro-ppl -packaging==23.1 - # via tables -pandas==2.0.1 - # via torch-geometric -pillow==9.5.0 - # via rdkit -protobuf==4.22.3 - # via tensorboard -py-cpuinfo==9.0.0 - # via tables -pyarrow==11.0.0 - # via gflownet (pyproject.toml) -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pyparsing==3.0.9 - # via - # rdflib - # torch-geometric -pyro-api==0.1.2 - # via pyro-ppl -pyro-ppl==1.8.0 - # via - # botorch - # gflownet (pyproject.toml) -python-dateutil==2.8.2 - # via pandas -pytz==2023.3 - # via pandas -pyyaml==6.0 - # via - # torch-geometric - # yacs -rdflib==6.3.2 - # via torch-geometric -rdkit==2022.9.5 - # via gflownet (pyproject.toml) -requests==2.29.0 - # via - # requests-oauthlib - # tensorboard - # torch-geometric -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth -scikit-learn==1.2.2 - # via - # gpytorch - # torch-geometric -scipy==1.10.1 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # scikit-learn - # torch-geometric - # torch-sparse -six==1.16.0 - # via - # google-auth - # isodate - # multipledispatch - # python-dateutil -tables==3.8.0 - # via gflownet (pyproject.toml) -tensorboard==2.12.2 - # via gflownet (pyproject.toml) -tensorboard-data-server==0.7.0 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.1.0 - # via scikit-learn -torch==1.10.2 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # pyro-ppl -torch-cluster==1.6.0 - # via gflownet (pyproject.toml) -torch-geometric==2.0.3 - # via gflownet (pyproject.toml) -torch-scatter==2.0.9 - # via gflownet (pyproject.toml) -torch-sparse==0.6.13 - # via gflownet (pyproject.toml) -tqdm==4.65.0 - # via - # pyro-ppl - # torch-geometric -typing-extensions==4.5.0 - # via torch -tzdata==2023.3 - # via pandas -urllib3==1.26.15 - # via requests -werkzeug==2.3.3 - # via tensorboard -wheel==0.40.0 - # via tensorboard -yacs==0.1.8 - # via torch-geometric -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements/main_3.9.txt b/requirements/main_3.9.txt index cb0cbaa0..66be868a 100644 --- a/requirements/main_3.9.txt +++ b/requirements/main_3.9.txt @@ -14,6 +14,10 @@ cvxopt==1.3.0 # via gflownet (pyproject.toml) cython==0.29.33 # via tables +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via gflownet (pyproject.toml) google-auth==2.16.0 # via # google-auth-oauthlib @@ -70,6 +74,8 @@ numpy==1.24.1 # torch-geometric oauthlib==3.2.2 # via requests-oauthlib +omegaconf==2.3.0 + # via gflownet opt-einsum==3.3.0 # via pyro-ppl packaging==23.0 @@ -139,6 +145,8 @@ six==1.16.0 # isodate # multipledispatch # python-dateutil +smmap==5.0.0 + # via gitdb tables==3.8.0 # via gflownet (pyproject.toml) tensorboard==2.11.2 @@ -149,7 +157,7 @@ tensorboard-plugin-wit==1.8.1 # via tensorboard threadpoolctl==3.1.0 # via scikit-learn -torch==1.10.2 +torch==1.13.1 # via # botorch # gflownet (pyproject.toml) @@ -159,9 +167,9 @@ torch-cluster==1.6.0 # via gflownet (pyproject.toml) torch-geometric==2.0.3 # via gflownet (pyproject.toml) -torch-scatter==2.0.9 +torch-scatter==2.1.1 # via gflownet (pyproject.toml) -torch-sparse==0.6.13 +torch-sparse==0.6.17 # via gflownet (pyproject.toml) tqdm==4.64.1 # via @@ -174,7 +182,8 @@ urllib3==1.26.14 werkzeug==2.2.3 # via tensorboard wheel==0.38.4 - # via tensorboard + # via + # tensorboard yacs==0.1.8 # via torch-geometric zipp==3.12.0 diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 43245286..001e19d0 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -1,11 +1,10 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd from torch import Tensor +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from .graph_sampling import GraphSampler @@ -17,9 +16,7 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Advantage Actor-Critic implementation, see Asynchronous Methods for Deep Reinforcement Learning, @@ -38,29 +35,25 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.entropy_coef = hps.get("a2c_entropy", 0.01) - self.gamma = hps.get("a2c_gamma", 1) - self.invalid_penalty = hps.get("a2c_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.entropy_coef = cfg.algo.a2c.entropy + self.gamma = cfg.algo.a2c.gamma + self.invalid_penalty = cfg.algo.a2c.penalty assert self.gamma == 1 self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py new file mode 100644 index 00000000..bd0ce3de --- /dev/null +++ b/src/gflownet/algo/config.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TBConfig: + """Trajectory Balance config. + + Attributes + ---------- + bootstrap_own_reward : bool + Whether to bootstrap the reward with the own reward. (deprecated) + epsilon : Optional[float] + The epsilon parameter in log-flow smoothing (see paper) + reward_loss_multiplier : float + The multiplier for the reward loss when bootstrapping the reward. (deprecated) + do_subtb : bool + Whether to use the full N^2 subTB loss + do_correct_idempotent : bool + Whether to correct for idempotent actions + do_parameterize_p_b : bool + Whether to parameterize the P_B distribution (otherwise it is uniform) + subtb_max_len : int + The maximum length trajectories, used to cache subTB computation indices + Z_learning_rate : float + The learning rate for the logZ parameter (only relevant when do_subtb is False) + Z_lr_decay : float + The learning rate decay for the logZ parameter (only relevant when do_subtb is False) + """ + + bootstrap_own_reward: bool = False + epsilon: Optional[float] = None + reward_loss_multiplier: float = 1.0 + do_subtb: bool = False + do_correct_idempotent: bool = False + do_parameterize_p_b: bool = False + subtb_max_len: int = 128 + Z_learning_rate: float = 1e-4 + Z_lr_decay: float = 50_000 + + +@dataclass +class MOQLConfig: + gamma: float = 1 + num_omega_samples: int = 32 + num_objectives: int = 2 + lambda_decay: int = 10_000 + penalty: float = -10 + + +@dataclass +class A2CConfig: + entropy: float = 0.01 + gamma: float = 1 + penalty: float = -10 + + +@dataclass +class FMConfig: + epsilon: float = 1e-38 + balanced_loss: bool = False + leaf_coef: float = 10 + correct_idempotent: bool = False + + +@dataclass +class SQLConfig: + alpha: float = 0.01 + gamma: float = 1 + penalty: float = -10 + + +@dataclass +class AlgoConfig: + """Generic configuration for algorithms + + Attributes + ---------- + method : str + The name of the algorithm to use (e.g. "TB") + global_batch_size : int + The batch size for training + max_len : int + The maximum length of a trajectory + max_nodes : int + The maximum number of nodes in a generated graph + max_edges : int + The maximum number of edges in a generated graph + illegal_action_logreward : float + The log reward an agent gets for illegal actions + offline_ratio: float + The ratio of samples drawn from `self.training_data` during training. The rest is drawn from + `self.sampling_model` + valid_offline_ratio: float + Idem but for validation, and `self.test_data`. + train_random_action_prob : float + The probability of taking a random action during training + valid_random_action_prob : float + The probability of taking a random action during validation + valid_sample_cond_info : bool + Whether to sample conditioning information during validation (if False, expects a validation set of cond_info) + sampling_tau : float + The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) + """ + + method: str = "TB" + global_batch_size: int = 64 + max_len: int = 128 + max_nodes: int = 128 + max_edges: int = 128 + illegal_action_logreward: float = -100 + offline_ratio: float = 0.5 + valid_offline_ratio: float = 1 + train_random_action_prob: float = 0.0 + valid_random_action_prob: float = 0.0 + valid_sample_cond_info: bool = True + sampling_tau: float = 0.0 + tb: TBConfig = TBConfig() + moql: MOQLConfig = MOQLConfig() + a2c: A2CConfig = A2CConfig() + fm: FMConfig = FMConfig() + sql: SQLConfig = SQLConfig() diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 0f563d60..4d694ae2 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn @@ -8,6 +6,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, GraphBuildingEnv, @@ -15,6 +14,7 @@ generate_forward_trajectory, ) from gflownet.models.graph_transformer import GraphTransformer, mlp +from gflownet.trainer import GFNTask from .graph_sampling import GraphSampler @@ -165,10 +165,9 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, + task: GFNTask, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Envelope Q-Learning implementation, see A Generalized Algorithm for Multi-Objective Reinforcement Learning and Policy Adaptation, @@ -187,31 +186,28 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env + self.task = task self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.gamma = hps.get("moql_gamma", 1) - self.num_objectives = len(hps["objectives"]) - self.num_omega_samples = hps.get("moql_num_omega_samples", 32) - self.Lambda_decay = hps.get("moql_lambda_decay", 10_000) - self.invalid_penalty = hps.get("moql_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.gamma = cfg.algo.moql.gamma + self.num_objectives = cfg.algo.moql.num_objectives + self.num_omega_samples = cfg.algo.moql.num_omega_samples + self.lambda_decay = cfg.algo.moql.lambda_decay + self.invalid_penalty = cfg.algo.moql.penalty self._num_updates = 0 assert self.gamma == 1 self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -396,7 +392,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # and L_B loss_B = abs((w * y).sum(1) - (w * Q_saw).sum(1)) - Lambda = 1 - self.Lambda_decay / (self.Lambda_decay + self._num_updates) + Lambda = 1 - self.lambda_decay / (self.lambda_decay + self._num_updates) losses = (1 - Lambda) * loss_A + Lambda * loss_B self._num_updates += 1 diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py new file mode 100644 index 00000000..a1e9a393 --- /dev/null +++ b/src/gflownet/algo/flow_matching.py @@ -0,0 +1,190 @@ +import networkx as nx +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from torch_scatter import scatter + +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnv, + GraphBuildingEnvContext, +) + + +def relabel(ga: GraphAction, g: Graph): + """Relabel the nodes for g to 0-N, and the graph action ga applied to g. + + This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be + labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that + creates a gap in 0-N, leading to a faulty encoding of the graph. + """ + rmap = dict(zip(g.nodes, range(len(g.nodes)))) + if not len(g) and ga.action == GraphActionType.AddNode: + rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0 + g = nx.relabel_nodes(g, rmap) + if ga.source is not None: + ga.source = rmap[ga.source] + if ga.target is not None: + ga.target = rmap[ga.target] + return ga, g + + +class FlowMatching(TrajectoryBalance): # TODO: FM inherits from TB but we could have a generic GFNAlgorithm class + def __init__( + self, + env: GraphBuildingEnv, + ctx: GraphBuildingEnvContext, + rng: np.random.RandomState, + cfg: Config, + ): + super().__init__(env, ctx, rng, cfg) + self.fm_epsilon = torch.as_tensor(cfg.algo.fm.epsilon).log() + # We include the "balanced loss" as a possibility to reproduce results from the FM paper, but + # in a number of settings the regular loss is more stable. + self.fm_balanced_loss = cfg.algo.fm.balanced_loss + self.fm_leaf_coef = cfg.algo.fm.leaf_coef + self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent + + def construct_batch(self, trajs, cond_info, log_rewards): + """Construct a batch from a list of trajectories and their information + + Parameters + ---------- + trajs: List[List[tuple[Graph, GraphAction]]] + A list of N trajectories. + cond_info: Tensor + The conditional info that is considered for each trajectory. Shape (N, n_info) + log_rewards: Tensor + The transformed reward (e.g. log(R(x) ** beta)) for each trajectory. Shape (N,) + Returns + ------- + batch: gd.Batch + A (CPU) Batch object with relevant attributes added + """ + if not self.correct_idempotent: + # For every s' (i.e. every state except the first of each trajectory), enumerate parents + parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]] + # convert parents to Data + parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent] + else: + # Here we again enumerate parents + states = [i[0] for tj in trajs for i in tj["traj"][1:]] + base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states] + base_parent_graphs = [ + [self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents + ] + parents = [] + parent_graphs = [] + for state, parent_set, parent_set_graphs in zip(states, base_parents, base_parent_graphs): + new_parent_set = [] + new_parent_graphs = [] + # But for each parent we add all the possible (action, parent) pairs to the sets of parents + for (ga, p), pd in zip(parent_set, parent_set_graphs): + ipa = self.get_idempotent_actions(p, pd, state, ga, return_aidx=False) + new_parent_set += [(a, p) for a in ipa] + new_parent_graphs += [pd] * len(ipa) + parents.append(new_parent_set) + parent_graphs += new_parent_graphs + # Implementation Note: no further correction is required for environments where episodes + # always end in a Stop action. If this is not the case, then this implementation is + # incorrect in that it doesn't account for the multiple ways that one could reach the + # terminal state (because it assumes that a terminal state has only one parent and gives + # 100% of the reward-flow to the edge between that parent and the terminal state, which + # for stop actions is correct). Notably, this error will happen in environments where + # there are invalid states that make episodes end prematurely (when those invalid states + # have multiple possible parents). + + # convert actions to aidx + parent_actions = [pact for parent in parents for pact, pstate in parent] + parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)] + # convert state to Data + state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]] + terminal_actions = [ + self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs + ] + + # Create a batch from [*parents, *states]. This order will make it easier when computing the loss + batch = self.ctx.collate(parent_graphs + state_graphs) + batch.num_parents = torch.tensor([len(i) for i in parents]) + batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) + batch.parent_acts = torch.tensor(parent_actionidcs) + batch.terminal_acts = torch.tensor(terminal_actions) + batch.log_rewards = log_rewards + batch.cond_info = cond_info + batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() + if self.correct_idempotent: + raise ValueError("Not implemented") + return batch + + def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): + dev = batch.x.device + eps = self.fm_epsilon.to(dev) + # Compute relevant quantities + num_trajs = len(batch.log_rewards) + num_states = int(batch.num_parents.shape[0]) + total_num_parents = batch.num_parents.sum() + # Compute, for every parent, the index of the state it corresponds to (all states are + # considered numbered 0..N regardless of which trajectory they correspond to) + parents_state_idx = torch.arange(num_states, device=dev).repeat_interleave(batch.num_parents) + # Compute, for every state, the index of the trajectory it corresponds to + states_traj_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens - 1) + # Idem for parents + parents_traj_idx = states_traj_idx.repeat_interleave(batch.num_parents) + # Compute the index of the first graph of every trajectory via a cumsum of the trajectory + # lengths. This works because by design the first parent of every trajectory is s0 (i.e. s1 + # only has one parent that is s0) + num_parents_per_traj = scatter(batch.num_parents, states_traj_idx, 0, reduce="sum") + first_graph_idx = torch.cumsum( + torch.cat([torch.zeros_like(num_parents_per_traj[0])[None], num_parents_per_traj]), 0 + ) + # Similarly we want the index of the last graph of each trajectory + final_graph_idx = torch.cumsum(batch.traj_lens - 1, 0) + total_num_parents - 1 + + # Query the model for Fsa. The model will output a GraphActionCategorical, but we will + # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of + # log F(s,a) so we don't have to change anything in the sampling routines. + cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the + # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs + # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We + # only need the parent's outputs so we specify those batch indices. + parent_log_F_sa = cat.log_prob( + batch.parent_acts, logprobs=cat.logits, batch=torch.arange(total_num_parents, device=dev) + ) + # The inflows is then simply the sum reduction of exponentiating the log edge flows. The + # indices are the state index that each parent belongs to. + log_inflows = scatter(parent_log_F_sa.exp(), parents_state_idx, 0, reduce="sum").log() + # To compute the outflows we can just logsumexp the log F(s,a) predictions. We do so for the + # entire batch, which is slightly wasteful (TODO). We only take the last outflows here, and + # later take the log outflows of s0 to estimate logZ. + all_log_outflows = cat.logsumexp() + log_outflows = all_log_outflows[total_num_parents:] + + # The loss of intermediary states is inflow - outflow. We use the log-epsilon variant (see FM paper) + intermediate_loss = (torch.logaddexp(log_inflows, eps) - torch.logaddexp(log_outflows, eps)).pow(2) + # To compute the loss of the terminal states we match F(s, a'), where a' is the action that + # terminated the trajectory, to R(s). We again use the mechanism of log_prob + log_F_s_stop = cat.log_prob(batch.terminal_acts, cat.logits, final_graph_idx) + terminal_loss = (torch.logaddexp(log_F_s_stop, eps) - torch.logaddexp(batch.log_rewards, eps)).pow(2) + + if self.fm_balanced_loss: + loss = intermediate_loss.mean() + terminal_loss.mean() * self.fm_leaf_coef + else: + loss = (intermediate_loss.sum() + terminal_loss.sum()) / ( + intermediate_loss.shape[0] + terminal_loss.shape[0] + ) + + # logZ is simply the outflow of s0, the first graph of each parent set. + logZ = all_log_outflows[first_graph_idx] + info = { + "intermediate_loss": intermediate_loss.mean().item(), + "terminal_loss": terminal_loss.mean().item(), + "loss": loss.item(), + "logZ": logZ.mean().item(), + } + return loss, info diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index 7d014124..aa1feef8 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -1,11 +1,10 @@ -from typing import Any, Dict - import numpy as np import torch import torch_geometric.data as gd from torch_scatter import scatter from gflownet.algo.trajectory_balance import TrajectoryBalance, TrajectoryBalanceModel +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext @@ -19,11 +18,9 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): - super().__init__(env, ctx, rng, hps, max_len, max_nodes) + super().__init__(env, ctx, rng, cfg) def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): """Compute multi objective REINFORCE loss over trajectories contained in the batch""" diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 0d49a14b..1e3f1146 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn @@ -8,6 +6,7 @@ from torch_scatter import scatter from gflownet.algo.graph_sampling import GraphSampler +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory @@ -17,16 +16,14 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Soft Q-Learning implementation, see - xxxxx + Haarnoja, Tuomas, Haoran Tang, Pieter Abbeel, and Sergey Levine. "Reinforcement learning with deep + energy-based policies." In International conference on machine learning, pp. 1352-1361. PMLR, 2017. Hyperparameters used: illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions - sql_alpha: float, the entropy coefficient Parameters ---------- @@ -36,27 +33,23 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.alpha = hps.get("sql_alpha", 0.01) - self.gamma = hps.get("sql_gamma", 1) - self.invalid_penalty = hps.get("sql_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.alpha = cfg.algo.sql.alpha + self.gamma = cfg.algo.sql.gamma + self.invalid_penalty = cfg.algo.sql.penalty self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index cab29cb3..22fe655e 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Tuple import networkx as nx import numpy as np @@ -9,6 +9,7 @@ from torch_scatter import scatter, scatter_sum from gflownet.algo.graph_sampling import GraphSampler +from gflownet.config import Config from gflownet.envs.graph_building_env import ( Graph, GraphAction, @@ -18,6 +19,7 @@ GraphBuildingEnvContext, generate_forward_trajectory, ) +from gflownet.trainer import GFNAlgorithm class TrajectoryBalanceModel(nn.Module): @@ -28,29 +30,24 @@ def logZ(self, cond_info: Tensor) -> Tensor: raise NotImplementedError() -class TrajectoryBalance: - """ """ +class TrajectoryBalance(GFNAlgorithm): + """TB implementation, see + "Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain, + Emmanuel Bengio, Chen Sun, Yoshua Bengio" + https://arxiv.org/abs/2201.13259""" def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """TB implementation, see "Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio" https://arxiv.org/abs/2201.13259 - Hyperparameters used: - illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions - bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data - tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio - reward_loss_multiplier: float, multiplying constant for the bootstrap loss. - Parameters ---------- env: GraphBuildingEnv @@ -59,22 +56,16 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + Hyperparameters """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.bootstrap_own_reward = hps["bootstrap_own_reward"] - self.epsilon = hps["tb_epsilon"] - self.reward_loss_multiplier = hps.get("reward_loss_multiplier", 1) + self.global_cfg = cfg + self.cfg = cfg.algo.tb + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes # Experimental flags self.reward_loss_is_mae = True self.tb_loss_is_mae = False @@ -83,22 +74,20 @@ def __init__( self.length_normalize_losses = False self.reward_normalize_losses = False self.sample_temp = 1 - self.is_doing_subTB = hps.get("tb_do_subtb", False) - self.correct_idempotent = hps.get("tb_correct_idempotent", False) - self.p_b_is_parameterized = hps.get("tb_p_b_is_parameterized", False) + self.bootstrap_own_reward = self.cfg.bootstrap_own_reward self.graph_sampler = GraphSampler( ctx, env, - max_len, - max_nodes, + cfg.algo.max_len, + cfg.algo.max_nodes, rng, self.sample_temp, - correct_idempotent=self.correct_idempotent, - pad_with_terminal_state=self.p_b_is_parameterized, + correct_idempotent=self.cfg.do_correct_idempotent, + pad_with_terminal_state=self.cfg.do_parameterize_p_b, ) - if self.is_doing_subTB: - self._subtb_max_len = hps.get("tb_subtb_max_len", max_len + 2 if max_len is not None else 128) + if self.cfg.do_subtb: + self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? def create_training_data_from_own_samples( @@ -149,9 +138,17 @@ def create_training_data_from_graphs(self, graphs): trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] A list of trajectories. """ - return [{"traj": generate_forward_trajectory(i)} for i in graphs] - - def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction): + trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] + for traj in trajs: + n_back = [ + self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) + for gp, _ in traj["traj"][1:] + ] + [1] + traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) + traj["result"] = traj["traj"][-1][0] + return trajs + + def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True): """Returns the list of idempotent actions for a given transition. Note, this is slow! Correcting for idempotency is needed to estimate p(x) correctly, but @@ -168,22 +165,24 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph The next state's graph action: GraphAction Action leading from g to gp + return_aidx: bool + If true returns of list of action indices, else a list of GraphAction Returns ------- - actions: List[Tuple[int,int,int]] + actions: Union[List[Tuple[int,int,int]], List[GraphAction]] The list of idempotent actions that all lead from g to gp. """ iaction = self.ctx.GraphAction_to_aidx(gd, action) if action.action == GraphActionType.Stop: - return [iaction] + return [iaction if return_aidx else action] # Here we're looking for potential idempotent actions by looking at legal actions of the # same type. This assumes that this is the only way to get to a similar parent. Perhaps # there are edges cases where this is not true...? lmask = getattr(gd, action.action.mask_name) nz = lmask.nonzero() # Legal actions are those with a nonzero mask value - actions = [iaction] + actions = [iaction if return_aidx else action] for i in nz: aidx = (iaction[0], i[0].item(), i[1].item()) if aidx == iaction: @@ -191,7 +190,7 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph ga = self.ctx.aidx_to_GraphAction(gd, aidx, fwd=not action.action.is_backward) child = self.env.step(g, ga) if nx.algorithms.is_isomorphic(child, gp, lambda a, b: a == b, lambda a, b: a == b): - actions.append(aidx) + actions.append(aidx if return_aidx else ga) return actions def construct_batch(self, trajs, cond_info, log_rewards): @@ -218,7 +217,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) batch.log_p_B = torch.cat([i["bck_logprobs"] for i in trajs], 0) batch.actions = torch.tensor(actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: batch.bck_actions = torch.tensor( [ self.ctx.GraphAction_to_aidx(g, a) @@ -229,7 +228,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.log_rewards = log_rewards batch.cond_info = cond_info batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() - if self.correct_idempotent: + if self.cfg.do_correct_idempotent: # Every timestep is a (graph_a, action, graph_b) triple agraphs = [i[0] for tj in trajs for i in tj["traj"]] # Here we start at the 1th timestep and append the result @@ -241,7 +240,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): ] batch.ip_actions = torch.tensor(sum(ipa, [])) batch.ip_lens = torch.tensor([len(i) for i in ipa]) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # Here we start at the 0th timestep and prepend None (it will be unused) bgraphs = sum([[None] + [i[0] for i in tj["traj"][:-1]] for tj in trajs], []) gactions = [i for tj in trajs for i in tj["bck_a"]] @@ -254,7 +253,9 @@ def construct_batch(self, trajs, cond_info, log_rewards): return batch - def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): + def compute_batch_losses( + self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0 # type: ignore[override] + ): """Compute the losses over trajectories contained in the batch Parameters @@ -272,7 +273,9 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n log_rewards = batch.log_rewards # Clip rewards assert log_rewards.ndim == 1 - clip_log_R = torch.maximum(log_rewards, torch.tensor(self.illegal_action_logreward, device=dev)).float() + clip_log_R = torch.maximum( + log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev) + ).float() cond_info = batch.cond_info invalid_mask = 1 - batch.is_valid @@ -285,7 +288,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) else: fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) @@ -296,7 +299,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # Compute trajectory balance objective log_Z = model.logZ(cond_info)[:, 0] # Compute the log prob of each action in the trajectory - if self.correct_idempotent: + if self.cfg.do_correct_idempotent: # If we want to correct for idempotent actions, we need to sum probabilities # i.e. to compute P(s' | s) = sum_{a that lead to s'} P(a|s) # here we compute the indices of the graph that each action corresponds to, ip_lens @@ -312,7 +315,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # scatter(small number) = 0 on CUDA log_p_F = p.clamp(1e-30).log() - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # Now we repeat this but for the backward policy bck_ip_batch_idces = torch.arange(batch.bck_ip_lens.shape[0], device=dev).repeat_interleave( batch.bck_ip_lens @@ -325,14 +328,14 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n else: # Else just naively take the logprob of the actions we took log_p_F = fwd_cat.log_prob(batch.actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: log_p_B = bck_cat.log_prob(batch.bck_actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # If we're modeling P_B then trajectories are padded with a virtual terminal state sF, # zero-out the logP_F of those states log_p_F[final_graph_idx] = 0 - if self.is_doing_subTB: + if self.cfg.do_subtb: # Force the pad states' F(s) prediction to be R per_graph_out[final_graph_idx, 0] = clip_log_R @@ -355,12 +358,12 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n traj_log_p_F = scatter(log_p_F, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") traj_log_p_B = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") - if self.is_doing_subTB: + if self.cfg.do_subtb: # SubTB interprets the per_graph_out predictions to predict the state flow F(s) traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) # The position of the first graph of each trajectory first_graph_idx = torch.zeros_like(batch.traj_lens) - first_graph_idx = torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) log_Z = per_graph_out[first_graph_idx, 0] else: # Compute log numerator and denominator of the TB objective @@ -374,9 +377,9 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # (thus the `numerator - 1`). Why 1? Intuition? denominator = denominator * (1 - invalid_mask) + invalid_mask * (numerator.detach() - 1) - if self.epsilon is not None: + if self.cfg.epsilon is not None: # Numerical stability epsilon - epsilon = torch.tensor([self.epsilon], device=dev).float() + epsilon = torch.tensor([self.cfg.epsilon], device=dev).float() numerator = torch.logaddexp(numerator, epsilon) denominator = torch.logaddexp(denominator, epsilon) if self.tb_loss_is_mae: @@ -399,17 +402,17 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # undercount (by 2N) the contribution of each loss traj_losses = factor * traj_losses * num_trajs - if self.bootstrap_own_reward: + if self.cfg.bootstrap_own_reward: num_bootstrap = num_bootstrap or len(log_rewards) if self.reward_loss_is_mae: reward_losses = abs(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]) else: reward_losses = (log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]).pow(2) - reward_loss = reward_losses.mean() + reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier else: reward_loss = 0 - loss = traj_losses.mean() + reward_loss * self.reward_loss_multiplier + loss = traj_losses.mean() + reward_loss info = { "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, @@ -499,7 +502,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): for ep in range(traj_lengths.shape[0]): offset = cumul_lens[ep] T = int(traj_lengths[ep]) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # The length of the trajectory is the padded length, reduce by 1 T -= 1 idces, dests = self._precomp[T - 1] diff --git a/src/gflownet/config.py b/src/gflownet/config.py new file mode 100644 index 00000000..be4fa879 --- /dev/null +++ b/src/gflownet/config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Optional + +from omegaconf import MISSING + +from gflownet.algo.config import AlgoConfig +from gflownet.data.config import ReplayConfig +from gflownet.models.config import ModelConfig +from gflownet.tasks.config import TasksConfig +from gflownet.utils.config import ConditionalsConfig + + +@dataclass +class OptimizerConfig: + """Generic configuration for optimizers + + Attributes + ---------- + opt : str + The optimizer to use (either "adam" or "sgd") + learning_rate : float + The learning rate + lr_decay : float + The learning rate decay (in steps, f = 2 ** (-steps / self.cfg.opt.lr_decay)) + weight_decay : float + The L2 weight decay + momentum : float + The momentum parameter value + clip_grad_type : str + The type of gradient clipping to use (either "norm" or "value") + clip_grad_param : float + The parameter for gradient clipping + adam_eps : float + The epsilon parameter for Adam + """ + + opt: str = "adam" + learning_rate: float = 1e-4 + lr_decay: float = 20_000 + weight_decay: float = 1e-8 + momentum: float = 0.9 + clip_grad_type: str = "norm" + clip_grad_param: float = 10.0 + adam_eps: float = 1e-8 + + +@dataclass +class Config: + """Base configuration for training + + Attributes + ---------- + log_dir : str + The directory where to store logs, checkpoints, and samples. + device : str + The device to use for training (either "cpu" or "cuda[:]") + seed : int + The random seed + validate_every : int + The number of training steps after which to validate the model + checkpoint_every : Optional[int] + The number of training steps after which to checkpoint the model + print_every : int + The number of training steps after which to print the training loss + start_at_step : int + The training step to start at (default: 0) + num_final_gen_steps : Optional[int] + After training, the number of steps to generate graphs for + num_training_steps : int + The number of training steps + num_workers : int + The number of workers to use for creating minibatches (0 = no multiprocessing) + hostname : Optional[str] + The hostname of the machine on which the experiment is run + pickle_mp_messages : bool + Whether to pickle messages sent between processes (only relevant if num_workers > 0) + git_hash : Optional[str] + The git hash of the current commit + overwrite_existing_exp : bool + Whether to overwrite the contents of the log_dir if it already exists + """ + + log_dir: str = MISSING + device: str = "cuda" + seed: int = 0 + validate_every: int = 1000 + checkpoint_every: Optional[int] = None + print_every: int = 100 + start_at_step: int = 0 + num_final_gen_steps: Optional[int] = None + num_training_steps: int = 10_000 + num_workers: int = 0 + hostname: Optional[str] = None + pickle_mp_messages: bool = False + git_hash: Optional[str] = None + overwrite_existing_exp: bool = True + algo: AlgoConfig = AlgoConfig() + model: ModelConfig = ModelConfig() + opt: OptimizerConfig = OptimizerConfig() + replay: ReplayConfig = ReplayConfig() + task: TasksConfig = TasksConfig() + cond: ConditionalsConfig = ConditionalsConfig() diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py new file mode 100644 index 00000000..fab5d036 --- /dev/null +++ b/src/gflownet/data/config.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ReplayConfig: + """Replay buffer configuration + + Attributes + ---------- + use : bool + Whether to use a replay buffer + capacity : int + The capacity of the replay buffer + warmup : int + The number of samples to collect before starting to sample from the replay buffer + hindsight_ratio : float + The ratio of hindsight samples within a batch + """ + + use: bool = False + capacity: Optional[int] = None + warmup: Optional[int] = None + hindsight_ratio: float = 0 diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index 28dc7f78..b26c29d2 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import rdkit.Chem as Chem +import torch from torch.utils.data import Dataset @@ -39,4 +40,23 @@ def __len__(self): return len(self.idcs) def __getitem__(self, idx): - return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]]) + return ( + Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), + torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), + ) + + +def convert_h5(): + # File obtained from + # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 + # (from http://quantum-machine.org/datasets/) + f = tarfile.TarFile("qm9.xyz.tar", "r") + labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] + all_mols = [] + for pt in f: + pt = f.extractfile(pt) # type: ignore + data = pt.read().decode().splitlines() # type: ignore + all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) + df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) + store = pd.HDFStore("qm9.h5", "w") + store["df"] = df diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py new file mode 100644 index 00000000..541656fd --- /dev/null +++ b/src/gflownet/data/replay_buffer.py @@ -0,0 +1,42 @@ +from typing import List + +import numpy as np +import torch + +from gflownet.config import Config + + +class ReplayBuffer(object): + def __init__(self, cfg: Config, rng: np.random.Generator = None): + self.capacity = cfg.replay.capacity + self.warmup = cfg.replay.warmup + assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" + + self.buffer: List[tuple] = [] + self.position = 0 + self.rng = rng + + def push(self, *args): + if len(self.buffer) == 0: + self._input_size = len(args) + else: + assert self._input_size == len(args), "ReplayBuffer input size must be constant" + if len(self.buffer) < self.capacity: + self.buffer.append(None) + self.buffer[self.position] = args + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size): + idxs = self.rng.choice(len(self.buffer), batch_size) + out = list(zip(*[self.buffer[idx] for idx in idxs])) + for i in range(len(out)): + # stack if all elements are numpy arrays or torch tensors + # (this is much more efficient to send arrays through multiprocessing queues) + if all([isinstance(x, np.ndarray) for x in out[i]]): + out[i] = np.stack(out[i], axis=0) + elif all([isinstance(x, torch.Tensor) for x in out[i]]): + out[i] = torch.stack(out[i], dim=0) + return tuple(out) + + def __len__(self): + return len(self.buffer) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index ff95691e..90b8b4db 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -1,6 +1,7 @@ import os import sqlite3 from collections.abc import Iterable +from copy import deepcopy from typing import Callable, List import networkx as nx @@ -10,6 +11,9 @@ from rdkit import Chem, RDLogger from torch.utils.data import Dataset, IterableDataset +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical + class SamplingIterator(IterableDataset): """This class allows us to parallelise and train faster. @@ -25,16 +29,20 @@ def __init__( self, dataset: Dataset, model: nn.Module, - batch_size: int, ctx, algo, task, device, - ratio=0.5, - stream=True, + batch_size: int = 1, + illegal_action_logreward: float = -50, + ratio: float = 0.5, + stream: bool = True, + replay_buffer: ReplayBuffer = None, log_dir: str = None, - sample_cond_info=True, - random_action_prob=0.0, + sample_cond_info: bool = True, + random_action_prob: float = 0.0, + hindsight_ratio: float = 0.0, + init_train_iter: int = 0, ): """Parameters ---------- @@ -43,12 +51,21 @@ def __init__( model: nn.Module The model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker) + ctx: + The context for the environment, e.g. a MolBuildingEnvContext instance + algo: + The training algorithm, e.g. a TrajectoryBalance instance + task: GFNTask + A Task instance, e.g. a MakeRingsTask instance + device: torch.device + The device the model is on + replay_buffer: ReplayBuffer + The replay buffer for training on past data batch_size: int The number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task) - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: ConditionalTask + illegal_action_logreward: float + The logreward for invalid trajectories ratio: float The ratio of offline trajectories in the batch. stream: bool @@ -59,13 +76,18 @@ def __init__( sample_cond_info: bool If True (default), then the dataset is a dataset of points used in offline training. If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - + random_action_prob: float + The probability of taking a random action, passed to the graph sampler + init_train_iter: int + The initial training iteration, incremented and passed to task.sample_conditional_information """ self.data = dataset self.model = model + self.replay_buffer = replay_buffer self.batch_size = batch_size - self.offline_batch_size = int(np.ceil(batch_size * ratio)) - self.online_batch_size = int(np.floor(batch_size * (1 - ratio))) + self.illegal_action_logreward = illegal_action_logreward + self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) + self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) self.ratio = ratio self.ctx = ctx self.algo = algo @@ -75,17 +97,22 @@ def __init__( self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely self.sample_cond_info = sample_cond_info self.random_action_prob = random_action_prob + self.hindsight_ratio = hindsight_ratio + self.train_it = init_train_iter + self.do_validate_batch = False # Turn this on for debugging self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag + + # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) + # then "offline" now refers to cond info and online to x, so no duplication and we don't end + # up with 2*batch_size accidentally if not sample_cond_info: - # Slightly weird semantics, but if we're sampling x given some fixed (data) cond info - # then "offline" refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally - self.offline_batch_size = self.online_batch_size = batch_size - self.log_dir = log_dir + self.offline_batch_size = self.online_batch_size = self.batch_size + # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we # don't want to initialize per-worker things just yet, such as where the log the worker writes # to. This must be done in __iter__, which is called by the DataLoader once this instance # has been copied into a new python process. + self.log_dir = log_dir self.log = SQLiteLog() self.log_hooks: List[Callable] = [] @@ -105,9 +132,12 @@ def _idx_iterator(self): if n == 0: yield np.arange(0, 0) return - if worker_info is None: + assert ( + self.offline_batch_size > 0 + ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" + if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 - else: + else: # split the data into chunks (per-worker) nw = worker_info.num_workers wid = worker_info.id start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) @@ -143,7 +173,11 @@ def __iter__(self): # Sample conditional info such as temperature, trade-off weights, etc. if self.sample_cond_info: - cond_info = self.task.sample_conditional_information(num_offline + self.online_batch_size) + num_online = self.online_batch_size + cond_info = self.task.sample_conditional_information( + num_offline + self.online_batch_size, self.train_it + ) + # Sample some dataset data mols, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) flat_rewards = ( @@ -151,15 +185,17 @@ def __iter__(self): ) graphs = [self.ctx.mol_to_graph(m) for m in mols] trajs = self.algo.create_training_data_from_graphs(graphs) - num_online = self.online_batch_size + else: # If we're not sampling the conditionals, then the idcs refer to listed preferences num_online = num_offline num_offline = 0 - cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs])) + cond_info = self.task.encode_conditional_information( + steer_info=torch.stack([self.data[i] for i in idcs]) + ) trajs, flat_rewards = [], [] - is_valid = torch.ones(num_offline + num_online).bool() # Sample some on-policy data + is_valid = torch.ones(num_offline + num_online).bool() if num_online > 0: with torch.no_grad(): trajs += self.algo.create_training_data_from_own_samples( @@ -181,17 +217,15 @@ def __iter__(self): # fetch the valid trajectories endpoints mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] # ask the task to compute their reward - preds, m_is_valid = self.task.compute_flat_rewards(mols) - assert preds.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) + assert ( + online_flat_rew.ndim == 2 + ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" # The task may decide some of the mols are invalid, we have to again filter those valid_idcs = valid_idcs[m_is_valid] valid_mols = [m for m, v in zip(mols, m_is_valid) if v] - pred_reward = torch.zeros((num_online, preds.shape[1])) - pred_reward[valid_idcs - num_offline] = preds - # TODO: reintegrate bootstrapped reward predictions - # if preds.shape[0] > 0: - # for i in range(self.number_of_objectives): - # pred_reward[valid_idcs - num_offline, i] = preds[range(preds.shape[0]), i] + pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) + pred_reward[valid_idcs - num_offline] = online_flat_rew is_valid[num_offline:] = False is_valid[valid_idcs] = True flat_rewards += list(pred_reward) @@ -201,51 +235,125 @@ def __iter__(self): if self.log_molecule_smis: for i, m in zip(valid_idcs, valid_mols): trajs[i]["smi"] = Chem.MolToSmiles(m) - flat_rewards = torch.stack(flat_rewards) + # Compute scalar rewards from conditional information & flat rewards + flat_rewards = torch.stack(flat_rewards) log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.algo.illegal_action_logreward - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.mols = mols - batch.preferences = cond_info.get("preferences", None) - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + # Computes some metrics + extra_info = {} if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): i["data_idx"] = j - - # Converts back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) + # note: we convert back into natural rewards for logging purposes + # (allows to take averages and plot in objective space) + # TODO: implement that per-task (in case they don't apply the same beta and log transformations) rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: self.log_generated( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) if num_online > 0: - extra_info = {} for hook in self.log_hooks: extra_info.update( hook( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) ) - batch.extra_info = extra_info + + if self.replay_buffer is not None: + # If we have a replay buffer, we push the online trajectories in it + # and resample immediately such that the "online" data in the batch + # comes from a more stable distribution (try to avoid forgetting) + + # cond_info is a dict, so we need to convert it to a list of dicts + cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] + + # push the online trajectories in the replay buffer and sample a new 'online' batch + for i in range(num_offline, len(trajs)): + self.replay_buffer.push( + deepcopy(trajs[i]), + deepcopy(log_rewards[i]), + deepcopy(flat_rewards[i]), + deepcopy(cond_info[i]), + deepcopy(is_valid[i]), + ) + replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( + num_online + ) + + # append the online trajectories to the offline ones + trajs[num_offline:] = replay_trajs + log_rewards[num_offline:] = replay_logr + flat_rewards[num_offline:] = replay_fr + cond_info[num_offline:] = replay_condinfo + is_valid[num_offline:] = replay_valid + + # convert cond_info back to a dict + cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} + + if self.hindsight_ratio > 0.0: + # Relabels some of the online trajectories with hindsight + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Construct batch + batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) + batch.num_offline = num_offline + batch.num_online = num_online + batch.flat_rewards = flat_rewards + batch.preferences = cond_info.get("preferences", None) + batch.focus_dir = cond_info.get("focus_dir", None) + batch.extra_info = extra_info + # TODO: we could very well just pass the cond_info dict to construct_batch above, + # and the algo can decide what it wants to put in the batch object + + # Only activate for debugging your environment or dataset (e.g. the dataset could be + # generating trajectories with illegal actions) + if self.do_validate_batch: + self.validate_batch(batch, trajs) + + self.train_it += worker_info.num_workers if worker_info is not None else 1 yield batch + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [self.model._action_type_to_mask(t, batch) for t in atypes], + [self.model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + def log_generated(self, trajs, rewards, flat_rewards, cond_info): if self.log_molecule_smis: mols = [ @@ -258,18 +366,26 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() rewards = rewards.data.numpy().tolist() preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences"]] + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] data = [ - [mols[i], rewards[i]] + flat_rewards[i] + preferences[i] + [cond_info[k][i].item() for k in logged_keys] + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] for i in range(len(trajs)) ] + data_labels = ( ["smi", "r"] + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + [f"ci_{k}" for k in logged_keys] ) + self.log.insert_many(data, data_labels) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index cb20fcd6..82818f83 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -65,6 +65,8 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu # for u and the second half for v. Each logit i in the first half for a given edge # corresponds to setting the stem atom of fragment u used to attach between u and v to be i # (named f'{u}_attach') and vice versa for the second half and v, u. + # Note to self: this choice results in a special case in generate_forward_trajectory for these + # edge attributes. See PR#83 for details. self.num_edge_attr_logits = most_stems * 2 # There are thus up to 2 edge attributes, the stem of u and the stem of v. self.num_edge_attrs = 2 diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 4ac72415..fa7b284b 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -271,9 +271,9 @@ def add_parent(a, new_g): GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]), new_g, ) - if len(g.nodes) == 1: + if len(g.nodes) == 1 and len(g.nodes[i]) == 1: # The final node is degree 0, need this special case to remove it - # and end up with S0, the empty graph root + # and end up with S0, the empty graph root (but only if it has no attrs except 'v') add_parent( GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]), graph_without_node(g, i), @@ -348,14 +348,26 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G if len(i) > 1: # i is an edge e = relabeling_map.get(i[0], None), relabeling_map.get(i[1], None) if e in gn.edges: - # i exists in the new graph, that means some of its attributes need to be added - attrs = [j for j in g.edges[i] if j not in gn.edges[e]] + # i exists in the new graph, that means some of its attributes need to be added. + # + # This remap is a special case for the fragment environment, due to the (poor) design + # choice of treating directed edges as undirected edges. Until we have routines for + # directed graphs, this may need to stay. + def possibly_remap(attr): + if attr == f"{i[0]}_attach": + return f"{e[0]}_attach" + elif attr == f"{i[1]}_attach": + return f"{e[1]}_attach" + return attr + + attrs = [j for j in g.edges[i] if possibly_remap(j) not in gn.edges[e]] if len(attrs) == 0: continue # If nodes are in cycles edges leading to them get stack multiple times, disregard - attr = attrs[np.random.randint(len(attrs))] - gn.edges[e][attr] = g.edges[i][attr] + iattr = attrs[np.random.randint(len(attrs))] + eattr = possibly_remap(iattr) + gn.edges[e][eattr] = g.edges[i][iattr] act = GraphAction( - GraphActionType.SetEdgeAttr, source=e[0], target=e[1], attr=attr, value=g.edges[i][attr] + GraphActionType.SetEdgeAttr, source=e[0], target=e[1], attr=eattr, value=g.edges[i][iattr] ) else: # i doesn't exist, add the edge @@ -499,9 +511,11 @@ def __init__( self.logprobs = None if deduplicate_edge_index and "edge_index" in keys: - idx = keys.index("edge_index") - self.batch[idx] = self.batch[idx][::2] - self.slice[idx] = self.slice[idx].div(2, rounding_mode="floor") + for idx, k in enumerate(keys): + if k != "edge_index": + continue + self.batch[idx] = self.batch[idx][::2] + self.slice[idx] = self.slice[idx].div(2, rounding_mode="floor") def detach(self): new = copy.copy(self) diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 44c99e4a..13535bbf 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -8,7 +8,13 @@ from rdkit.Chem import Mol from rdkit.Chem.rdchem import BondType, ChiralType -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnvContext, + graph_without_edge, +) from gflownet.utils.graphs import random_walk_probs DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] @@ -28,6 +34,7 @@ def __init__( charges=[0, 1, -1], expl_H_range=[0, 1], allow_explicitly_aromatic=False, + allow_5_valence_nitrogen=False, num_rw_feat=8, max_nodes=None, max_edges=None, @@ -76,19 +83,22 @@ def __init__( # The size of the input vector for each atom self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values()) self.atom_attrs = sorted(self.atom_attr_values.keys()) + # 'v' is set separately when creating the node, so there's no point in having a SetNodeAttr logit for it + self.settable_atom_attrs = [i for i in self.atom_attrs if i != "v"] # The beginning position within the input vector of each attribute self.atom_attr_slice = [0] + list(np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs])) # The beginning position within the logit vector of each attribute - num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs] + num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.settable_atom_attrs] self.atom_attr_logit_slice = { k: (s, e) - for k, s, e in zip(self.atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits)) + for k, s, e in zip( + self.settable_atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits) + ) } # The attribute and value each logit dimension maps back to self.atom_attr_logit_map = [ (k, v) - for k in self.atom_attrs - if k != "v" + for k in self.settable_atom_attrs # index 0 is skipped because it is the default value for v in self.atom_attr_values[k][1:] ] @@ -118,9 +128,11 @@ def __init__( BondType.AROMATIC: 1.5, } pt = Chem.GetPeriodicTable() + self.allow_5_valence_nitrogen = allow_5_valence_nitrogen self._max_atom_valence = { **{a: max(pt.GetValenceList(a)) for a in atoms}, - "N": 3, # We'll handle nitrogen valence later explicitly in graph_to_Data + # We'll handle nitrogen valence later explicitly in graph_to_Data + "N": 3 if not allow_5_valence_nitrogen else 5, "*": 0, # wildcard atoms have 0 valence until filled in } @@ -144,12 +156,21 @@ def __init__( GraphActionType.AddEdge, GraphActionType.SetEdgeAttr, ] + self.bck_action_type_order = [ + GraphActionType.RemoveNode, + GraphActionType.RemoveNodeAttr, + GraphActionType.RemoveEdge, + GraphActionType.RemoveEdgeAttr, + ] self.device = torch.device("cpu") def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" act_type, act_row, act_col = [int(i) for i in action_idx] - t = self.action_type_order[act_type] + if fwd: + t = self.action_type_order[act_type] + else: + t = self.bck_action_type_order[act_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: @@ -161,12 +182,34 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: a, b = g.non_edge_index[:, act_row] return GraphAction(t, source=a.item(), target=b.item()) elif t is GraphActionType.SetEdgeAttr: - a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. + a, b = g.edge_index[:, act_row * 2] attr, val = self.bond_attr_logit_map[act_col] return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) + elif t is GraphActionType.RemoveNode: + return GraphAction(t, source=act_row) + elif t is GraphActionType.RemoveNodeAttr: + attr = self.settable_atom_attrs[act_col] + return GraphAction(t, source=act_row, attr=attr) + elif t is GraphActionType.RemoveEdge: + a, b = g.edge_index[:, act_row * 2] # see note above about edge_index + return GraphAction(t, source=a.item(), target=b.item()) + elif t is GraphActionType.RemoveEdgeAttr: + a, b = g.edge_index[:, act_row * 2] # see note above about edge_index + attr = self.bond_attrs[act_col] + return GraphAction(t, source=a.item(), target=b.item(), attr=attr) def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: """Translate a GraphAction to an index tuple""" + for u in [self.action_type_order, self.bck_action_type_order]: + if action.action in u: + type_idx = u.index(action.action) + break + else: + raise ValueError(f"Unknown action type {action.action}") + if action.action is GraphActionType.Stop: row = col = 0 elif action.action is GraphActionType.AddNode: @@ -188,17 +231,33 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int ).argmax() col = 0 elif action.action is GraphActionType.SetEdgeAttr: - # Here the edges are duplicated, both (i,j) and (j,i) are in edge_index - # so no need for a double check. - # row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) + - # (g.edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)).argmax() + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() - # Because edges are duplicated but logits aren't, divide by two row = row.div(2, rounding_mode="floor") # type: ignore col = ( self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0] ) - type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.RemoveNode: + row = action.source + col = 0 + elif action.action is GraphActionType.RemoveNodeAttr: + row = action.source + col = self.settable_atom_attrs.index(action.attr) + elif action.action is GraphActionType.RemoveEdge: + row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax() + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. + row = int(row) // 2 + col = 0 + elif action.action is GraphActionType.RemoveEdgeAttr: + row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() + row = row.div(2, rounding_mode="floor") # type: ignore + col = self.bond_attrs.index(action.attr) + else: + raise ValueError(f"Unknown action type {action.action}") return (type_idx, int(row), int(col)) def graph_to_Data(self, g: Graph) -> gd.Data: @@ -208,6 +267,9 @@ def graph_to_Data(self, g: Graph) -> gd.Data: add_node_mask = torch.ones((x.shape[0], self.num_new_node_values)) if self.max_nodes is not None and len(g.nodes) >= self.max_nodes: add_node_mask *= 0 + remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) + remove_node_attr_mask = torch.zeros((x.shape[0], len(self.settable_atom_attrs))) + explicit_valence = {} max_valence = {} set_node_attr_mask = torch.ones((x.shape[0], self.num_node_attr_logits)) @@ -215,25 +277,41 @@ def graph_to_Data(self, g: Graph) -> gd.Data: set_node_attr_mask *= 0 for i, n in enumerate(g.nodes): ad = g.nodes[n] + if g.degree(n) <= 1 and len(ad) == 1 and all([len(g[n][neigh]) == 0 for neigh in g.neighbors(n)]): + # If there's only the 'v' key left and the node is a leaf, and the edge that connect to the node have + # no attributes set, we can remove it + remove_node_mask[i] = 1 for k, sl in zip(self.atom_attrs, self.atom_attr_slice): + # idx > 0 means that the attribute is not the default value idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0 x[i, sl + idx] = 1 - # If the attribute is already there, mask out logits - # (or if the attribute is a negative attribute and has been filled) + if k == "v": + continue + # If the attribute + # - is already there (idx > 0), + # - or the attribute is a negative attribute and has been filled + # - or the attribute is a negative attribute and is not fillable (i.e. not a key of ad) + # then mask forward logits. + # For backward logits, positively mask if the attribute is there (idx > 0). if k in self.negative_attrs: if k in ad and idx > 0 or k not in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + # We don't want to make the attribute removable if it's not fillable (i.e. not a key of ad) + if k in ad: + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 elif k in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 # Account for charge and explicit Hs in atom as limiting the total valence max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]] # Special rule for Nitrogen if ad["v"] == "N" and ad.get("charge", 0) == 1: - # This is definitely a heuristic, but to keep things simple we'll limit Nitrogen's valence to 3 (as + # This is definitely a heuristic, but to keep things simple we'll limit* Nitrogen's valence to 3 (as # per self._max_atom_valence) unless it is charged, then we make it 5. # This keeps RDKit happy (and is probably a good idea anyway). + # (* unless allow_5_valence_nitrogen is True, then it's just always 5) max_atom_valence = 5 max_valence[n] = max_atom_valence - abs(ad.get("charge", 0)) - ad.get("expl_H", 0) # Compute explicitly defined valence: @@ -252,8 +330,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data: s, e = self.atom_attr_logit_slice["expl_H"] set_node_attr_mask[i, s:e] = 0 + remove_edge_mask = torch.zeros((len(g.edges), 1)) + for i, (u, v) in enumerate(g.edges): + if g.degree(u) > 1 and g.degree(v) > 1: + if nx.algorithms.is_connected(graph_without_edge(g, (u, v))): + remove_edge_mask[i] = 1 edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) set_edge_attr_mask = torch.zeros((len(g.edges), self.num_edge_attr_logits)) + remove_edge_attr_mask = torch.zeros((len(g.edges), len(self.bond_attrs))) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): @@ -263,6 +347,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: if k in ad: # If the attribute is already there, mask out logits s, e = self.bond_attr_logit_slice[k] set_edge_attr_mask[i, s:e] = 0 + remove_edge_attr_mask[i, self.bond_attrs.index(k)] = 1 # Check which bonds don't bust the valence of their atoms if "type" not in ad: # Only if type isn't already set sl, _ = self.bond_attr_logit_slice["type"] @@ -281,18 +366,23 @@ def is_ok_non_edge(e): non_edge_index = torch.zeros((2, 0), dtype=torch.long) else: gc = nx.complement(g) - non_edge_index = torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).T.reshape( - (2, -1) + non_edge_index = ( + torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).reshape((-1, 2)).T ) data = gd.Data( x, edge_index, edge_attr, non_edge_index=non_edge_index, + stop_mask=torch.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge set_edge_attr_mask=set_edge_attr_mask, + remove_node_mask=remove_node_mask, + remove_node_attr_mask=remove_node_attr_mask, + remove_edge_mask=remove_edge_mask, + remove_edge_attr_mask=remove_edge_attr_mask, ) if self.num_rw_feat > 0: data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1) diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py new file mode 100644 index 00000000..833a2bba --- /dev/null +++ b/src/gflownet/models/config.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + + +@dataclass +class GraphTransformerConfig: + num_heads: int = 2 + ln_type: str = "pre" + num_mlp_layers: int = 0 + + +@dataclass +class ModelConfig: + """Generic configuration for models + + Attributes + ---------- + num_layers : int + The number of layers in the model + num_emb : int + The number of dimensions of the embedding + """ + + num_layers: int = 3 + num_emb: int = 128 + graph_transformer: GraphTransformerConfig = GraphTransformerConfig() diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 6dea00a2..05f9b0e4 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -6,6 +6,7 @@ import torch_geometric.nn as gnn from torch_geometric.utils import add_self_loops +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType @@ -139,20 +140,36 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): class GraphTransformerGFN(nn.Module): - """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Meant for atom-wise - generation. + """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Outputs logits corresponding to the action types used by the env_ctx argument. """ + # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the + # embeddings to the right MLP + _action_type_to_graph_part = { + GraphActionType.Stop: "graph", + GraphActionType.AddNode: "node", + GraphActionType.SetNodeAttr: "node", + GraphActionType.AddEdge: "non_edge", + GraphActionType.SetEdgeAttr: "edge", + GraphActionType.RemoveNode: "node", + GraphActionType.RemoveNodeAttr: "node", + GraphActionType.RemoveEdge: "edge", + GraphActionType.RemoveEdgeAttr: "edge", + } + # The torch_geometric batch key each graph part corresponds to + _graph_part_to_key = { + "graph": None, + "node": "x", + "non_edge": "non_edge_index", + "edge": "edge_index", + } + def __init__( self, env_ctx, - num_emb=64, - num_layers=3, - num_heads=2, - num_mlp_layers=0, - ln_type="pre", + cfg: Config, num_graph_out=1, do_bck=False, ): @@ -162,11 +179,12 @@ def __init__( x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim, g_dim=env_ctx.num_cond_dim, - num_emb=num_emb, - num_layers=num_layers, - num_heads=num_heads, - ln_type=ln_type, + num_emb=cfg.model.num_emb, + num_layers=cfg.model.num_layers, + num_heads=cfg.model.graph_transformer.num_heads, + ln_type=cfg.model.graph_transformer.ln_type, ) + num_emb = cfg.model.num_emb num_final = num_emb num_glob_final = num_emb * 2 num_edge_feat = num_emb if env_ctx.edges_are_unordered else num_emb * 2 @@ -187,39 +205,22 @@ def __init__( GraphActionType.RemoveEdge: (num_edge_feat, 1), GraphActionType.RemoveEdgeAttr: (num_edge_feat, env_ctx.num_edge_attrs), } - # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the - # embeddings to the right MLP - self._action_type_to_graph_part = { - GraphActionType.Stop: "graph", - GraphActionType.AddNode: "node", - GraphActionType.SetNodeAttr: "node", - GraphActionType.AddEdge: "non_edge", - GraphActionType.SetEdgeAttr: "edge", - GraphActionType.RemoveNode: "node", - GraphActionType.RemoveNodeAttr: "node", - GraphActionType.RemoveEdge: "edge", - GraphActionType.RemoveEdgeAttr: "edge", - } - # The torch_geometric batch key each graph part corresponds to - self._graph_part_to_key = { - "graph": None, - "node": "x", - "non_edge": "non_edge_index", - "edge": "edge_index", + self._action_type_to_key = { + at: self._graph_part_to_key[self._action_type_to_graph_part[at]] for at in self._action_type_to_graph_part } # Here we create only the embedding -> logit mapping MLPs that are required by the environment mlps = {} for atype in chain(env_ctx.action_type_order, env_ctx.bck_action_type_order if do_bck else []): num_in, num_out = self._action_type_to_num_inputs_outputs[atype] - mlps[atype.cname] = mlp(num_in, num_emb, num_out, num_mlp_layers) + mlps[atype.cname] = mlp(num_in, num_emb, num_out, cfg.model.graph_transformer.num_mlp_layers) self.mlps = nn.ModuleDict(mlps) self.do_bck = do_bck if do_bck: self.bck_action_type_order = env_ctx.bck_action_type_order - self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, num_mlp_layers) + self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) @@ -232,13 +233,14 @@ def _action_type_to_logit(self, t, emb, g): def _mask(self, x, m): # mask logit vector x with binary mask m, -1000 is a tiny log-value + # Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99) return x * m + -1000 * (1 - m) def _make_cat(self, g, emb, action_types): return GraphActionCategorical( g, logits=[self._action_type_to_logit(t, emb, g) for t in action_types], - keys=[self._graph_part_to_key[self._action_type_to_graph_part[t]] for t in action_types], + keys=[self._action_type_to_key[t] for t in action_types], masks=[self._action_type_to_mask(t, g) for t in action_types], types=action_types, ) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py new file mode 100644 index 00000000..98791be5 --- /dev/null +++ b/src/gflownet/online_trainer.py @@ -0,0 +1,107 @@ +import copy +import os +import pathlib + +import git +import torch +from omegaconf import OmegaConf +from torch import Tensor + +from gflownet.algo.advantage_actor_critic import A2C +from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.soft_q_learning import SoftQLearning +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.models.graph_transformer import GraphTransformerGFN + +from .trainer import GFNTrainer + + +class StandardOnlineTrainer(GFNTrainer): + def setup_model(self): + self.model = GraphTransformerGFN( + self.ctx, + self.cfg, + do_bck=self.cfg.algo.tb.do_parameterize_p_b, + ) + + def setup_algo(self): + algo = self.cfg.algo.method + if algo == "TB": + algo = TrajectoryBalance + elif algo == "FM": + algo = FlowMatching + elif algo == "A2C": + algo = A2C + elif algo == "SQL": + algo = SoftQLearning + else: + raise ValueError(algo) + self.algo = algo(self.env, self.ctx, self.rng, self.cfg) + + def setup_data(self): + self.training_data = [] + self.test_data = [] + + def setup(self): + super().setup() + self.offline_ratio = 0 + self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + + # Separate Z parameters from non-Z to allow for LR decay on the former + if hasattr(self.model, "logZ"): + Z_params = list(self.model.logZ.parameters()) + non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] + else: + Z_params = [] + non_Z_params = list(self.model.parameters()) + self.opt = torch.optim.Adam( + non_Z_params, + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + + self.sampling_tau = self.cfg.algo.sampling_tau + if self.sampling_tau > 0: + self.sampling_model = copy.deepcopy(self.model) + else: + self.sampling_model = self.model + + self.mb_size = self.cfg.algo.global_batch_size + self.clip_grad_callback = { + "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "none": lambda x: None, + }[self.cfg.opt.clip_grad_type] + + # saving hyperparameters + git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] + self.cfg.git_hash = git_hash + + os.makedirs(self.cfg.log_dir, exist_ok=True) + print("\n\nHyperparameters:\n") + yaml = OmegaConf.to_yaml(self.cfg) + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + f.write(yaml) + + def step(self, loss: Tensor): + loss.backward() + for i in self.model.parameters(): + self.clip_grad_callback(i) + self.opt.step() + self.opt.zero_grad() + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched.step() + self.lr_sched_Z.step() + if self.sampling_tau > 0: + for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): + b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py new file mode 100644 index 00000000..a9f6ac3f --- /dev/null +++ b/src/gflownet/tasks/config.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + + +@dataclass +class SEHTaskConfig: + pass # SEH just uses a temperature conditional + + +@dataclass +class SEHMOOTaskConfig: + """Config for the SEHMOOTask + + Attributes + ---------- + use_steer_thermometer : bool + Whether to use a thermometer encoding for the steering. + preference_type : Optional[str] + The preference sampling distribution, defaults to "dirichlet". + focus_type : Union[list, str, None] + The type of focus distribtuion used, see SEHMOOTask.setup_focus_regions. + focus_cosim : float + The cosine similarity threshold for the focus distribution. + focus_limit_coef : float + The smoothing coefficient for the focus reward. + focus_model_training_limits : Optional[Tuple[int, int]] + The training limits for the focus sampling model (if used). + focus_model_state_space_res : Optional[int] + The state space resolution for the focus sampling model (if used). + max_train_it : Optional[int] + The maximum number of training iterations for the focus sampling model (if used). + n_valid : int + The number of valid cond_info tensors to sample + n_valid_repeats : int + The number of times to repeat the valid cond_info tensors + objectives : List[str] + The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. + """ + + use_steer_thermometer: bool = False + preference_type: Optional[str] = "dirichlet" + focus_type: Optional[str] = None + focus_dirs_listed: Optional[List[List[float]]] = None + focus_cosim: float = 0.0 + focus_limit_coef: float = 1.0 + focus_model_training_limits: Optional[Tuple[int, int]] = None + focus_model_state_space_res: Optional[int] = None + max_train_it: Optional[int] = None + n_valid: int = 15 + n_valid_repeats: int = 128 + objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) + + +@dataclass +class QM9TaskConfig: + h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + model_path: str = "./data/qm9/qm9_model.pt" + + +@dataclass +class TasksConfig: + qm9: QM9TaskConfig = QM9TaskConfig() + seh: SEHTaskConfig = SEHTaskConfig() + seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py new file mode 100644 index 00000000..c3e8d0f9 --- /dev/null +++ b/src/gflownet/tasks/make_rings.py @@ -0,0 +1,90 @@ +import os +import socket +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from rdkit import Chem +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor + +from gflownet.config import Config +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar + + +class MakeRingsTask(GFNTask): + """A toy task where the reward is the number of rings in the molecule.""" + + def __init__( + self, + rng: np.random.Generator, + ): + self.rng = rng + + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(y) + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log() + return RewardScalar(scalar_logreward.flatten()) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() + return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() + + +class MakeRingsTrainer(StandardOnlineTrainer): + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.num_workers = 8 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 6 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.tb.do_parameterize_p_b = True + + cfg.replay.use = False + + def setup_task(self): + self.task = MakeRingsTask(rng=self.rng) + + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C"], + charges=[0], # disable charge + chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality + num_rw_feat=0, + max_nodes=self.cfg.algo.max_nodes, + num_cond_dim=1, + ) + + +def main(): + hps = { + "log_dir": "./logs/debug_run_mr4", + "device": "cuda", + "num_training_steps": 10_000, + "num_workers": 8, + "algo": {"tb": {"do_parameterize_p_b": True}}, + } + os.makedirs(hps["log_dir"], exist_ok=True) + + trial = MakeRingsTrainer(hps) + trial.print_every = 1 + trial.run() + + +if __name__ == "__main__": + main() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index d5ed182f..e5b1d29a 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,27 +1,22 @@ -import ast -import copy import os -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class QM9GapTask(GFNTask): @@ -30,19 +25,15 @@ class QM9GapTask(GFNTask): def __init__( self, dataset: Dataset, - temperature_distribution: str, - temperature_parameters: Tuple[float, float], - num_thermometer_dim: int, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models() + self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset - self.temperature_sample_dist = temperature_distribution - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) # TODO: fix interface self._min, self._max, self._percentile_95 = self.dataset.get_stats(percentile=0.05) # type: ignore self._width = self._max - self._min @@ -69,49 +60,20 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self): + def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? - state_dict = torch.load("/data/chem/qm9/mxmnet_gap_model.pt") + state_dict = torch.load(path) gap_model.load_state_dict(state_dict) gap_model.cuda() - gap_model, self.device = self._wrap_model(gap_model) + gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) in [float, int] - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": beta, "encoding": beta_enc} + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] @@ -126,103 +88,59 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class QM9GapTrainer(GFNTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - "bootstrap_own_reward": False, - "learning_rate": 1e-4, - "global_batch_size": 64, - "num_emb": 128, - "num_layers": 4, - "tb_epsilon": None, - "illegal_action_logreward": -75, - "reward_loss_multiplier": 1, - "temperature_sample_dist": "uniform", - "temperature_dist_params": (0.5, 32.0), - "weight_decay": 1e-8, - "num_data_loader_workers": 8, - "momentum": 0.9, - "adam_eps": 1e-8, - "lr_decay": 20000, - "Z_lr_decay": 20000, - "clip_grad_type": "norm", - "clip_grad_param": 10, - "random_action_prob": 0.001, - "sampling_tau": 0.0, - "num_thermometer_dim": 32, - } - - def setup(self): - hps = self.hps - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.ctx = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32) - self.training_data = QM9Dataset(hps["qm9_h5_path"], train=True, target="gap") - self.test_data = QM9Dataset(hps["qm9_h5_path"], train=False, target="gap") - - model = GraphTransformerGFN(self.ctx, num_emb=hps["num_emb"], num_layers=hps["num_layers"]) - self.model = model - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - hps["learning_rate"], - (hps["momentum"], 0.999), - weight_decay=hps["weight_decay"], - eps=hps["adam_eps"], +class QM9GapTrainer(StandardOnlineTrainer): + def set_default_hps(self, cfg: Config): + cfg.num_workers = 8 + cfg.num_training_steps = 100000 + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20000 + cfg.opt.clip_grad_type = "norm" + cfg.opt.clip_grad_param = 10 + cfg.algo.max_nodes = 9 + cfg.algo.global_batch_size = 64 + cfg.algo.train_random_action_prob = 0.001 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.sampling_tau = 0.0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + cfg.cond.temperature.sample_dist = "uniform" + cfg.cond.temperature.dist_params = [0.5, 32.0] + cfg.cond.temperature.num_thermometer_dim = 32 + + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True ) - self.opt_Z = torch.optim.Adam(Z_params, hps["learning_rate"], (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / hps["lr_decay"])) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(self.opt_Z, lambda steps: 2 ** (-steps / hps["Z_lr_decay"])) - - self.sampling_tau = hps["sampling_tau"] - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(model) - else: - self.sampling_model = self.model - eps = hps["tb_epsilon"] - hps["tb_epsilon"] = ast.literal_eval(eps) if isinstance(eps, str) else eps - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, hps, max_nodes=9) - + # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories + # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action + # for setting the explicit hydrogen is used before the positive charge is set, it will be considered + # an invalid action. However, generate_forward_trajectory does not consider this implementation detail, + # it assumes that attribute-setting will always be valid. For the molecular environment, as of writing + # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. + + def setup_data(self): + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") + + def setup_task(self): self.task = QM9GapTask( dataset=self.training_data, - temperature_distribution=hps["temperature_sample_dist"], - temperature_parameters=hps["temperature_dist_params"], - num_thermometer_dim=hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, ) - self.mb_size = hps["global_batch_size"] - self.clip_grad_param = hps["clip_grad_param"] - self.clip_grad_callback = { - "value": (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)), - "norm": (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)), - "none": (lambda x: None), - }[hps["clip_grad_type"]] - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" yaml = YAML(typ="safe", pure=True) config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") with open(config_file, "r") as f: hps = yaml.load(f) - trial = QM9GapTrainer(hps, torch.device("cpu")) + trial = QM9GapTrainer(hps) trial.run() diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml index 01ea17f4..19701fac 100644 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ b/src/gflownet/tasks/qm9/qm9.yaml @@ -1,5 +1,10 @@ -lr_decay: 10000 -qm9_h5_path: /data/chem/qm9/qm9.h5 -log_dir: /scratch/logs/qm9_gap_mxmnet +opt: + lr_decay: 10000 +task: + qm9: + h5_path: /rxrx/data/chem/qm9/qm9.h5 + model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt num_training_steps: 100000 validate_every: 100 +log_dir: ./logs/debug_qm9 +num_workers: 0 diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e51e21d9..4d0cc624 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,27 +1,22 @@ -import ast -import copy import os import shutil import socket -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset -from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.models import bengio2021flow -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class SEHTask(GFNTask): @@ -31,15 +26,13 @@ class SEHTask(GFNTask): The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. This setup essentially reproduces the results of the Trajectory Balance paper when using the TB - objective, or of the original paper when using Flow Matching (TODO: port to this repo). + objective, or of the original paper when using Flow Matching. """ def __init__( self, dataset: Dataset, - temperature_distribution: str, - temperature_parameters: Tuple[float, float], - num_thermometer_dim: int, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): @@ -47,9 +40,8 @@ def __init__( self.rng = rng self.models = self._load_task_models() self.dataset = dataset - self.temperature_sample_dist = temperature_distribution - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -59,43 +51,14 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model) + model, self.device = self._wrap_model(model, send_to_device=True) return {"seh": model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) is float - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": beta, "encoding": beta_enc} + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -110,128 +73,74 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class SEHFragTrainer(GFNTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - "hostname": socket.gethostname(), - "bootstrap_own_reward": False, - "learning_rate": 1e-4, - "Z_learning_rate": 1e-4, - "global_batch_size": 64, - "num_emb": 128, - "num_layers": 4, - "tb_epsilon": None, - "tb_p_b_is_parameterized": False, - "illegal_action_logreward": -75, - "reward_loss_multiplier": 1, - "temperature_sample_dist": "uniform", - "temperature_dist_params": (0.5, 32.0), - "weight_decay": 1e-8, - "num_data_loader_workers": 8, - "momentum": 0.9, - "adam_eps": 1e-8, - "lr_decay": 20000, - "Z_lr_decay": 20000, - "clip_grad_type": "norm", - "clip_grad_param": 10, - "random_action_prob": 0.0, - "valid_random_action_prob": 0.0, - "sampling_tau": 0.0, - "max_nodes": 9, - "num_thermometer_dim": 32, - } - - def setup_algo(self): - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.hps, max_nodes=self.hps["max_nodes"]) +class SEHFragTrainer(StandardOnlineTrainer): + task: SEHTask + + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.pickle_mp_messages = False + cfg.num_workers = 8 + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20_000 + cfg.opt.clip_grad_type = "norm" + cfg.opt.clip_grad_param = 10 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 9 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.valid_offline_ratio = 0 + cfg.algo.tb.epsilon = None + cfg.algo.tb.bootstrap_own_reward = False + cfg.algo.tb.Z_learning_rate = 1e-3 + cfg.algo.tb.Z_lr_decay = 50_000 + cfg.algo.tb.do_parameterize_p_b = False + + cfg.replay.use = False + cfg.replay.capacity = 10_000 + cfg.replay.warmup = 1_000 def setup_task(self): self.task = SEHTask( dataset=self.training_data, - temperature_distribution=self.hps["temperature_sample_dist"], - temperature_parameters=self.hps["temperature_dist_params"], - num_thermometer_dim=self.hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, ) - def setup_model(self): - self.model = GraphTransformerGFN(self.ctx, num_emb=self.hps["num_emb"], num_layers=self.hps["num_layers"]) - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext( - max_frags=self.hps["max_nodes"], num_cond_dim=self.hps["num_thermometer_dim"] - ) - - def setup(self): - hps = self.hps - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.training_data = [] - self.test_data = [] - self.offline_ratio = 0 - self.valid_offline_ratio = 0 - self.setup_env_context() - self.setup_algo() - self.setup_task() - self.setup_model() - - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(self.model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - hps["learning_rate"], - (hps["momentum"], 0.999), - weight_decay=hps["weight_decay"], - eps=hps["adam_eps"], - ) - self.opt_Z = torch.optim.Adam(Z_params, hps["Z_learning_rate"], (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / hps["lr_decay"])) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(self.opt_Z, lambda steps: 2 ** (-steps / hps["Z_lr_decay"])) - - self.sampling_tau = hps["sampling_tau"] - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(self.model) - else: - self.sampling_model = self.model - eps = hps["tb_epsilon"] - hps["tb_epsilon"] = ast.literal_eval(eps) if isinstance(eps, str) else eps - - self.mb_size = hps["global_batch_size"] - self.clip_grad_param = hps["clip_grad_param"] - self.clip_grad_callback = { - "value": (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)), - "norm": (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)), - "none": (lambda x: None), - }[hps["clip_grad_type"]] - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_seh_frag", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), "overwrite_existing_exp": True, - "qm9_h5_path": "/data/chem/qm9/qm9.h5", "num_training_steps": 10_000, - "validate_every": 1, - "lr_decay": 20000, - "sampling_tau": 0.99, - "num_data_loader_workers": 8, - "temperature_dist_params": (0.0, 64.0), + "num_workers": 8, + "opt": { + "lr_decay": 20000, + }, + "algo": { + "sampling_tau": 0.99, + }, + "cond": { + "temperature": { + "sample_dist": "uniform", + "dist_params": [0, 64.0], + } + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -240,8 +149,8 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - trial.verbose = True + trial = SEHFragTrainer(hps) + trial.print_every = 1 trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index be09a5e6..8ad73320 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,10 +1,8 @@ -import json import os import pathlib import shutil from typing import Any, Callable, Dict, List, Tuple, Union -import git import numpy as np import torch import torch.nn as nn @@ -12,22 +10,18 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.distributions.dirichlet import Dirichlet from torch.utils.data import Dataset -from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -from gflownet.algo.soft_q_learning import SoftQLearning -from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow -from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask -from gflownet.train import FlatRewards, RewardScalar +from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore +from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook -from gflownet.utils.transforms import thermometer class SEHMOOTask(SEHTask): @@ -42,27 +36,30 @@ class SEHMOOTask(SEHTask): def __init__( self, - objectives: List[str], dataset: Dataset, - temperature_sample_dist: str, - temperature_parameters: Tuple[float, float], - num_thermometer_dim: int, - use_pref_thermometer: bool, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): - self._wrap_model = wrap_model - self.rng = rng - self.models = self._load_task_models() - self.objectives = objectives + super().__init__(dataset, cfg, rng, wrap_model) + self.cfg = cfg + mcfg = self.cfg.task.seh_moo + self.objectives = cfg.task.seh_moo.objectives self.dataset = dataset - self.temperature_sample_dist = temperature_sample_dist - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim - self.use_pref_thermometer = use_pref_thermometer - self.seeded_preference = None - self.experimental_dirichlet = False - assert set(objectives) <= {"seh", "qed", "sa", "mw"} and len(objectives) == len(set(objectives)) + if self.cfg.cond.focus_region.focus_type is not None: + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) + else: + self.focus_cond = None + self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) + self.temperature_sample_dist = cfg.cond.temperature.sample_dist + self.temperature_dist_params = cfg.cond.temperature.dist_params + self.num_thermometer_dim = cfg.cond.temperature.num_thermometer_dim + self.num_cond_dim = ( + self.temperature_conditional.encoding_size() + + self.pref_cond.encoding_size() + + (self.focus_cond.encoding_size() if self.focus_cond is not None else 0) + ) + assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y)) @@ -70,48 +67,83 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: def inverse_flat_reward_transform(self, rp): return rp - def _load_task_models(self): - model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model) - return {"seh": model} - - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - cond_info = super().sample_conditional_information(n) - - if self.seeded_preference is not None: - preferences = torch.tensor([self.seeded_preference] * n).float() - elif self.experimental_dirichlet: - a = np.random.dirichlet([1] * len(self.objectives), n) - b = np.random.exponential(1, n)[:, None] - preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() - else: - m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - preferences = m.sample([n]) - - preferences_enc = ( - thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1) - if self.use_pref_thermometer - else preferences + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + cond_info = super().sample_conditional_information(n, train_it) + pref_ci = self.pref_cond.sample(n) + focus_ci = ( + self.focus_cond.sample(n, train_it) if self.focus_cond is not None else {"encoding": torch.zeros(n, 0)} ) - cond_info["encoding"] = torch.cat([cond_info["encoding"], preferences_enc], 1) - cond_info["preferences"] = preferences + cond_info = { + **cond_info, + **pref_ci, + **focus_ci, + "encoding": torch.cat([cond_info["encoding"], pref_ci["encoding"], focus_ci["encoding"]], dim=1), + } return cond_info - def encode_conditional_information(self, preferences: Tensor) -> Dict[str, Tensor]: - n = len(preferences) + def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: + """ + Encode conditional information at validation-time + We use the maximum temperature beta for inference + Args: + steer_info: Tensor of shape (Batch, 2 * n_objectives) containing the preferences and focus_dirs + in that order + Returns: + Dict[str, Tensor]: Dictionary containing the encoded conditional information + """ + n = len(steer_info) if self.temperature_sample_dist == "constant": - beta = torch.ones(n) * self.temperature_dist_params + beta = torch.ones(n) * self.temperature_dist_params[0] beta_enc = torch.zeros((n, self.num_thermometer_dim)) else: beta = torch.ones(n) * self.temperature_dist_params[-1] beta_enc = torch.ones((n, self.num_thermometer_dim)) assert len(beta.shape) == 1, f"beta should be of shape (Batch,), got: {beta.shape}" - if self.use_pref_thermometer: - encoding = torch.cat([beta_enc, thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1)], 1) + + # TODO: positional assumption here, should have something cleaner + preferences = steer_info[:, : len(self.objectives)].float() + focus_dir = steer_info[:, len(self.objectives) :].float() + + preferences_enc = self.pref_cond.encode(preferences) + if self.focus_cond is not None: + focus_enc = self.focus_cond.encode(focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() else: - encoding = torch.cat([beta_enc, preferences], 1) - return {"beta": beta, "encoding": encoding.float(), "preferences": preferences.float()} + encoding = torch.cat([beta_enc, preferences_enc], 1).float() + return { + "beta": beta, + "encoding": encoding, + "preferences": preferences, + "focus_dir": focus_dir, + } + + def relabel_condinfo_and_logrewards( + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + ): + # TODO: we seem to be relabeling tensors in place, could that cause a problem? + if self.focus_cond is None: + raise NotImplementedError("Hindsight relabeling only implemented for focus conditioning") + if self.focus_cond.cfg.focus_type is None: + return cond_info, log_rewards + # only keep hindsight_idxs that actually correspond to a violated constraint + _, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + ) + out_focus_mask = torch.logical_not(in_focus_mask) + hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] + + # relabels the focus_dirs and log_rewards + cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) + + preferences_enc = self.pref_cond.encode(cond_info["preferences"]) + focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) + cond_info["encoding"] = torch.cat( + [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 + ) + + log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) + return cond_info, log_rewards def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): @@ -119,11 +151,15 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + + scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) + focused_reward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + if self.focus_cond is not None + else scalarized_reward + ) + tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) + return RewardScalar(tempered_reward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -132,13 +168,13 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid else: - flat_rewards: List[Tensor] = [] + flat_r: List[Tensor] = [] if "seh" in self.objectives: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) batch.to(self.device) seh_preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8 seh_preds[seh_preds.isnan()] = 0 - flat_rewards.append(seh_preds) + flat_r.append(seh_preds) def safe(f, x, default): try: @@ -148,132 +184,138 @@ def safe(f, x, default): if "qed" in self.objectives: qeds = torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v.item()]) - flat_rewards.append(qeds) + flat_r.append(qeds) if "sa" in self.objectives: sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i, v in zip(mols, is_valid) if v.item()]) sas = (10 - sas) / 9 # Turn into a [0-1] reward - flat_rewards.append(sas) + flat_r.append(sas) if "mw" in self.objectives: molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i, v in zip(mols, is_valid) if v.item()]) molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 - flat_rewards.append(molwts) + flat_r.append(molwts) - flat_rewards = torch.stack(flat_rewards, dim=1) + flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid class SEHMOOFragTrainer(SEHFragTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - **super().default_hps(), - "use_fixed_weight": False, - "objectives": ["seh", "qed", "sa", "mw"], - "sampling_tau": 0.95, - "valid_sample_cond_info": False, - "n_valid_prefs": 15, - "n_valid_repeats_per_pref": 128, - "preference_type": "dirichlet", - "use_pref_thermometer": False, - } + task: SEHMOOTask + ctx: FragMolBuildingEnvContext + + def set_default_hps(self, cfg: Config): + super().set_default_hps(cfg) + cfg.algo.sampling_tau = 0.95 + # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) + # sampling and set the offline ratio to 1 + cfg.algo.valid_sample_cond_info = False + cfg.algo.valid_offline_ratio = 1 def setup_algo(self): - hps = self.hps - if hps["algo"] == "TB": - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "SQL": - self.algo = SoftQLearning(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "A2C": - self.algo = A2C(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "MOREINFORCE": - self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "MOQL": - self.algo = EnvelopeQLearning(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) + algo = self.cfg.algo.method + if algo == "MOREINFORCE": + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + elif algo == "MOQL": + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + else: + super().setup_algo() def setup_task(self): self.task = SEHMOOTask( - objectives=self.hps["objectives"], dataset=self.training_data, - temperature_sample_dist=self.hps["temperature_sample_dist"], - temperature_parameters=self.hps["temperature_dist_params"], - num_thermometer_dim=self.hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, - use_pref_thermometer=self.hps["use_pref_thermometer"], + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, ) + def setup_env_context(self): + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) + def setup_model(self): - if self.hps["algo"] == "MOQL": - model = GraphTransformerFragEnvelopeQL( + if self.cfg.algo.method == "MOQL": + self.model = GraphTransformerFragEnvelopeQL( self.ctx, - num_emb=self.hps["num_emb"], - num_layers=self.hps["num_layers"], - num_objectives=len(self.hps["objectives"]), + num_emb=self.cfg.model.num_emb, + num_layers=self.cfg.model.num_layers, + num_heads=self.cfg.model.graph_transformer.num_heads, + num_objectives=len(self.cfg.task.seh_moo.objectives), ) else: - model = GraphTransformerGFN( - self.ctx, - num_emb=self.hps["num_emb"], - num_layers=self.hps["num_layers"], - do_bck=self.hps["tb_p_b_is_parameterized"], - ) - - if self.hps["algo"] in ["A2C", "MOQL"]: - model.do_mask = False - self.model = model - - def setup_env_context(self): - if self.hps.get("use_pref_thermometer", False): - ncd = self.hps["num_thermometer_dim"] * (1 + len(self.hps["objectives"])) - else: - ncd = self.hps["num_thermometer_dim"] + len(self.hps["objectives"]) - self.ctx = FragMolBuildingEnvContext(max_frags=9, num_cond_dim=ncd) + super().setup_model() def setup(self): super().setup() self.sampling_hooks.append( - MultiObjectiveStatsHook(256, self.hps["log_dir"], compute_igd=True, compute_pc_entropy=True) + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.task.seh_moo.focus_type is not None else False, + focus_cosim=self.cfg.task.seh_moo.focus_cosim, + ) ) - - n_obj = len(self.hps["objectives"]) - - # create fixed preference vectors for validation - if self.hps["preference_type"] is None: - valid_preferences = np.ones((self.hps["n_valid_prefs"], n_obj)) - elif self.hps["preference_type"] == "dirichlet": - valid_preferences = metrics.partition_hypersphere(d=n_obj, k=self.hps["n_valid_prefs"], normalisation="l1") - elif self.hps["preference_type"] == "seeded_single": - seeded_prefs = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet( - [1] * n_obj, self.hps["n_valid_prefs"] + # instantiate preference and focus conditioning vectors for validation + + tcfg = self.cfg.task.seh_moo + n_obj = len(tcfg.objectives) + + # making sure hyperparameters for preferences and focus regions are consistent + if not ( + tcfg.focus_type is None + or tcfg.focus_type == "centered" + or (type(tcfg.focus_type) is list and len(tcfg.focus_type) == 1) + ): + assert tcfg.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " + f"and preference_type={tcfg.preference_type}" ) + + if type(tcfg.focus_type) is list and len(tcfg.focus_type) > 1: + n_valid = len(tcfg.focus_type) + else: + n_valid = tcfg.n_valid + + # preference vectors + if tcfg.preference_type is None: + valid_preferences = np.ones((n_valid, n_obj)) + elif tcfg.preference_type == "dirichlet": + valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") + elif tcfg.preference_type == "seeded_single": + seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] - elif self.hps["preference_type"] == "seeded_many": - valid_preferences = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet( - [1] * n_obj, self.hps["n_valid_prefs"] + elif tcfg.preference_type == "seeded_many": + valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) + else: + raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") + + # TODO: this was previously reported, would be nice to serialize it + # hps["fixed_focus_dirs"] = ( + # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None + # ) + if self.task.focus_cond is not None: + assert self.task.focus_cond.valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), ( + "Invalid shape for valid_preferences, " + f"{self.task.focus_cond.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" ) - self._top_k_hook = TopKHook(10, self.hps["n_valid_repeats_per_pref"], len(valid_preferences)) - self.test_data = RepeatedPreferenceDataset(valid_preferences, self.hps["n_valid_repeats_per_pref"]) + # combine preferences and focus directions (fixed focus cosim) since they could be used together + # (not either/or). TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, self.task.focus_cond.valid_focus_dirs], axis=1) + else: + valid_cond_vector = valid_preferences + + self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task - git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.hps["gflownet_git_hash"] = git_hash - - os.makedirs(self.hps["log_dir"], exist_ok=True) - torch.save( - { - "hps": self.hps, - }, - open(pathlib.Path(self.hps["log_dir"]) / "hps.pt", "wb"), - ) - fmt_hps = "\n".join([f"{k}:\t({type(v).__name__})\t{v}".expandtabs(40) for k, v in self.hps.items()]) - print(f"\n\nHyperparameters:\n{'-'*50}\n{fmt_hps}\n{'-'*50}\n\n") - with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as fd: - json.dump(self.hps, fd, sort_keys=True, indent=4) - def build_callbacks(self): # We use this class-based setup to be compatible with the DeterminedAI API, but no direct # dependency is required. @@ -288,45 +330,90 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + if self.task.focus_cond is not None: + self.task.focus_cond.step_focus_model(batch, train_it) + return super().train_batch(batch, epoch_idx, batch_idx, train_it) + + def _save_state(self, it): + if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: + self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) + return super()._save_state(it) + -class RepeatedPreferenceDataset: - def __init__(self, preferences, repeat): - self.prefs = preferences +class RepeatedCondInfoDataset: + def __init__(self, cond_info_vectors, repeat): + self.cond_info_vectors = cond_info_vectors self.repeat = repeat def __len__(self): - return len(self.prefs) * self.repeat + return len(self.cond_info_vectors) * self.repeat def __getitem__(self, idx): assert 0 <= idx < len(self) - return torch.tensor(self.prefs[int(idx // self.repeat)]) + return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_sfm", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), + "pickle_mp_messages": True, "overwrite_existing_exp": True, "seed": 0, - "global_batch_size": 64, - "num_training_steps": 20_000, - "validate_every": 1, - "num_layers": 4, - "algo": "TB", - "objectives": ["seh", "qed"], - "learning_rate": 1e-4, - "Z_learning_rate": 1e-3, - "lr_decay": 20000, - "Z_lr_decay": 50000, - "sampling_tau": 0.95, - "random_action_prob": 0.1, - "num_data_loader_workers": 8, - "temperature_sample_dist": "constant", - "temperature_dist_params": 60.0, - "num_thermometer_dim": 32, - "preference_type": "dirichlet", - "n_valid_prefs": 15, - "n_valid_repeats_per_pref": 128, + "num_training_steps": 500, + "num_final_gen_steps": 50, + "validate_every": 100, + "num_workers": 0, + "algo": { + "global_batch_size": 64, + "method": "TB", + "sampling_tau": 0.95, + "train_random_action_prob": 0.01, + "tb": { + "Z_learning_rate": 1e-3, + "Z_lr_decay": 50000, + }, + }, + "model": { + "num_layers": 2, + "num_emb": 256, + }, + "task": { + "seh_moo": { + "objectives": ["seh", "qed"], + "n_valid": 15, + "n_valid_repeats": 128, + }, + }, + "opt": { + "learning_rate": 1e-4, + "lr_decay": 20000, + }, + "cond": { + "temperature": { + "sample_dist": "constant", + "dist_params": [60.0], + "num_thermometer_dim": 32, + }, + "weighted_prefs": { + "preference_type": "dirichlet", + }, + "focus_region": { + "focus_type": None, # "learned-tabular", + "focus_cosim": 0.98, + "focus_limit_coef": 1e-1, + "focus_model_training_limits": (0.25, 0.75), + "focus_model_state_space_res": 30, + "max_train_it": 5_000, + }, + }, + "replay": { + "use": False, + "warmup": 1000, + "hindsight_ratio": 0.0, + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -335,8 +422,8 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHMOOFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - trial.verbose = True + trial = SEHMOOFragTrainer(hps) + trial.print_every = 1 trial.run() diff --git a/src/gflownet/train.py b/src/gflownet/trainer.py similarity index 58% rename from src/gflownet/train.py rename to src/gflownet/trainer.py index e73cae92..93e0e0a5 100644 --- a/src/gflownet/train.py +++ b/src/gflownet/trainer.py @@ -2,18 +2,24 @@ import pathlib from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd +from omegaconf import OmegaConf +from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet.data.replay_buffer import ReplayBuffer from gflownet.data.sampling_iterator import SamplingIterator from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.utils.misc import create_logger -from gflownet.utils.multiprocessing_proxy import wrap_model_mp +from gflownet.utils.multiprocessing_proxy import mp_object_wrapper + +from .config import Config # This type represents an unprocessed list of reward signals/conditioning information FlatRewards = NewType("FlatRewards", Tensor) # type: ignore @@ -28,6 +34,7 @@ def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: """Computes the loss for a batch of data, and proves logging informations + Parameters ---------- model: nn.Module @@ -36,6 +43,7 @@ def compute_batch_losses( A batch of graphs num_bootstrap: Optional[int] The number of trajectories with reward targets in the batch (if applicable). + Returns ------- loss: Tensor @@ -82,13 +90,13 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class GFNTrainer: - def __init__(self, hps: Dict[str, Any], device: torch.device): + def __init__(self, hps: Dict[str, Any]): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters ---------- hps: Dict[str, Any] - A dictionary of hyperparameters. These override default values obtained by the `default_hps` method. + A dictionary of hyperparameters. These override default values obtained by the `set_default_hps` method. device: torch.device The torch device of the main worker. """ @@ -99,112 +107,170 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): # `sampling_model` is used by the data workers to sample new objects from the model. Can be # the same as `model`. self.sampling_model: nn.Module + self.replay_buffer: Optional[ReplayBuffer] self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext self.task: GFNTask self.algo: GFNAlgorithm - # Override default hyperparameters with the constructor arguments - self.hps = {**self.default_hps(), **hps} - self.device = device - # The number of processes spawned to sample object and do CPU work - self.num_workers: int = self.hps.get("num_data_loader_workers", 0) - # The ratio of samples drawn from `self.training_data` during training. The rest is drawn from - # `self.sampling_model`. - self.offline_ratio = self.hps.get("offline_ratio", 0.5) - # idem, but from `self.test_data` during validation. - self.valid_offline_ratio = 1 - # If True, print messages during training - self.verbose = False + # There are three sources of config values + # - The default values specified in individual config classes + # - The default values specified in the `default_hps` method, typically what is defined by a task + # - The values passed in the constructor, typically what is called by the user + # The final config is obtained by merging the three sources + self.cfg: Config = OmegaConf.structured(Config()) + self.set_default_hps(self.cfg) + # OmegaConf returns a fancy object but we can still pretend it's a Config instance + self.cfg = OmegaConf.merge(self.cfg, hps) # type: ignore + + self.device = torch.device(self.cfg.device) + # Print the loss every `self.print_every` iterations + self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = [] # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False - # Pickle messages to reduce load on shared memory (conversely, increases load on CPU) - self.pickle_messages = hps.get("mp_pickle_messages", False) self.setup() - def default_hps(self) -> Dict[str, Any]: + def set_default_hps(self, base: Config): raise NotImplementedError() - def setup(self): + def setup_env_context(self): + raise NotImplementedError() + + def setup_task(self): + raise NotImplementedError() + + def setup_model(self): + raise NotImplementedError() + + def setup_algo(self): raise NotImplementedError() + def setup_data(self): + pass + def step(self, loss: Tensor): raise NotImplementedError() - def _wrap_model_mp(self, model): - """Wraps a nn.Module instance so that it can be shared to `DataLoader` workers.""" - model.to(self.device) - if self.num_workers > 0: - placeholder = wrap_model_mp( - model, - self.num_workers, + def setup(self): + RDLogger.DisableLog("rdApp.*") + self.rng = np.random.default_rng(142857) + self.env = GraphBuildingEnv() + self.setup_data() + self.setup_task() + self.setup_env_context() + self.setup_algo() + self.setup_model() + + def _wrap_for_mp(self, obj, send_to_device=False): + """Wraps an object in a placeholder whose reference can be sent to a + data worker process (only if the number of workers is non-zero).""" + if send_to_device: + obj.to(self.device) + if self.cfg.num_workers > 0 and obj is not None: + placeholder = mp_object_wrapper( + obj, + self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical), - pickle_messages=self.pickle_messages, + pickle_messages=self.cfg.pickle_mp_messages, ) return placeholder, torch.device("cpu") - return model, self.device + else: + return obj, self.device def build_callbacks(self): return {} def build_training_data_loader(self) -> DataLoader: - model, dev = self._wrap_model_mp(self.sampling_model) + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) iterator = SamplingIterator( self.training_data, model, - self.mb_size, self.ctx, self.algo, self.task, dev, - ratio=self.offline_ratio, - log_dir=os.path.join(self.hps["log_dir"], "train"), - random_action_prob=self.hps.get("random_action_prob", 0.0), + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, + replay_buffer=replay_buffer, + ratio=self.cfg.algo.offline_ratio, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), + random_action_prob=self.cfg.algo.train_random_action_prob, + hindsight_ratio=self.cfg.replay.hindsight_ratio, ) for hook in self.sampling_hooks: iterator.add_log_hook(hook) return torch.utils.data.DataLoader( iterator, batch_size=None, - num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, # The 2 here is an odd quirk of torch 1.10, it is fixed and # replaced by None in torch 2. - prefetch_factor=1 if self.num_workers else 2, + prefetch_factor=1 if self.cfg.num_workers else 2, ) def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_model_mp(self.model) + model, dev = self._wrap_for_mp(self.model, send_to_device=True) iterator = SamplingIterator( self.test_data, model, - self.mb_size, self.ctx, self.algo, self.task, dev, - ratio=self.valid_offline_ratio, - log_dir=os.path.join(self.hps["log_dir"], "valid"), - sample_cond_info=self.hps.get("valid_sample_cond_info", True), + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, + ratio=self.cfg.algo.valid_offline_ratio, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), + sample_cond_info=self.cfg.algo.valid_sample_cond_info, stream=False, - random_action_prob=self.hps.get("valid_random_action_prob", 0.0), + random_action_prob=self.cfg.algo.valid_random_action_prob, ) for hook in self.valid_sampling_hooks: iterator.add_log_hook(hook) return torch.utils.data.DataLoader( iterator, batch_size=None, - num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, - prefetch_factor=1 if self.num_workers else 2, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else 2, + ) + + def build_final_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + iterator = SamplingIterator( + self.training_data, + model, + self.ctx, + self.algo, + self.task, + dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, + replay_buffer=None, + ratio=0.0, + log_dir=os.path.join(self.cfg.log_dir, "final"), + random_action_prob=0.0, + hindsight_ratio=0.0, + init_train_iter=self.cfg.num_training_steps, + ) + for hook in self.sampling_hooks: + iterator.add_log_hook(hook) + return torch.utils.data.DataLoader( + iterator, + batch_size=None, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else 2, ) - def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int) -> Dict[str, Any]: + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: try: loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): @@ -213,8 +279,8 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int) -> Dict[s if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: - os.makedirs(self.hps["log_dir"], exist_ok=True) - torch.save([self.model.state_dict(), batch, loss, info], open(self.hps["log_dir"] + "/dump.pkl", "wb")) + os.makedirs(self.cfg.log_dir, exist_ok=True) + torch.save([self.model.state_dict(), batch, loss, info], open(self.cfg.log_dir + "/dump.pkl", "wb")) raise e if step_info is not None: @@ -234,24 +300,32 @@ def run(self, logger=None): validation every `validate_every` minibatches. """ if logger is None: - logger = create_logger(logfile=self.hps["log_dir"] + "/train.log") + logger = create_logger(logfile=self.cfg.log_dir + "/train.log") self.model.to(self.device) self.sampling_model.to(self.device) epoch_length = max(len(self.training_data), 1) - valid_freq = self.hps.get("validate_every", 0) + valid_freq = self.cfg.validate_every # If checkpoint_every is not specified, checkpoint at every validation epoch - ckpt_freq = self.hps.get("checkpoint_every", valid_freq) + ckpt_freq = self.cfg.checkpoint_every if self.cfg.checkpoint_every is not None else valid_freq train_dl = self.build_training_data_loader() valid_dl = self.build_validation_data_loader() + if self.cfg.num_final_gen_steps: + final_dl = self.build_final_data_loader() callbacks = self.build_callbacks() - start = self.hps.get("start_at_step", 0) + 1 + start = self.cfg.start_at_step + 1 + num_training_steps = self.cfg.num_training_steps logger.info("Starting training") - for it, batch in zip(range(start, 1 + self.hps["num_training_steps"]), cycle(train_dl)): + for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): epoch_idx = it // epoch_length batch_idx = it % epoch_length - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx) + if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: + logger.info( + f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" + ) + continue + info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) self.log(info, it, "train") - if self.verbose: + if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) if valid_freq > 0 and it % valid_freq == 0: @@ -266,21 +340,31 @@ def run(self, logger=None): self.log(end_metrics, it, "valid_end") if ckpt_freq > 0 and it % ckpt_freq == 0: self._save_state(it) - self._save_state(self.hps["num_training_steps"]) + self._save_state(num_training_steps) + + num_final_gen_steps = self.cfg.num_final_gen_steps + if num_final_gen_steps: + logger.info(f"Generating final {num_final_gen_steps} batches ...") + for it, batch in zip( + range(num_training_steps, num_training_steps + num_final_gen_steps + 1), + cycle(final_dl), + ): + pass + logger.info("Final generation steps completed.") def _save_state(self, it): torch.save( { "models_state_dict": [self.model.state_dict()], - "hps": self.hps, + "cfg": self.cfg, "step": it, }, - open(pathlib.Path(self.hps["log_dir"]) / "model_state.pt", "wb"), + open(pathlib.Path(self.cfg.log_dir) / "model_state.pt", "wb"), ) def log(self, info, index, key): if not hasattr(self, "_summary_writer"): - self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.hps["log_dir"]) + self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py new file mode 100644 index 00000000..5893ef3b --- /dev/null +++ b/src/gflownet/utils/conditioning.py @@ -0,0 +1,246 @@ +import abc +from copy import deepcopy +from typing import Dict + +import numpy as np +import torch +from scipy import stats +from torch import Tensor +from torch.distributions.dirichlet import Dirichlet +from torch_geometric import data as gd + +from gflownet.config import Config +from gflownet.utils import metrics +from gflownet.utils.focus_model import TabularFocusModel +from gflownet.utils.transforms import thermometer + + +class Conditional(abc.ABC): + def sample(self, n): + raise NotImplementedError() + + @abc.abstractmethod + def transform(self, cond_info: Dict[str, Tensor], properties: Tensor) -> Tensor: + raise NotImplementedError() + + def encoding_size(self): + raise NotImplementedError() + + def encode(self, conditional: Tensor) -> Tensor: + raise NotImplementedError() + + +class TemperatureConditional(Conditional): + def __init__(self, cfg: Config, rng: np.random.Generator): + self.cfg = cfg + tmp_cfg = self.cfg.cond.temperature + self.rng = rng + self.upper_bound = 1024 + if tmp_cfg.sample_dist == "gamma": + loc, scale = tmp_cfg.dist_params + self.upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) + elif tmp_cfg.sample_dist == "uniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "loguniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "beta": + self.upper_bound = 1 + + def encoding_size(self): + return self.cfg.cond.temperature.num_thermometer_dim + + def sample(self, n): + cfg = self.cfg.cond.temperature + beta = None + if cfg.sample_dist == "constant": + assert type(cfg.dist_params[0]) is float + beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) + beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + else: + if cfg.sample_dist == "gamma": + loc, scale = cfg.dist_params + beta = self.rng.gamma(loc, scale, n).astype(np.float32) + elif cfg.sample_dist == "uniform": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.uniform(a, b, n).astype(np.float32) + elif cfg.sample_dist == "loguniform": + low, high = np.log(cfg.dist_params) + beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) + elif cfg.sample_dist == "beta": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.beta(a, b, n).astype(np.float32) + beta_enc = thermometer(torch.tensor(beta), cfg.num_thermometer_dim, 0, self.upper_bound) + + assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" + return {"beta": torch.tensor(beta), "encoding": beta_enc} + + def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor: + scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == len( + cond_info["beta"].shape + ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" + return scalar_logreward * cond_info["beta"] + + def encode(self, conditional: Tensor) -> Tensor: + cfg = self.cfg.cond.temperature + if cfg.sample_dist == "constant": + return torch.zeros((conditional.shape[0], cfg.num_thermometer_dim)) + return thermometer(torch.tensor(conditional), cfg.num_thermometer_dim, 0, self.upper_bound) + + +class MultiObjectiveWeightedPreferences(Conditional): + def __init__(self, cfg: Config): + self.cfg = cfg.cond.weighted_prefs + self.num_objectives = cfg.cond.moo.num_objectives + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim + if self.cfg.preference_type == "seeded": + self.seeded_prefs = np.random.default_rng(142857 + int(cfg.seed)).dirichlet([1] * self.num_objectives) + + def sample(self, n): + if self.cfg.preference_type is None: + preferences = torch.ones((n, self.num_objectives)) + elif self.cfg.preference_type == "seeded": + preferences = torch.tensor(self.seeded_prefs).float().repeat(n, 1) + elif self.cfg.preference_type == "dirichlet_exponential": + a = np.random.dirichlet([1] * self.num_objectives, n) + b = np.random.exponential(1, n)[:, None] + preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() + elif self.cfg.preference_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.num_objectives)) + preferences = m.sample([n]) + else: + raise ValueError(f"Unknown preference type {self.cfg.preference_type}") + preferences = torch.as_tensor(preferences).float() + return {"preferences": preferences, "encoding": self.encode(preferences)} + + def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: + scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" + return scalar_logreward + + def encoding_size(self): + return max(1, self.num_thermometer_dim * self.num_objectives) + + def encode(self, conditional: Tensor) -> Tensor: + if self.num_thermometer_dim > 0: + return thermometer(conditional, self.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + else: + return conditional.unsqueeze(1) + + +class FocusRegionConditional(Conditional): + def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): + self.cfg = cfg.cond.focus_region + self.n_valid = n_valid + self.n_objectives = cfg.cond.moo.num_objectives + self.ocfg = cfg + self.rng = rng + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim if self.cfg.use_steer_thermomether else 0 + + focus_type = self.cfg.focus_type + if focus_type is not None and "learned" in focus_type: + if focus_type == "learned-tabular": + self.focus_model = TabularFocusModel( + # TODO: proper device propagation + device=torch.device("cpu"), + n_objectives=cfg.cond.moo.num_objectives, + state_space_res=self.cfg.focus_model_state_space_res, + ) + else: + raise NotImplementedError("Unknown focus model type {self.focus_type}") + else: + self.focus_model = None + self.setup_focus_regions() + + def encoding_size(self): + if self.num_thermometer_dim > 0: + return self.num_thermometer_dim * self.n_objectives + return self.n_objectives + + def setup_focus_regions(self): + # focus regions + if self.cfg.focus_type is None: + valid_focus_dirs = np.zeros((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "centered": + valid_focus_dirs = np.ones((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "partitioned": + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type in ["dirichlet", "learned-gfn"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l1") + self.fixed_focus_dirs = None + elif self.cfg.focus_type in ["hyperspherical", "learned-tabular"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = None + elif type(self.cfg.focus_type) is list: + if len(self.cfg.focus_type) == 1: + valid_focus_dirs = np.array([self.cfg.focus_type[0]] * self.n_valid) + self.fixed_focus_dirs = valid_focus_dirs + else: + valid_focus_dirs = np.array(self.cfg.focus_type) + self.fixed_focus_dirs = valid_focus_dirs + else: + raise NotImplementedError( + f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " + f"focus_type, but here: {self.cfg.focus_type}" + ) + self.valid_focus_dirs = valid_focus_dirs + + def sample(self, n: int, train_it: int = None): + train_it = train_it or 0 + if self.fixed_focus_dirs is not None: + focus_dir = torch.tensor( + np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) + ) + elif self.cfg.focus_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.n_objectives)) + focus_dir = m.sample([n]) + elif self.cfg.focus_type == "hyperspherical": + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + elif self.cfg.focus_type is not None and "learned" in self.cfg.focus_type: + if ( + self.focus_model is not None + and train_it >= self.cfg.focus_model_training_limits[0] * self.cfg.max_train_it + ): + focus_dir = self.focus_model.sample_focus_directions(n) + else: + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + else: + raise NotImplementedError(f"Unsupported focus_type={type(self.cfg.focus_type)}") + + return {"focus_dir": focus_dir, "encoding": self.encode(focus_dir)} + + def encode(self, conditional: Tensor) -> Tensor: + return ( + thermometer(conditional, self.ocfg.cond.moo.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + if self.cfg.use_steer_thermomether + else conditional + ) + + def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_logreward: Tensor = None) -> Tensor: + focus_coef, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef + ) + if scalar_logreward is None: + scalar_logreward = torch.log(focus_coef) + else: + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward + + return scalar_logreward + + def step_focus_model(self, batch: gd.Batch, train_it: int): + focus_model_training_limits = self.cfg.focus_model_training_limits + max_train_it = self.ocfg.num_training_steps + if ( + self.focus_model is not None + and train_it >= focus_model_training_limits[0] * max_train_it + and train_it <= focus_model_training_limits[1] * max_train_it + ): + self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py new file mode 100644 index 00000000..db3d3905 --- /dev/null +++ b/src/gflownet/utils/config.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional + + +@dataclass +class TempCondConfig: + """Config for the temperature conditional. + + Attributes + ---------- + + sample_dist : str + The distribution to sample the inverse temperature from. Can be one of: + - "uniform": uniform distribution + - "loguniform": log-uniform distribution + - "gamma": gamma distribution + - "constant": constant temperature + - "beta": beta distribution + dist_params : List[Any] + The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. + num_thermometer_dim : int + The number of thermometer encoding dimensions to use. + """ + + sample_dist: str = "uniform" + dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + + +@dataclass +class MultiObjectiveConfig: + num_objectives: int = 2 + num_thermometer_dim: int = 16 + + +@dataclass +class WeightedPreferencesConfig: + """Config for the weighted preferences conditional. + + Attributes + ---------- + preference_type : str + The preference sampling distribution, defaults to "dirichlet". Can be one of: + - "dirichlet": Dirichlet distribution + - "dirichlet_exponential": Dirichlet distribution with exponential temperature + - "seeded": Enumerated preferences + - None: All rewards equally weighted""" + + preference_type: Optional[str] = "dirichlet" + + +@dataclass +class FocusRegionConfig: + """Config for the focus region conditional. + + Attributes + ---------- + focus_type : str + The type of focus distribtuion used, see FocusRegionConditon.setup_focus_regions. Can be one of: + [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] + """ + + focus_type: Optional[str] = "learned-tabular" + use_steer_thermomether: bool = False + focus_cosim: float = 0.98 + focus_limit_coef: float = 0.1 + focus_model_training_limits: tuple[float, float] = (0.25, 0.75) + focus_model_state_space_res: int = 30 + max_train_it: int = 20_000 + + +@dataclass +class ConditionalsConfig: + temperature: TempCondConfig = TempCondConfig() + moo: MultiObjectiveConfig = MultiObjectiveConfig() + weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() + focus_region: FocusRegionConfig = FocusRegionConfig() diff --git a/src/gflownet/utils/focus_model.py b/src/gflownet/utils/focus_model.py new file mode 100644 index 00000000..14bf6c71 --- /dev/null +++ b/src/gflownet/utils/focus_model.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import torch +import torch.nn as nn + +from gflownet.utils.metrics import get_limits_of_hypercube + + +class FocusModel: + """ + Abstract class for a belief model over focus directions for goal-conditioned GFNs. + Goal-conditioned GFNs allow for more control over the objective-space region from which + we wish to sample. However due to the growing number of emtpy regions in the objective space, + if we naively sample focus-directions from the entire objective space, we will condition + our GFN with a lot of infeasible directions which significantly harms its sample efficiency + compared to a more simple preference-conditioned model. + To alleviate this problem, we introduce a focus belief model which is used to sample + focus directions from a subset of the objective space. The belief model is + trained to predict the probability of a focus direction being feasible. The likelihood + to sample a focus direction is then proportional to its population. Directions that have never + been sampled should be given the maximum likelihood. + """ + + def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: + """ + args: + device: torch device + n_objectives: number of objectives + state_space_res: resolution of the state space discretisation. The number of focus directions to consider + grows within O(state_space_res ** n_objectives) and depends on the amount of filtering we apply + (e.g. valid focus-directions should sum to 1 [dirichlet], should contain a 1 [limits], etc.) + """ + self.device = device + self.n_objectives = n_objectives + self.state_space_res = state_space_res + + self.feasible_flow = 1.0 + self.infeasible_flow = 0.1 + + def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): + raise NotImplementedError + + def sample_focus_directions(self, n: int): + raise NotImplementedError + + +class TabularFocusModel(FocusModel): + """ + Tabular model of the feasibility of focus directions for goal-condtioning. + We keep a count of the number of times each focus direction has been sampled and whether + this direction succesfully lead to a sample in this region of the objective space. The (unormalized) likelihood + of a focus direction being feasible is then given by the ratio of these numbers. + If a focus direction has not been sampled yet it obtains the maximum likelihood of one. + """ + + def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: + super().__init__(device, n_objectives, state_space_res) + self.n_objectives = n_objectives + self.state_space_res = state_space_res + self.focus_dir_dataset = ( + nn.functional.normalize(torch.tensor(get_limits_of_hypercube(n_objectives, state_space_res)), dim=1) + .float() + .to(self.device) + ) + self.focus_dir_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) + self.focus_dir_population_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) + + def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): + """ + Updates the focus model with the focus directions and rewards + of the last batch. + """ + focus_dirs = nn.functional.normalize(focus_dirs, dim=1) + flat_rewards = nn.functional.normalize(flat_rewards, dim=1) + + focus_dirs_indices = torch.argmin(torch.cdist(focus_dirs, self.focus_dir_dataset), dim=1) + flat_rewards_indices = torch.argmin(torch.cdist(flat_rewards, self.focus_dir_dataset), dim=1) + + for idxs, count in zip( + [focus_dirs_indices, flat_rewards_indices], + [self.focus_dir_count, self.focus_dir_population_count], + ): + idx_increments = torch.bincount(idxs, minlength=len(count)) + count += idx_increments + + def sample_focus_directions(self, n: int): + """ + Samples n focus directions from the focus model. + """ + sampling_likelihoods = torch.zeros_like(self.focus_dir_count).float().to(self.device) + sampling_likelihoods[self.focus_dir_count == 0] = self.feasible_flow + sampling_likelihoods[ + torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0) + ] = self.feasible_flow + sampling_likelihoods[ + torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0) + ] = self.infeasible_flow + focus_dir_indices = torch.multinomial(sampling_likelihoods, n, replacement=True) + return self.focus_dir_dataset[focus_dir_indices].to("cpu") + + def save(self, path: Path): + params = { + "n_objectives": self.n_objectives, + "state_space_res": self.state_space_res, + "focus_dir_dataset": self.focus_dir_dataset.to("cpu"), + "focus_dir_count": self.focus_dir_count.to("cpu"), + "focus_dir_population_count": self.focus_dir_population_count.to("cpu"), + } + torch.save(params, open(path / "tabular_focus_model.pt", "wb")) + + def load(self, device: torch.device, path: Path): + params = torch.load(open(path / "tabular_focus_model.pt", "rb")) + self.n_objectives = params["n_objectives"] + self.state_space_res = params["state_space_res"] + self.focus_dir_dataset = params["focus_dir_dataset"].to(device) + self.focus_dir_count = params["focus_dir_count"].to(device) + self.focus_dir_population_count = params["focus_dir_population_count"].to(device) diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index abf123df..cc37c127 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.nn as nn from botorch.utils.multi_objective import infer_reference_point, pareto from botorch.utils.multi_objective.hypervolume import Hypervolume from rdkit import Chem, DataStructs @@ -11,6 +12,47 @@ from sklearn.cluster import KMeans +def compute_focus_coef( + flat_rewards: torch.Tensor, focus_dirs: torch.Tensor, focus_cosim: float, focus_limit_coef: float = 1.0 +): + """ + The focus direction is defined as a hypercone in the objective space centered around an focus_dir. + The focus coefficient (between 0 and 1) scales the reward associated to a given sample. + It should be 1 when the sample is exactly at the focus direction, equal to the focus_limit_coef + when the sample is at on the limit of the focus region and 0 when it is outside the focus region + we can use an exponential decay of the focus coefficient between the center and the limit of the focus region + i.e. cosim(sample, focus_dir) ** focus_gamma_param = focus_limit_coef + Note that we work in the positive quadrant (each reward is positive) and thus the cosine similarity is in [0, 1] + + :param focus_dirs: the focus directions, shape (batch_size, num_objectives) + :param flat_rewards: the flat rewards, shape (batch_size, num_objectives) + :param focus_cosim: the cosine similarity threshold to define the focus region + :param focus_limit_coef: the focus coefficient at the limit of the focus region + """ + assert focus_cosim >= 0.0 and focus_cosim <= 1.0, f"focus_cosim must be in [0, 1], now {focus_cosim}" + assert ( + focus_limit_coef > 0.0 and focus_limit_coef <= 1.0 + ), f"focus_limit_coef must be in (0, 1], now {focus_limit_coef}" + focus_gamma_param = torch.tensor(np.log(focus_limit_coef) / np.log(focus_cosim)).float() + cosim = nn.functional.cosine_similarity(flat_rewards, focus_dirs, dim=1) + in_focus_mask = cosim >= focus_cosim + focus_coef = torch.where(in_focus_mask, cosim**focus_gamma_param, 0.0) + return focus_coef, in_focus_mask + + +def get_focus_accuracy(flat_rewards, focus_dirs, focus_cosim): + _, in_focus_mask = compute_focus_coef(focus_dirs, flat_rewards, focus_cosim, focus_limit_coef=1.0) + return in_focus_mask.float().sum() / len(flat_rewards) + + +def get_limits_of_hypercube(n_dims, n_points_per_dim=10): + """Discretise the faces that are at the extremity of a unit hypercube""" + linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] + grid = np.array(list(product(*linear_spaces))) + extreme_points = grid[np.any(grid == 1, axis=1)] + return extreme_points + + def get_IGD(samples, ref_front: np.ndarray = None): """ Computes the Inverse Generational Distance of a set of samples w.r.t a reference pareto front. @@ -28,14 +70,6 @@ def get_IGD(samples, ref_front: np.ndarray = None): Returns: float: The IGD value. """ - - def get_limits_of_hypercube(n_dims, n_points_per_dim=10): - """Discretise the faces that are at the extremity of a unit hypercube""" - linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] - grid = np.array(list(product(*linear_spaces))) - extreme_points = grid[np.any(grid == 1, axis=1)] - return extreme_points - n_objectives = samples.shape[1] if ref_front is None: ref_front = get_limits_of_hypercube(n_dims=n_objectives) @@ -71,14 +105,6 @@ def get_PC_entropy(samples, ref_front=None): Returns: float: The IGD value. """ - - def get_limits_of_hypercube(n_dims, n_points_per_dim=10): - """Discretise the faces that are at the extremity of a unit hypercube""" - linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] - grid = np.array(list(product(*linear_spaces))) - extreme_points = grid[np.any(grid == 1, axis=1)] - return extreme_points - n_objectives = samples.shape[1] if ref_front is None: ref_front = get_limits_of_hypercube(n_dims=n_objectives) @@ -100,6 +126,18 @@ def get_limits_of_hypercube(n_dims, n_points_per_dim=10): return float(pc_ent) +def sample_positiveQuadrant_ndim_sphere(n=10, d=2, normalisation="l2"): + points = np.random.randn(n, d) + points = np.abs(points) # positive quadrant + if normalisation == "l2": + points /= np.linalg.norm(points, axis=1, keepdims=True) + elif normalisation == "l1": + points /= np.sum(points, axis=1, keepdims=True) + else: + raise ValueError(f"Unknown normalisation {normalisation}") + return points + + def partition_hypersphere(k: int, d: int, n_samples: int = 10000, normalisation: str = "l2"): """ Partition a hypersphere into k clusters. @@ -119,18 +157,6 @@ def partition_hypersphere(k: int, d: int, n_samples: int = 10000, normalisation: v: np.ndarray Array of shape (k, d) containing the cluster centers """ - - def sample_positiveQuadrant_ndim_sphere(n=10, d=2, normalisation="l2"): - points = np.random.randn(n, d) - points = np.abs(points) # positive quadrant - if normalisation == "l2": - points /= np.linalg.norm(points, axis=1, keepdims=True) - elif normalisation == "l1": - points /= np.sum(points, axis=1, keepdims=True) - else: - raise ValueError(f"Unknown normalisation {normalisation}") - return points - points = sample_positiveQuadrant_ndim_sphere(n_samples, d, normalisation) v = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(points).cluster_centers_ if normalisation == "l2": diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 844e3174..4862c6c7 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -23,6 +23,8 @@ def __init__( compute_normed=False, compute_igd=False, compute_pc_entropy=False, + compute_focus_accuracy=False, + focus_cosim=None, ): # This __init__ is only called in the main process. This object is then (potentially) cloned # in pytorch data worker processed and __call__'ed from within those processes. This means @@ -36,8 +38,11 @@ def __init__( self.compute_normed = compute_normed self.compute_igd = compute_igd self.compute_pc_entropy = compute_pc_entropy + self.compute_focus_accuracy = compute_focus_accuracy + self.focus_cosim = focus_cosim self.all_flat_rewards: List[Tensor] = [] + self.all_focus_dirs: List[Tensor] = [] self.all_smi: List[str] = [] self.pareto_queue: mp.Queue = mp.Queue() self.pareto_front = None @@ -115,12 +120,17 @@ def _run_pareto_accumulation(self): def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards) + if self.compute_focus_accuracy: + self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) self.all_smi = self.all_smi + list([i.get("smi", None) for i in trajs]) if len(self.all_flat_rewards) > self.num_to_keep: self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep :] + self.all_focus_dirs = self.all_focus_dirs[-self.num_to_keep :] self.all_smi = self.all_smi[-self.num_to_keep :] flat_rewards = torch.stack(self.all_flat_rewards).numpy() + if self.compute_focus_accuracy: + focus_dirs = torch.stack(self.all_focus_dirs).numpy() # collects empirical pareto front from in-process samples pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False) @@ -176,6 +186,14 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): "PCent": pc_ent, "lifetime_PCent_frontOnly": self.pareto_metrics[3], } + if self.compute_focus_accuracy: + focus_acc = metrics.get_focus_accuracy( + torch.tensor(flat_rewards), torch.tensor(focus_dirs), self.focus_cosim + ) + info = { + **info, + "focus_acc": focus_acc, + } return info diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index c49138e9..cb220f4b 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,14 +1,16 @@ import pickle import queue import threading +import traceback import torch import torch.multiprocessing as mp -class MPModelPlaceholder: - """This class can be used as a Model in a worker process, and - translates calls to queries to the main process""" +class MPObjectPlaceholder: + """This class can be used for example as a model or dataset placeholder + in a worker process, and translates calls to the object-placeholder into + queries for the main process to execute on the real object.""" def __init__(self, in_queues, out_queues, pickle_messages=False): self.qs = in_queues, out_queues @@ -31,42 +33,51 @@ def encode(self, m): def decode(self, m): if self.pickle_messages: - return pickle.loads(m) + m = pickle.loads(m) + if isinstance(m, Exception): + print("Received exception from main process, reraising.") + raise m return m - # TODO: make a generic method for this based on __getattr__ - def logZ(self, *a, **kw): - self._check_init() - self.in_queue.put(self.encode(("logZ", a, kw))) - return self.decode(self.out_queue.get()) + def __getattr__(self, name): + def method_wrapper(*a, **kw): + self._check_init() + self.in_queue.put(self.encode((name, a, kw))) + return self.decode(self.out_queue.get()) + + return method_wrapper def __call__(self, *a, **kw): self._check_init() self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) + def __len__(self): + self._check_init() + self.in_queue.put(("__len__", (), {})) + return self.out_queue.get() + -class MPModelProxy: - """This class maintains a reference to an in-cuda-memory model, and +class MPObjectProxy: + """This class maintains a reference to some object and creates a `placeholder` attribute which can be safely passed to multiprocessing DataLoader workers. - This placeholder model sends messages accross multiprocessing - queues, which are received by this proxy instance, which calls the - model and sends the return value back to the worker. - - Starts its own (daemon) thread. Always passes CPU tensors between - processes. + The placeholders in each process send messages accross multiprocessing + queues which are received by this proxy instance. The proxy instance then + runs the calls on our object and sends the return value back to the worker. + Starts its own (daemon) thread. + Always passes CPU tensors between processes. """ - def __init__(self, model: torch.nn.Module, num_workers: int, cast_types: tuple, pickle_messages: bool = False): - """Construct a multiprocessing model proxy for torch DataLoaders. + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + """Construct a multiprocessing object proxy. Parameters ---------- - model: torch.nn.Module - A torch model which lives in the main process to which method calls are passed + obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple @@ -80,9 +91,12 @@ def __init__(self, model: torch.nn.Module, num_workers: int, cast_types: tuple, self.in_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPModelPlaceholder(self.in_queues, self.out_queues, pickle_messages) - self.model = model - self.device = next(model.parameters()).device + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.obj = obj + if hasattr(obj, "parameters"): + self.device = next(obj.parameters()).device + else: + self.device = torch.device("cpu") self.cuda_types = (torch.Tensor,) + cast_types self.stop = threading.Event() self.thread = threading.Thread(target=self.run, daemon=True) @@ -114,12 +128,36 @@ def run(self): except ConnectionError: break attr, args, kwargs = r - f = getattr(self.model, attr) + f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} - result = f(*args, **kwargs) + try: + # There's no need to compute gradients, since we can't transfer them back to the worker + with torch.no_grad(): + result = f(*args, **kwargs) + except Exception as e: + result = e + exc_str = traceback.format_exc() + try: + pickle.dumps(e) + except Exception: + result = RuntimeError("Exception raised in MPModelProxy, but it cannot be pickled.\n" + exc_str) if isinstance(result, (list, tuple)): msg = [self.to_cpu(i) for i in result] + elif isinstance(result, dict): + msg = {k: self.to_cpu(i) for k, i in result.items()} + else: + msg = self.to_cpu(result) + self.out_queues[qi].put(self.encode(msg)) + + +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): + """Construct a multiprocessing object proxy for torch DataLoaders so + that it does not need to be copied in every worker's memory. For example, + this can be used to wrap a model such that only the main process makes + cuda calls by forwarding data through the model, or a replay buffer + such that the new data is pushed in from the worker processes but only the + main process has to hold the full buffer in memory. self.out_queues[qi].put(self.encode(msg)) elif isinstance(result, dict): msg = {k: self.to_cpu(i) for k, i in result.items()} @@ -128,16 +166,10 @@ def run(self): msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) - -def wrap_model_mp(model, num_workers, cast_types, pickle_messages: bool = False): - """Construct a multiprocessing model proxy for torch DataLoaders so - that only one process ends up making cuda calls and holding cuda - tensors in memory. - Parameters ---------- - model: torch.Module - A torch model which lives in the main process to which method calls are passed + obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple @@ -150,8 +182,8 @@ def wrap_model_mp(model, num_workers, cast_types, pickle_messages: bool = False) Returns ------- - placeholder: MPModelPlaceholder - A placeholder model whose method calls route arguments to the main process + placeholder: MPObjectPlaceholder + A placeholder object whose method calls route arguments to the main process """ - return MPModelProxy(model, num_workers, cast_types, pickle_messages).placeholder + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages).placeholder diff --git a/tests/test_frag_env.py b/tests/test_envs.py similarity index 74% rename from tests/test_frag_env.py rename to tests/test_envs.py index 414eaad8..204a17cb 100644 --- a/tests/test_frag_env.py +++ b/tests/test_envs.py @@ -1,16 +1,19 @@ import base64 import pickle -from collections import defaultdict import networkx as nx import pytest +from omegaconf import OmegaConf from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.models import bengio2021flow -def build_two_node_states(): +def build_two_node_states(ctx): # TODO: This is actually fairly generic code that will probably be reused by other tests in the future. # Having a proper class to handle graph-indexed hash maps would probably be good. graph_cache = {} @@ -20,7 +23,6 @@ def build_two_node_states(): # We're enumerating all states of length two, but we could've just as well randomly sampled # some states. env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) def g2h(g): gc = g.to_directed() @@ -72,11 +74,19 @@ def expand(s, idx): return [graph_by_idx[i] for i in list(nx.topological_sort(mdp_graph))] +def get_frag_env_ctx(): + return FragMolBuildingEnvContext(max_frags=2, fragments=bengio2021flow.FRAGMENTS[:20]) + + +def get_atom_env_ctx(): + return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) + + @pytest.fixture -def two_node_states(request): +def two_node_states_frags(request): data = request.config.cache.get("frag_env/two_node_states", None) if data is None: - data = build_two_node_states() + data = build_two_node_states(get_frag_env_ctx()) # pytest caches through JSON so we have to make a clean enough string request.config.cache.set("frag_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) else: @@ -84,13 +94,24 @@ def two_node_states(request): return data -def test_backwards_mask_equivalence(two_node_states): +@pytest.fixture +def two_node_states_atoms(request): + data = request.config.cache.get("atom_env/two_node_states", None) + if data is None: + data = build_two_node_states(get_atom_env_ctx()) + # pytest caches through JSON so we have to make a clean enough string + request.config.cache.set("atom_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) + else: + data = pickle.loads(base64.b64decode(data)) + return data + + +def _test_backwards_mask_equivalence(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=False) @@ -103,7 +124,7 @@ def test_backwards_mask_equivalence(two_node_states): raise ValueError() -def test_backwards_mask_equivalence_ipa(two_node_states): +def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -111,8 +132,9 @@ def test_backwards_mask_equivalence_ipa(two_node_states): This test also accounts for idempotent actions. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) - algo = TrajectoryBalance(env, ctx, None, defaultdict(int), max_nodes=2) + cfg = OmegaConf.structured(Config) + cfg.algo.max_nodes = 2 + algo = TrajectoryBalance(env, ctx, None, cfg) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=True) @@ -138,3 +160,19 @@ def test_backwards_mask_equivalence_ipa(two_node_states): equivalence_classes.append(ipa) if n != len(equivalence_classes): raise ValueError() + + +def test_backwards_mask_equivalence_frag(two_node_states_frags): + _test_backwards_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_frag(two_node_states_frags): + _test_backwards_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_atom(two_node_states_atoms): + _test_backwards_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_atom(two_node_states_atoms): + _test_backwards_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx()) diff --git a/tox.ini b/tox.ini index 617b57be..85d593d5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,17 +1,16 @@ [tox] -envlist = py3{8,9}, report +envlist = py3{9}, report [testenv] commands = pytest skip_install = true depends = - report: py3{8,9} + report: py3{9} setenv = - py3{8,9,10}: COVERAGE_FILE = .coverage.{envname} + py3{9,10}: COVERAGE_FILE = .coverage.{envname} install_command = - pip install -U {opts} {packages} --find-links https://data.pyg.org/whl/torch-1.10.0+cu113.html --find-links https://data.pyg.org/whl/torch-1.10.0+cpu.html + pip install -U {opts} {packages} --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html deps = - py38: -r requirements/dev_3.8.txt py39: -r requirements/dev_3.9.txt