Skip to content

Commit

Permalink
Merge branch 'trunk' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP authored Aug 15, 2023
2 parents 4f9de6b + ec857a5 commit 0bcfb24
Show file tree
Hide file tree
Showing 41 changed files with 2,485 additions and 1,367 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera

## Repo overview

- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations (only [Trajectory Balance](https://arxiv.org/abs/2201.13259) for now), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities.
- [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data.
- [examples](docs/examples), contains simple example implementations of GFlowNet.
Expand All @@ -30,8 +30,11 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera
- [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward.
- [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein
- [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives).
- [utils](src/gflownet/utils), contains utilities (multiprocessing).
- [`train.py`](src/gflownet/train.py), defines a general harness for training GFlowNet models.
- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning).
- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models.
- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop.

See [implementation notes](docs/implementation_notes.md) for more.

## Getting started

Expand All @@ -44,19 +47,21 @@ A good place to get started is with the [sEH fragment-based MOO task](src/gflown
This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the `--find-links` arguments must be specified as well:

```bash
pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cu113.html
pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu117.html
```
Or for CPU use:

```bash
pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cpu.html
pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html
```

To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository-pip-dependency/) a specific tag, for example here `v0.0.10`, use the following scheme:
```bash
pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ...
```

If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main_3.9.txt`.

## Developing & Contributing

TODO: Write Contributing.md.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
MAJOR="0"
MINOR="0"
MINOR="1"
36 changes: 36 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Implementation notes

This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments.

## Environment, Context, Task, Trainers

We separate experiment concerns in four categories:
- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP.
- The Context provides an interface between the agent and the environment, it
- maps graphs to torch_geometric `Data`
instances
- maps GraphActions to action indices
- produces action masks
- communicates to the model what inputs it should expect
- The Task class is responsible for computing the reward of a state, and for sampling conditioning information
- The Trainer class is responsible for instanciating everything, and running the training & testing loop

Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.


## Graphs

This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations.

Some notes:
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.


### Graph policies & graph action categoricals

The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch.

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
17 changes: 9 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,26 @@ keywords = ["gflownet"]
requires-python = ">=3.8,<3.10"
dynamic = ["version"]
dependencies = [
"torch==1.10.2",
"torch==1.13.1",
# These pins are specific on purpose, some of these packages have
# unstable APIs since they are fairly new. We could instead pin
# them as >= in dev until something breaks?
"torch-geometric==2.0.3",
"torch-scatter==2.0.9",
"torch-sparse==0.6.13",
"torch-cluster==1.6.0",
"torch-scatter",
"torch-sparse",
"torch-cluster",
"rdkit",
"tables",
"scipy",
"networkx",
"tensorboard",
"cvxopt",
"pyarrow",
"botorch==0.6.6", # pin because of the torch==1.10.2 dependency, botorch>=0.7 requires torch>=1.11
# pins to help depencency resolution, because of the above pin
"pyro-ppl==1.8.0",
"gpytorch==1.8.1",
"gitpython",
"botorch",
"pyro-ppl",
"gpytorch",
"omegaconf>=2.3",
]

[project.optional-dependencies]
Expand Down
Loading

0 comments on commit 0bcfb24

Please sign in to comment.