-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR refactors tasks into reusable components that will make it easier to define new tasks and methods. The PR is pretty big, but a lot of it is moving code around (or deleting repeated code): - Creates a notion of `Conditional`, implements `TemperatureConditional`, `MultiObjectiveWeightedPreferences`, `FocusRegionConditional` - Separates the commonalities between `seh_frag` and `seh_frag_moo` into a more generic `StandardOnlineTrainer` that is meant to be easily subclassed for new tasks. - Adds some implementation notes and comments - Adds a `validate_batch` routine that's useful for debugging, e.g. new environments and datasets Also fixes some bugs: - Makes `valid_offline_ratio` a flag and sets it explicitly in tasks where it wasn't properly set - `first_graph_idx` was incorrectly calculated in SubTB (affected logging of logZ values) - QM9Dataset was returning the wrong shape for its rewards - Adds a `allow_5_valence_nitrogen` flag to `MolBuildingEnvContext`, this is needed in some cases, see `tasks/qm9.py`. - Adds an explicit `stop_mask` to `MolBuildingEnvContext.graph_to_Data` - Fixes incorrect default objective name in `seh_frag_moo` - Fixes the default configurations in the tasks' `main` that hadn't been updated - Fixes a number of routines where `focus_cond` was assumed to exist (but we can now turn it off). commits: * little test * new config structure - in progress * trying config by names * further refactor progress * better pyi sort + fix moo example * tox * import fixes * add SQL + fix n_valid * tox test * fix mypy hook and convert qm9 to new cfg * better config generation * use generated config.py * use generated config.py * fix rng call types * fix test + tox * better config doc * fix deps * tox * re-fix deps * minor fixes for seh_frag_moo * tox * beginning of refactor + impl notes * multiobject weighted prefs * focus conditional in progress * switch to OmegaConf * switch to omegaconf * fix pre-commit-config * add omegaconf dep * fix list defaults to fields * remove comment * finish focus conditional * various fixes + switch to new config * tox * update README * make string configs into Literals * switch task construction order * OmegaConf does not support Literal :( * tox * remove dead code + guard against no focus used * fix for no replay * refactor qm9 + remove unused configs * many fixes to QM9, some debugging code and other various fixes * explicit valid_offline_ratio flag * do_validate_batch off by default * addressing PR comments * made device configurable
- Loading branch information
Showing
21 changed files
with
782 additions
and
543 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Implementation notes | ||
|
||
This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments. | ||
|
||
## Environment, Context, Task, Trainers | ||
|
||
We separate experiment concerns in four categories: | ||
- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP. | ||
- The Context provides an interface between the agent and the environment, it | ||
- maps graphs to torch_geometric `Data` | ||
instances | ||
- maps GraphActions to action indices | ||
- produces action masks | ||
- communicates to the model what inputs it should expect | ||
- The Task class is responsible for computing the reward of a state, and for sampling conditioning information | ||
- The Trainer class is responsible for instanciating everything, and running the training & testing loop | ||
|
||
Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.