From a7058baecd7bb8e3bce9b204d248b1bf3c329f8c Mon Sep 17 00:00:00 2001 From: Julien Roy Date: Fri, 26 May 2023 11:38:33 -0400 Subject: [PATCH 1/9] Adding support for goal-conditioning, replay-buffer and upgrading to torch1.13 (#89) * fix: removed np.int (deprecated) * Speed and stability improvements to GraphActionCategorical + some tests * tox * include self as test dependency * make TB use new GraphActionCategorical.log_prob * slice tensor to dev * fix: changed default log_dir * fix: TMP -- REVERT ME (ugly fix to stop saving molecules while I have to save data in home/ because of filesystem issues on scratch/ * feat: added option hps.temperature_sample_dist='constant' which correspond to self.temperature_dist_params= * set epsilon to min normal float32 * docs: added hyperparmeters in printing and in logs as json file * fix: added case for self.temperature_sample_dist == 'constant' in encode_conditional_information() * fix: moved config saving from GFNTrainer.run() to SEHMOOFragTrainer.setup() to be caught by Determined agents as well * feat: added git_hash to saved hyperparams * tmp: REVERT ME * chore: flake8 adjustments * feat: added hyperparams 'objectives' and 'num_thermometer_dim' * objectives: the list of objectives to optimise among {'seh', 'qed', 'sa', 'mw'} * num_thermometer_dim: the length of the thermometer vector for encoding temperature * chore: pep8 * Revert "fix: TMP -- REVERT ME (ugly fix to stop saving molecules while I have to save data in home/ because of filesystem issues on scratch/" This reverts commit 5861f3ac319a1c0fcb40d3de892fd7f8d95ea928. * fix: use 'objectives' hyperparam for validation set construction * Revert "Revert "fix: TMP -- REVERT ME (ugly fix to stop saving molecules while I have to save data in home/ because of filesystem issues on scratch/"" This reverts commit 5c737e31975ff610b21c3e21e26d0aa9db507bee. * feat: added 'hostname' in hyperparameters for cluster troubleshooting * feat: use log-rewards for numerical stability (forgotten commit from PR39) * docs: changed name of logged data to reflect true value (r to logR) * fix: added security-check to thermometer and avoid passing both vmin=vmax=0 when beta=0 * improve fragment environment for validity + include stop_mask + include original fragment stems * tox * remove temp code * Remove json prettifier pre-commit hook * fix mask to device in GraphActionCategorical * docs: renamed functions to logrewards (to avoid TrajectoryBalance calling task.cond_info_to_reward() for tasks that wouldn't have been updated to using logrewards rather than rewards) * fix: added clamping of logprobs from PR#49 * more comments and slightly better architecture adaptation * Revert "fix: added clamping of logprobs from PR#49" This reverts commit 76507ca36fdc6d6f7a7b65a22e742839fe96e735. * chore: updated README.md and dependencies to facilitate installation on mila-cluster * fix: isolating sensitive imports (cvxopt) that are tricky to install on Mac M1 chip * feat: added hyperparameter 'Z_params' * (fix): fixed encode_conditional_information() for constant beta (was passing full thermometer instead of empty one) * fix: convert logrewards back into regular rewards for logging * fix: removed batch_size*2 for training sampling_iterator * feat: added argument 'valid_random_action_prob' and moved random_action_prob to be a argument of graph_sampler.sample_from_model() * feat: added a logger to run() * refactor: updated default hyperparameters to stabilise learning with high temperature * feat: added hyperpam 'objective_region_dir' and 'objective_region_cosim' to restrict the feasible solution region through reward manipulation * fix: defined objective_region_dir as a string to allow to search over using Determined * fix: convert to tensor deeper * (minor): added prints * fix: minor fix * fix: set strea=True for valid_sampling_iterator and remove stream-related condition for logging * Added hyperparameters: * n_valid_prefs * n_valid_repeats_per_pref * fix: (minor) intialise beta as ndarray to avoid torch warning over tensor(tensor(beta)) * chore: pep8 * refactor: renamed objective_region -> focus * docs: improved function signatures for clarity * feat: added support for preference_type == None (no preferences) * refactor: updated default hyperparameters to stabilise learning with high temperature * Added hyperparameters: * n_valid_prefs * n_valid_repeats_per_pref * feat: now support focus_dir='dirichlet' for focus-conditioned GFN * chore: pep8 * refactor: save hps at begining of setup() * fix: added metrics.partition_hypersphere() which allows to partition a hypersphere into a pre-defined number of k vectors * chore: pep8 * fix: wrong alignment between data and column_names * chore: tox * feat: renamed hps.focus_dir -> hps.focus_dist and added support for focus_dist being a list of fixed focus_directions * fix: added shape check on valid_focus_dirs * minor: updated default hyperparams * feat: added support for focus_type == 'centered', 'partitioned', None * refactor: save hps in parent class seh_frag * feat: implemented reach_metric() for computing distance between samples and ideal pareto front * feat: added reach_metric() to MultiObjectiveStatsHooks() * chore: tox * docs: updated docstring and function signature for reach_metric * fix: removed duplicate of hyperparams printing * fix: now properly closes hps.json after saving hps on disk * fix: convert beta into torch.tensor() * refactor: changed MPModel* -> MPObject* to allow these classes to be used for other objects like a replay_buffer (now uses __getattr__ to avoid having to have a per-objectType class in which we enumerate the possible calls on the placeholder) * chore: improved code-readability a bit in SamplingIterator() * refactor: removed 'mols' from batch attributes -- seemed unused * feat: implemented replay_buffer * chore: tox * docs: added buffer-length to monitoring * minor: added num_workers to logged curves to easily see which runs manage to complete (error: too many opened files) * chore: cleaning up * fix: turning off hypervolume by default (may cause delays in training with many objectives?) * feat: changed self.verbose to self.print_every * feat: renamed reach-metric -> IGD and added PC-entropy metric * (minor): clean-up, compute metrics before constructing the batch (constructing the batch should be the last thing that we do in that method, also this approach will better accomodate for the use of a replay buffer where the metrics are computed on the online distribution but the batch is from the buffer distribution) * fix: moved replay buffer sampling after metric computation (to have metrics computed on the online distribution, not the buffer distribution) * feat: made the replay buffer more general and added stacking of numpy and torch arrays to make multiprocessing communication (through shared memory files) more efficient * feat: added logrewards in the replay buffer * A few edits in MultiObjectiveStatsHook: * fixed labeling error (the IGD and PC-entropy that were computed were local, not lifetime) * added actual lifetime computation of IGD and PC-entropy * added comments * cleaned up (grouping logic, etc.) * feat: added get_focus_accuracy() * feat: added hyperparam --hindsight_ratio * fix: adding is_valid to ReplayBuffer * this would have been an issue when replay was used with hindsight since the is_valid masks from the on-policy batch was used to update the log-rewards (not the mask corresponding to the samples in the replay buffer) * however I don't think in practice this has been an issue since for the Fragment task all molecules are supposed to be valid (so is_valid=True for all trajs) * fix: dirichlet-sampled focus vectors should be normalised to fit the fixed focus-vector distribution from the validation dataset * fix: changed focus_type='dirichlet' to 'sampled' and now samples from a hyperspehere positive quadrant surface (not a dirichlet distribution) * fix: deepcopy items before pushing them into the replay-buffer * Merge branch fix_multiprocessing_files in julien-replayBuffer * feat: simply returns the cond_info and log_rewards unchanged when using hindsight without focus_type * feat: added elapsed_time to plots * attempt: try to fix wrong focus_accuracy measurement by computing metrics on deep-copies of the online data * chore: tox * feat: added 'negative' hindsight ER * feat: updated default hyperparams based on results from 2023_04_11_4objectives_search1_capacity, 2023_04_11_4objectives_search2_focusCosim, 2023_04_11_4objectives_search3_hindsight * Revert "feat: added 'negative' hindsight ER" This reverts commit 46c777508f49e863c821acda449b80f88782c794. * feat: added param --focus_limit_coef and the notion of a soft exponentially decaying focus coefficient inside a hard-constrained focus-region * feat: re-added focus_type='dirichlet' option * feat: implemented TabularFocusModel() * fix: added device for GPU training * feat: modified the sampling likelihoods to avoid a rich-get-richer effect * minor: changed 'sampled' focus_type name to 'hyperspherical' * minor: changed 'sampled' focus_type name to 'hyperspherical' * feat: simplified TabularFocusModel * fix: check if focus_type is None to accomodate string searching * chore: black this branch * fix: forgotten deprecated focus_type name * fix: still should be sampling from the learned-focus-model at the end of training * feat: added final generation steps (to get molecules from final distribution only and without exploration) * fix: now save smiles again * fix: final_dl should sample from focus_dirs from the learned focus model * chore: black * chore: tox * chore: tox * chore: cleaning up some changes in dependencies * fix: support thermometer-encoding for both preferences and focus-directions * chore: cleaned up main.in * feat: upgrading to torch==1.13.1 * fix: added missing dependency * fix: updating torch-version in tox.ini * fix: relaxing cuda dependencies provided by roadie (they cause build-and-test to fail on Mac and Windows) * fix: up-pined versions for botorch, pyro-ppl, gpytorch since we upgraded torch * chore: cleaning up requirements files for python3.8 * chore: re-generating requirements * attempt: copied dev_3.9.txt -> dev_3.8.txt (same for main) * fix: relaxing cuda dependencies * chore: removing support for python3.8 * fix: attempt to relax versioning for torch-geometric dependencies * fix: removed torch-cuda from tox.ini * chore: minor clean up in requirements comments * docs: added comments to highlight where positional assumptions are made about the ordering of preferences and focus-directions * docs: fixed incorrect comment * chore: sorting hps before saving * chore: (minor) avoid breaking comment above * fix: avoid creating the final_data_loader if we are not doing any final-dist sampling * chore: set the sampling likelihoods of feasible/infeasible focus-directions of learned-focus-models as instance attributes (hardcoded for now) * chore: improved documentation of FocusModel() and removed deprecated argument focus_cosim --------- Co-authored-by: Julien Roy Co-authored-by: Emmanuel Bengio Co-authored-by: julienroyd --- README.md | 4 +- pyproject.toml | 16 +- requirements/dev_3.8.txt | 297 ---------------- requirements/dev_3.9.txt | 6 +- requirements/main_3.8.txt | 186 ---------- requirements/main_3.9.txt | 15 +- src/gflownet/data/replay_buffer.py | 40 +++ src/gflownet/data/sampling_iterator.py | 167 ++++++--- src/gflownet/tasks/qm9/qm9.py | 6 +- src/gflownet/tasks/seh_frag.py | 40 ++- src/gflownet/tasks/seh_frag_moo.py | 363 +++++++++++++++----- src/gflownet/train.py | 82 ++++- src/gflownet/utils/focus_model.py | 117 +++++++ src/gflownet/utils/metrics.py | 82 +++-- src/gflownet/utils/multiobjective_hooks.py | 16 + src/gflownet/utils/multiprocessing_proxy.py | 88 +++-- tox.ini | 9 +- 17 files changed, 812 insertions(+), 722 deletions(-) delete mode 100644 requirements/dev_3.8.txt delete mode 100644 requirements/main_3.8.txt create mode 100644 src/gflownet/data/replay_buffer.py create mode 100644 src/gflownet/utils/focus_model.py diff --git a/README.md b/README.md index 00791982..5f3dc65c 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,12 @@ A good place to get started is with the [sEH fragment-based MOO task](src/gflown This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the `--find-links` arguments must be specified as well: ```bash -pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cu113.html +pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu117.html ``` Or for CPU use: ```bash -pip install -e . --find-links https://data.pyg.org/whl/torch-1.10.0+cpu.html +pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html ``` To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository-pip-dependency/) a specific tag, for example here `v0.0.10`, use the following scheme: diff --git a/pyproject.toml b/pyproject.toml index 25209cff..4467b276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,14 +56,14 @@ keywords = ["gflownet"] requires-python = ">=3.8,<3.10" dynamic = ["version"] dependencies = [ - "torch==1.10.2", + "torch==1.13.1", # These pins are specific on purpose, some of these packages have # unstable APIs since they are fairly new. We could instead pin # them as >= in dev until something breaks? "torch-geometric==2.0.3", - "torch-scatter==2.0.9", - "torch-sparse==0.6.13", - "torch-cluster==1.6.0", + "torch-scatter", + "torch-sparse", + "torch-cluster", "rdkit", "tables", "scipy", @@ -71,10 +71,10 @@ dependencies = [ "tensorboard", "cvxopt", "pyarrow", - "botorch==0.6.6", # pin because of the torch==1.10.2 dependency, botorch>=0.7 requires torch>=1.11 - # pins to help depencency resolution, because of the above pin - "pyro-ppl==1.8.0", - "gpytorch==1.8.1", + "gitpython", + "botorch", + "pyro-ppl", + "gpytorch", ] [project.optional-dependencies] diff --git a/requirements/dev_3.8.txt b/requirements/dev_3.8.txt deleted file mode 100644 index 4b45ebb0..00000000 --- a/requirements/dev_3.8.txt +++ /dev/null @@ -1,297 +0,0 @@ -absl-py==1.4.0 - # via tensorboard -bandit[toml]==1.7.5 - # via gflownet (pyproject.toml) -black==23.3.0 - # via gflownet (pyproject.toml) -blosc2==2.0.0 - # via tables -botorch==0.6.6 - # via gflownet (pyproject.toml) -build==0.10.0 - # via pip-tools -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 - # via requests -cfgv==3.3.1 - # via pre-commit -charset-normalizer==3.1.0 - # via requests -click==8.1.3 - # via - # black - # pip-compile-multi - # pip-tools -coverage[toml]==7.2.5 - # via pytest-cov -cvxopt==1.3.0 - # via gflownet (pyproject.toml) -cython==0.29.34 - # via tables -distlib==0.3.6 - # via virtualenv -exceptiongroup==1.1.1 - # via pytest -filelock==3.12.0 - # via virtualenv -gitdb==4.0.10 - # via gitpython -gitpython==3.1.31 - # via - # bandit - # gflownet (pyproject.toml) -google-auth==2.17.3 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -googledrivedownloader==0.4 - # via torch-geometric -gpytorch==1.8.1 - # via - # botorch - # gflownet (pyproject.toml) -grpcio==1.54.0 - # via tensorboard -identify==2.5.23 - # via pre-commit -idna==3.4 - # via requests -importlib-metadata==6.6.0 - # via - # markdown - # typeguard -iniconfig==2.0.0 - # via pytest -isodate==0.6.1 - # via rdflib -isort==5.12.0 - # via gflownet (pyproject.toml) -jinja2==3.1.2 - # via torch-geometric -joblib==1.2.0 - # via scikit-learn -markdown==3.4.3 - # via tensorboard -markdown-it-py==2.2.0 - # via rich -markupsafe==2.1.2 - # via - # jinja2 - # werkzeug -mdurl==0.1.2 - # via markdown-it-py -msgpack==1.0.5 - # via blosc2 -multipledispatch==0.6.0 - # via botorch -mypy==1.2.0 - # via gflownet (pyproject.toml) -mypy-extensions==1.0.0 - # via - # black - # mypy -networkx==3.1 - # via - # gflownet (pyproject.toml) - # torch-geometric -nodeenv==1.7.0 - # via pre-commit -numexpr==2.8.4 - # via tables -numpy==1.24.3 - # via - # gpytorch - # numexpr - # opt-einsum - # pandas - # pyarrow - # pyro-ppl - # rdkit - # scikit-learn - # scipy - # tables - # tensorboard - # torch-geometric -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via pyro-ppl -packaging==23.1 - # via - # black - # build - # pytest - # tables -pandas==2.0.1 - # via torch-geometric -pathspec==0.11.1 - # via black -pbr==5.11.1 - # via stevedore -pillow==9.5.0 - # via rdkit -pip-compile-multi==2.6.2 - # via gflownet (pyproject.toml) -pip-tools==6.13.0 - # via pip-compile-multi -platformdirs==3.5.0 - # via - # black - # virtualenv -pluggy==1.0.0 - # via pytest -pre-commit==3.2.2 - # via gflownet (pyproject.toml) -protobuf==4.22.3 - # via tensorboard -py-cpuinfo==9.0.0 - # via tables -pyarrow==11.0.0 - # via gflownet (pyproject.toml) -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pygments==2.15.1 - # via rich -pyparsing==3.0.9 - # via - # rdflib - # torch-geometric -pyproject-hooks==1.0.0 - # via build -pyro-api==0.1.2 - # via pyro-ppl -pyro-ppl==1.8.0 - # via - # botorch - # gflownet (pyproject.toml) -pytest==7.3.1 - # via - # gflownet (pyproject.toml) - # pytest-cov -pytest-cov==4.0.0 - # via gflownet (pyproject.toml) -python-dateutil==2.8.2 - # via pandas -pytz==2023.3 - # via pandas -pyyaml==6.0 - # via - # bandit - # pre-commit - # torch-geometric - # yacs -rdflib==6.3.2 - # via torch-geometric -rdkit==2022.9.5 - # via gflownet (pyproject.toml) -requests==2.29.0 - # via - # requests-oauthlib - # tensorboard - # torch-geometric -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.3.5 - # via bandit -rsa==4.9 - # via google-auth -ruff==0.0.263 - # via gflownet (pyproject.toml) -scikit-learn==1.2.2 - # via - # gpytorch - # torch-geometric -scipy==1.10.1 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # scikit-learn - # torch-geometric - # torch-sparse -six==1.16.0 - # via - # google-auth - # isodate - # multipledispatch - # python-dateutil -smmap==5.0.0 - # via gitdb -stevedore==5.0.0 - # via bandit -tables==3.8.0 - # via gflownet (pyproject.toml) -tensorboard==2.12.2 - # via gflownet (pyproject.toml) -tensorboard-data-server==0.7.0 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.1.0 - # via scikit-learn -tomli==2.0.1 - # via - # bandit - # black - # build - # coverage - # mypy - # pytest -toposort==1.10 - # via pip-compile-multi -torch==1.10.2 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # pyro-ppl -torch-cluster==1.6.0 - # via gflownet (pyproject.toml) -torch-geometric==2.0.3 - # via gflownet (pyproject.toml) -torch-scatter==2.0.9 - # via gflownet (pyproject.toml) -torch-sparse==0.6.13 - # via gflownet (pyproject.toml) -tqdm==4.65.0 - # via - # pyro-ppl - # torch-geometric -typeguard==3.0.2 - # via gflownet (pyproject.toml) -types-pkg-resources==0.1.3 - # via gflownet (pyproject.toml) -typing-extensions==4.5.0 - # via - # black - # mypy - # rich - # torch - # typeguard -tzdata==2023.3 - # via pandas -urllib3==1.26.15 - # via requests -virtualenv==20.23.0 - # via pre-commit -werkzeug==2.3.3 - # via tensorboard -wheel==0.40.0 - # via - # pip-tools - # tensorboard -yacs==0.1.8 - # via torch-geometric -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools diff --git a/requirements/dev_3.9.txt b/requirements/dev_3.9.txt index 34784306..0af2d196 100644 --- a/requirements/dev_3.9.txt +++ b/requirements/dev_3.9.txt @@ -246,7 +246,7 @@ tomli==2.0.1 # pytest toposort==1.10 # via pip-compile-multi -torch==1.10.2 +torch==1.13.1 # via # botorch # gflownet (pyproject.toml) @@ -256,9 +256,9 @@ torch-cluster==1.6.0 # via gflownet (pyproject.toml) torch-geometric==2.0.3 # via gflownet (pyproject.toml) -torch-scatter==2.0.9 +torch-scatter==2.1.1 # via gflownet (pyproject.toml) -torch-sparse==0.6.13 +torch-sparse==0.6.17 # via gflownet (pyproject.toml) tqdm==4.65.0 # via diff --git a/requirements/main_3.8.txt b/requirements/main_3.8.txt deleted file mode 100644 index e9cff706..00000000 --- a/requirements/main_3.8.txt +++ /dev/null @@ -1,186 +0,0 @@ -absl-py==1.4.0 - # via tensorboard -blosc2==2.0.0 - # via tables -botorch==0.6.6 - # via gflownet (pyproject.toml) -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 - # via requests -charset-normalizer==3.1.0 - # via requests -cvxopt==1.3.0 - # via gflownet (pyproject.toml) -cython==0.29.34 - # via tables -google-auth==2.17.3 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -googledrivedownloader==0.4 - # via torch-geometric -gpytorch==1.8.1 - # via - # botorch - # gflownet (pyproject.toml) -grpcio==1.54.0 - # via tensorboard -idna==3.4 - # via requests -importlib-metadata==6.6.0 - # via markdown -isodate==0.6.1 - # via rdflib -jinja2==3.1.2 - # via torch-geometric -joblib==1.2.0 - # via scikit-learn -markdown==3.4.3 - # via tensorboard -markupsafe==2.1.2 - # via - # jinja2 - # werkzeug -msgpack==1.0.5 - # via blosc2 -multipledispatch==0.6.0 - # via botorch -networkx==3.1 - # via - # gflownet (pyproject.toml) - # torch-geometric -numexpr==2.8.4 - # via tables -numpy==1.24.3 - # via - # gpytorch - # numexpr - # opt-einsum - # pandas - # pyarrow - # pyro-ppl - # rdkit - # scikit-learn - # scipy - # tables - # tensorboard - # torch-geometric -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via pyro-ppl -packaging==23.1 - # via tables -pandas==2.0.1 - # via torch-geometric -pillow==9.5.0 - # via rdkit -protobuf==4.22.3 - # via tensorboard -py-cpuinfo==9.0.0 - # via tables -pyarrow==11.0.0 - # via gflownet (pyproject.toml) -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pyparsing==3.0.9 - # via - # rdflib - # torch-geometric -pyro-api==0.1.2 - # via pyro-ppl -pyro-ppl==1.8.0 - # via - # botorch - # gflownet (pyproject.toml) -python-dateutil==2.8.2 - # via pandas -pytz==2023.3 - # via pandas -pyyaml==6.0 - # via - # torch-geometric - # yacs -rdflib==6.3.2 - # via torch-geometric -rdkit==2022.9.5 - # via gflownet (pyproject.toml) -requests==2.29.0 - # via - # requests-oauthlib - # tensorboard - # torch-geometric -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth -scikit-learn==1.2.2 - # via - # gpytorch - # torch-geometric -scipy==1.10.1 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # scikit-learn - # torch-geometric - # torch-sparse -six==1.16.0 - # via - # google-auth - # isodate - # multipledispatch - # python-dateutil -tables==3.8.0 - # via gflownet (pyproject.toml) -tensorboard==2.12.2 - # via gflownet (pyproject.toml) -tensorboard-data-server==0.7.0 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.1.0 - # via scikit-learn -torch==1.10.2 - # via - # botorch - # gflownet (pyproject.toml) - # gpytorch - # pyro-ppl -torch-cluster==1.6.0 - # via gflownet (pyproject.toml) -torch-geometric==2.0.3 - # via gflownet (pyproject.toml) -torch-scatter==2.0.9 - # via gflownet (pyproject.toml) -torch-sparse==0.6.13 - # via gflownet (pyproject.toml) -tqdm==4.65.0 - # via - # pyro-ppl - # torch-geometric -typing-extensions==4.5.0 - # via torch -tzdata==2023.3 - # via pandas -urllib3==1.26.15 - # via requests -werkzeug==2.3.3 - # via tensorboard -wheel==0.40.0 - # via tensorboard -yacs==0.1.8 - # via torch-geometric -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements/main_3.9.txt b/requirements/main_3.9.txt index cb0cbaa0..dcc19ca1 100644 --- a/requirements/main_3.9.txt +++ b/requirements/main_3.9.txt @@ -14,6 +14,10 @@ cvxopt==1.3.0 # via gflownet (pyproject.toml) cython==0.29.33 # via tables +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via gflownet (pyproject.toml) google-auth==2.16.0 # via # google-auth-oauthlib @@ -139,6 +143,8 @@ six==1.16.0 # isodate # multipledispatch # python-dateutil +smmap==5.0.0 + # via gitdb tables==3.8.0 # via gflownet (pyproject.toml) tensorboard==2.11.2 @@ -149,7 +155,7 @@ tensorboard-plugin-wit==1.8.1 # via tensorboard threadpoolctl==3.1.0 # via scikit-learn -torch==1.10.2 +torch==1.13.1 # via # botorch # gflownet (pyproject.toml) @@ -159,9 +165,9 @@ torch-cluster==1.6.0 # via gflownet (pyproject.toml) torch-geometric==2.0.3 # via gflownet (pyproject.toml) -torch-scatter==2.0.9 +torch-scatter==2.1.1 # via gflownet (pyproject.toml) -torch-sparse==0.6.13 +torch-sparse==0.6.17 # via gflownet (pyproject.toml) tqdm==4.64.1 # via @@ -174,7 +180,8 @@ urllib3==1.26.14 werkzeug==2.2.3 # via tensorboard wheel==0.38.4 - # via tensorboard + # via + # tensorboard yacs==0.1.8 # via torch-geometric zipp==3.12.0 diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py new file mode 100644 index 00000000..2f27de0a --- /dev/null +++ b/src/gflownet/data/replay_buffer.py @@ -0,0 +1,40 @@ +from typing import List + +import numpy as np +import torch + + +class ReplayBuffer(object): + def __init__(self, capacity: int = 100000, warmup: int = 0, rng: np.random.Generator = None): + self.capacity = capacity + self.warmup = warmup + assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" + + self.buffer: List[tuple] = [] + self.position = 0 + self.rng = rng + + def push(self, *args): + if len(self.buffer) == 0: + self._input_size = len(args) + else: + assert self._input_size == len(args), "ReplayBuffer input size must be constant" + if len(self.buffer) < self.capacity: + self.buffer.append(None) + self.buffer[self.position] = args + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size): + idxs = self.rng.choice(len(self.buffer), batch_size) + out = list(zip(*[self.buffer[idx] for idx in idxs])) + for i in range(len(out)): + # stack if all elements are numpy arrays or torch tensors + # (this is much more efficient to send arrays through multiprocessing queues) + if all([isinstance(x, np.ndarray) for x in out[i]]): + out[i] = np.stack(out[i], axis=0) + elif all([isinstance(x, torch.Tensor) for x in out[i]]): + out[i] = torch.stack(out[i], dim=0) + return tuple(out) + + def __len__(self): + return len(self.buffer) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index ff95691e..2de3f21a 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -1,6 +1,7 @@ import os import sqlite3 from collections.abc import Iterable +from copy import deepcopy from typing import Callable, List import networkx as nx @@ -10,6 +11,8 @@ from rdkit import Chem, RDLogger from torch.utils.data import Dataset, IterableDataset +from gflownet.data.replay_buffer import ReplayBuffer + class SamplingIterator(IterableDataset): """This class allows us to parallelise and train faster. @@ -30,11 +33,14 @@ def __init__( algo, task, device, - ratio=0.5, - stream=True, + ratio: float = 0.5, + stream: bool = True, + replay_buffer: ReplayBuffer = None, log_dir: str = None, - sample_cond_info=True, - random_action_prob=0.0, + sample_cond_info: bool = True, + random_action_prob: float = 0.0, + hindsight_ratio: float = 0.0, + init_train_iter: int = 0, ): """Parameters ---------- @@ -43,6 +49,8 @@ def __init__( model: nn.Module The model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker) + replay_buffer: ReplayBuffer + The replay buffer for training on past data batch_size: int The number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task) @@ -63,6 +71,7 @@ def __init__( """ self.data = dataset self.model = model + self.replay_buffer = replay_buffer self.batch_size = batch_size self.offline_batch_size = int(np.ceil(batch_size * ratio)) self.online_batch_size = int(np.floor(batch_size * (1 - ratio))) @@ -75,17 +84,21 @@ def __init__( self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely self.sample_cond_info = sample_cond_info self.random_action_prob = random_action_prob + self.hindsight_ratio = hindsight_ratio + self.train_it = init_train_iter self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag + + # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) + # then "offline" now refers to cond info and online to x, so no duplication and we don't end + # up with 2*batch_size accidentally if not sample_cond_info: - # Slightly weird semantics, but if we're sampling x given some fixed (data) cond info - # then "offline" refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally self.offline_batch_size = self.online_batch_size = batch_size - self.log_dir = log_dir + # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we # don't want to initialize per-worker things just yet, such as where the log the worker writes # to. This must be done in __iter__, which is called by the DataLoader once this instance # has been copied into a new python process. + self.log_dir = log_dir self.log = SQLiteLog() self.log_hooks: List[Callable] = [] @@ -105,9 +118,9 @@ def _idx_iterator(self): if n == 0: yield np.arange(0, 0) return - if worker_info is None: + if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 - else: + else: # split the data into chunks (per-worker) nw = worker_info.num_workers wid = worker_info.id start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) @@ -143,7 +156,11 @@ def __iter__(self): # Sample conditional info such as temperature, trade-off weights, etc. if self.sample_cond_info: - cond_info = self.task.sample_conditional_information(num_offline + self.online_batch_size) + num_online = self.online_batch_size + cond_info = self.task.sample_conditional_information( + num_offline + self.online_batch_size, self.train_it + ) + # Sample some dataset data mols, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) flat_rewards = ( @@ -151,15 +168,17 @@ def __iter__(self): ) graphs = [self.ctx.mol_to_graph(m) for m in mols] trajs = self.algo.create_training_data_from_graphs(graphs) - num_online = self.online_batch_size + else: # If we're not sampling the conditionals, then the idcs refer to listed preferences num_online = num_offline num_offline = 0 - cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs])) + cond_info = self.task.encode_conditional_information( + steer_info=torch.stack([self.data[i] for i in idcs]) + ) trajs, flat_rewards = [], [] - is_valid = torch.ones(num_offline + num_online).bool() # Sample some on-policy data + is_valid = torch.ones(num_offline + num_online).bool() if num_online > 0: with torch.no_grad(): trajs += self.algo.create_training_data_from_own_samples( @@ -181,17 +200,15 @@ def __iter__(self): # fetch the valid trajectories endpoints mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] # ask the task to compute their reward - preds, m_is_valid = self.task.compute_flat_rewards(mols) - assert preds.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) + assert ( + online_flat_rew.ndim == 2 + ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" # The task may decide some of the mols are invalid, we have to again filter those valid_idcs = valid_idcs[m_is_valid] valid_mols = [m for m, v in zip(mols, m_is_valid) if v] - pred_reward = torch.zeros((num_online, preds.shape[1])) - pred_reward[valid_idcs - num_offline] = preds - # TODO: reintegrate bootstrapped reward predictions - # if preds.shape[0] > 0: - # for i in range(self.number_of_objectives): - # pred_reward[valid_idcs - num_offline, i] = preds[range(preds.shape[0]), i] + pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) + pred_reward[valid_idcs - num_offline] = online_flat_rew is_valid[num_offline:] = False is_valid[valid_idcs] = True flat_rewards += list(pred_reward) @@ -201,49 +218,95 @@ def __iter__(self): if self.log_molecule_smis: for i, m in zip(valid_idcs, valid_mols): trajs[i]["smi"] = Chem.MolToSmiles(m) - flat_rewards = torch.stack(flat_rewards) + # Compute scalar rewards from conditional information & flat rewards + flat_rewards = torch.stack(flat_rewards) log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) log_rewards[torch.logical_not(is_valid)] = self.algo.illegal_action_logreward - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.mols = mols - batch.preferences = cond_info.get("preferences", None) - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object + # Computes some metrics if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): i["data_idx"] = j - - # Converts back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) + # note: we convert back into natural rewards for logging purposes + # (allows to take averages and plot in objective space) + # TODO: implement that per-task (in case they don't apply the same beta and log transformations) rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: self.log_generated( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) if num_online > 0: extra_info = {} for hook in self.log_hooks: extra_info.update( hook( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) ) - batch.extra_info = extra_info + + if self.replay_buffer is not None: + # If we have a replay buffer, we push the online trajectories in it + # and resample immediately such that the "online" data in the batch + # comes from a more stable distribution (try to avoid forgetting) + + # cond_info is a dict, so we need to convert it to a list of dicts + cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] + + # push the online trajectories in the replay buffer and sample a new 'online' batch + for i in range(num_offline, len(trajs)): + self.replay_buffer.push( + deepcopy(trajs[i]), + deepcopy(log_rewards[i]), + deepcopy(flat_rewards[i]), + deepcopy(cond_info[i]), + deepcopy(is_valid[i]), + ) + replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( + num_online + ) + + # append the online trajectories to the offline ones + trajs[num_offline:] = replay_trajs + log_rewards[num_offline:] = replay_logr + flat_rewards[num_offline:] = replay_fr + cond_info[num_offline:] = replay_condinfo + is_valid[num_offline:] = replay_valid + + # convert cond_info back to a dict + cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} + + if self.hindsight_ratio > 0.0: + # Relabels some of the online trajectories with hindsight + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + log_rewards[torch.logical_not(is_valid)] = self.algo.illegal_action_logreward + + # Construct batch + batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) + batch.num_offline = num_offline + batch.num_online = num_online + batch.flat_rewards = flat_rewards + batch.preferences = cond_info.get("preferences", None) + batch.focus_dir = cond_info.get("focus_dir", None) + batch.extra_info = extra_info + # TODO: we could very well just pass the cond_info dict to construct_batch above, + # and the algo can decide what it wants to put in the batch object + + self.train_it += worker_info.num_workers if worker_info is not None else 1 yield batch def log_generated(self, trajs, rewards, flat_rewards, cond_info): @@ -258,18 +321,26 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() rewards = rewards.data.numpy().tolist() preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences"]] + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] data = [ - [mols[i], rewards[i]] + flat_rewards[i] + preferences[i] + [cond_info[k][i].item() for k in logged_keys] + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] for i in range(len(trajs)) ] + data_labels = ( ["smi", "r"] + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + [f"ci_{k}" for k in logged_keys] ) + self.log.insert_many(data, data_labels) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index d5ed182f..39dd8309 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -75,7 +75,7 @@ def load_task_models(self): state_dict = torch.load("/data/chem/qm9/mxmnet_gap_model.pt") gap_model.load_state_dict(state_dict) gap_model.cuda() - gap_model, self.device = self._wrap_model(gap_model) + gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: @@ -102,7 +102,7 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": beta, "encoding": beta_enc} + return {"beta": torch.tensor(beta), "encoding": beta_enc} def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): @@ -191,7 +191,7 @@ def setup(self): temperature_distribution=hps["temperature_sample_dist"], temperature_parameters=hps["temperature_dist_params"], num_thermometer_dim=hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, + wrap_model=self._wrap_for_mp, ) self.mb_size = hps["global_batch_size"] self.clip_grad_param = hps["clip_grad_param"] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e51e21d9..14009ef1 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,10 +1,13 @@ import ast import copy +import json import os +import pathlib import shutil import socket from typing import Any, Callable, Dict, List, Tuple, Union +import git import numpy as np import scipy.stats as stats import torch @@ -16,6 +19,7 @@ from torch.utils.data import Dataset from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.models import bengio2021flow @@ -59,10 +63,10 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model) + model, self.device = self._wrap_model(model, send_to_device=True) return {"seh": model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: beta = None if self.temperature_sample_dist == "constant": assert type(self.temperature_dist_params) is float @@ -86,7 +90,7 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": beta, "encoding": beta_enc} + return {"beta": torch.tensor(beta), "encoding": beta_enc} def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): @@ -116,7 +120,7 @@ def default_hps(self) -> Dict[str, Any]: "hostname": socket.gethostname(), "bootstrap_own_reward": False, "learning_rate": 1e-4, - "Z_learning_rate": 1e-4, + "Z_learning_rate": 1e-3, "global_batch_size": 64, "num_emb": 128, "num_layers": 4, @@ -131,14 +135,18 @@ def default_hps(self) -> Dict[str, Any]: "momentum": 0.9, "adam_eps": 1e-8, "lr_decay": 20000, - "Z_lr_decay": 20000, + "Z_lr_decay": 50000, "clip_grad_type": "norm", "clip_grad_param": 10, - "random_action_prob": 0.0, + "random_action_prob": 0.01, "valid_random_action_prob": 0.0, "sampling_tau": 0.0, "max_nodes": 9, "num_thermometer_dim": 32, + "use_replay_buffer": False, + "replay_buffer_size": 10000, + "replay_buffer_warmup": 10000, + "mp_pickle_messages": False, } def setup_algo(self): @@ -149,8 +157,9 @@ def setup_task(self): dataset=self.training_data, temperature_distribution=self.hps["temperature_sample_dist"], temperature_parameters=self.hps["temperature_dist_params"], + rng=self.rng, num_thermometer_dim=self.hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, + wrap_model=self._wrap_for_mp, ) def setup_model(self): @@ -170,6 +179,11 @@ def setup(self): self.test_data = [] self.offline_ratio = 0 self.valid_offline_ratio = 0 + self.replay_buffer = ( + ReplayBuffer(self.hps["replay_buffer_size"], self.hps["replay_buffer_warmup"], self.rng) + if self.hps["use_replay_buffer"] + else None + ) self.setup_env_context() self.setup_algo() self.setup_task() @@ -205,6 +219,16 @@ def setup(self): "none": (lambda x: None), }[hps["clip_grad_type"]] + # saving hyperparameters + git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] + self.hps["gflownet_git_hash"] = git_hash + + os.makedirs(self.hps["log_dir"], exist_ok=True) + fmt_hps = "\n".join([f"{f'{k}':40}:\t{f'({type(v).__name__})':10}\t{v}" for k, v in sorted(self.hps.items())]) + print(f"\n\nHyperparameters:\n{'-'*50}\n{fmt_hps}\n{'-'*50}\n\n") + with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as f: + json.dump(self.hps, f) + def step(self, loss: Tensor): loss.backward() for i in self.model.parameters(): @@ -241,7 +265,7 @@ def main(): os.makedirs(hps["log_dir"]) trial = SEHFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - trial.verbose = True + trial.print_every = 1 trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index be09a5e6..1373d3d8 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -2,9 +2,9 @@ import os import pathlib import shutil +from copy import deepcopy from typing import Any, Callable, Dict, List, Tuple, Union -import git import numpy as np import torch import torch.nn as nn @@ -26,6 +26,7 @@ from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.train import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore +from gflownet.utils.focus_model import FocusModel, TabularFocusModel from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook from gflownet.utils.transforms import thermometer @@ -47,7 +48,16 @@ def __init__( temperature_sample_dist: str, temperature_parameters: Tuple[float, float], num_thermometer_dim: int, - use_pref_thermometer: bool, + use_steer_thermometer: bool = False, + preference_type: str = None, + focus_type: Union[list, str] = None, + focus_cosim: float = 0.0, + focus_limit_coef: float = 1.0, + fixed_focus_dirs: torch.Tensor = None, + illegal_action_logreward: float = None, + focus_model: FocusModel = None, + focus_model_training_limits: Tuple[int, int] = None, + max_train_it: int = None, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): @@ -59,9 +69,18 @@ def __init__( self.temperature_sample_dist = temperature_sample_dist self.temperature_dist_params = temperature_parameters self.num_thermometer_dim = num_thermometer_dim - self.use_pref_thermometer = use_pref_thermometer + self.use_steer_thermometer = use_steer_thermometer + self.preference_type = preference_type self.seeded_preference = None self.experimental_dirichlet = False + self.focus_type = focus_type + self.focus_cosim = focus_cosim + self.focus_limit_coef = focus_limit_coef + self.fixed_focus_dirs = fixed_focus_dirs + self.illegal_action_logreward = illegal_action_logreward + self.focus_model = focus_model + self.focus_model_training_limits = focus_model_training_limits + self.max_train_it = max_train_it assert set(objectives) <= {"seh", "qed", "sa", "mw"} and len(objectives) == len(set(objectives)) def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: @@ -72,33 +91,73 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model) + model, self.device = self._wrap_model(model, send_to_device=True) return {"seh": model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - cond_info = super().sample_conditional_information(n) + def get_steer_encodings(self, preferences, focus_dirs): + n = len(preferences) + if self.use_steer_thermometer: + pref_enc = thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1) + focus_enc = thermometer(focus_dirs, self.num_thermometer_dim, 0, 1).reshape(n, -1) + else: + pref_enc = preferences + focus_enc = focus_dirs + return pref_enc, focus_enc + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + cond_info = super().sample_conditional_information(n, train_it) - if self.seeded_preference is not None: - preferences = torch.tensor([self.seeded_preference] * n).float() - elif self.experimental_dirichlet: - a = np.random.dirichlet([1] * len(self.objectives), n) - b = np.random.exponential(1, n)[:, None] - preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() + if self.preference_type is None: + preferences = torch.ones((n, len(self.objectives))) else: + if self.seeded_preference is not None: + preferences = torch.tensor([self.seeded_preference] * n).float() + elif self.experimental_dirichlet: + a = np.random.dirichlet([1] * len(self.objectives), n) + b = np.random.exponential(1, n)[:, None] + preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() + else: + m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) + preferences = m.sample([n]) + + if self.fixed_focus_dirs is not None: + focus_dir = torch.tensor( + np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) + ) + elif self.focus_type == "dirichlet": m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - preferences = m.sample([n]) + focus_dir = m.sample([n]) + elif self.focus_type == "hyperspherical": + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") + ).float() + elif self.focus_type is not None and "learned" in self.focus_type: + if self.focus_model is not None and train_it >= self.focus_model_training_limits[0] * self.max_train_it: + focus_dir = self.focus_model.sample_focus_directions(n) + else: + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") + ).float() + else: + raise NotImplementedError(f"Unsupported focus_type={type(self.focus_type)}") - preferences_enc = ( - thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1) - if self.use_pref_thermometer - else preferences - ) - cond_info["encoding"] = torch.cat([cond_info["encoding"], preferences_enc], 1) + preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) + cond_info["encoding"] = torch.cat([cond_info["encoding"], preferences_enc, focus_enc], 1) cond_info["preferences"] = preferences + cond_info["focus_dir"] = focus_dir return cond_info - def encode_conditional_information(self, preferences: Tensor) -> Dict[str, Tensor]: - n = len(preferences) + def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: + """ + Encode conditional information at validation-time + We use the maximum temperature beta for inference + Args: + steer_info: Tensor of shape (Batch, 2 * n_objectives) containing the preferences and focus_dirs + in that order + Returns: + Dict[str, Tensor]: Dictionary containing the encoded conditional information + """ + n = len(steer_info) if self.temperature_sample_dist == "constant": beta = torch.ones(n) * self.temperature_dist_params beta_enc = torch.zeros((n, self.num_thermometer_dim)) @@ -107,11 +166,41 @@ def encode_conditional_information(self, preferences: Tensor) -> Dict[str, Tenso beta_enc = torch.ones((n, self.num_thermometer_dim)) assert len(beta.shape) == 1, f"beta should be of shape (Batch,), got: {beta.shape}" - if self.use_pref_thermometer: - encoding = torch.cat([beta_enc, thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1)], 1) - else: - encoding = torch.cat([beta_enc, preferences], 1) - return {"beta": beta, "encoding": encoding.float(), "preferences": preferences.float()} + + # TODO: positional assumption here, should have something cleaner + preferences = steer_info[:, : len(self.objectives)].float() + focus_dir = steer_info[:, len(self.objectives) :].float() + + preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() + + return { + "beta": beta, + "encoding": encoding, + "preferences": preferences, + "focus_dir": focus_dir, + } + + def relabel_condinfo_and_logrewards( + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + ): + if self.focus_type is None: + return cond_info, log_rewards + # only keep hindsight_idxs that actually correspond to a violated constraint + _, in_focus_mask = metrics.compute_focus_coef(flat_rewards, cond_info["focus_dir"], self.focus_cosim) + out_focus_mask = torch.logical_not(in_focus_mask) + hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] + + # relabels the focus_dirs and log_rewards + cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) + + preferences_enc, focus_enc = self.get_steer_encodings(cond_info["preferences"], cond_info["focus_dir"]) + cond_info["encoding"] = torch.cat( + [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 + ) + + log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) + return cond_info, log_rewards def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): @@ -123,6 +212,14 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat assert len(scalar_logreward.shape) == len( cond_info["beta"].shape ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" + + if self.focus_type is not None: + focus_coef, in_focus_mask = metrics.compute_focus_coef( + flat_reward, cond_info["focus_dir"], self.focus_cosim, self.focus_limit_coef + ) + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.illegal_action_logreward + return RewardScalar(scalar_logreward * cond_info["beta"]) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: @@ -132,13 +229,13 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid else: - flat_rewards: List[Tensor] = [] + flat_r: List[Tensor] = [] if "seh" in self.objectives: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) batch.to(self.device) seh_preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8 seh_preds[seh_preds.isnan()] = 0 - flat_rewards.append(seh_preds) + flat_r.append(seh_preds) def safe(f, x, default): try: @@ -148,19 +245,19 @@ def safe(f, x, default): if "qed" in self.objectives: qeds = torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v.item()]) - flat_rewards.append(qeds) + flat_r.append(qeds) if "sa" in self.objectives: sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i, v in zip(mols, is_valid) if v.item()]) sas = (10 - sas) / 9 # Turn into a [0-1] reward - flat_rewards.append(sas) + flat_r.append(sas) if "mw" in self.objectives: molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i, v in zip(mols, is_valid) if v.item()]) molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 - flat_rewards.append(molwts) + flat_r.append(molwts) - flat_rewards = torch.stack(flat_rewards, dim=1) + flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid @@ -172,10 +269,16 @@ def default_hps(self) -> Dict[str, Any]: "objectives": ["seh", "qed", "sa", "mw"], "sampling_tau": 0.95, "valid_sample_cond_info": False, - "n_valid_prefs": 15, - "n_valid_repeats_per_pref": 128, + "n_valid": 15, + "n_valid_repeats": 128, "preference_type": "dirichlet", - "use_pref_thermometer": False, + "focus_type": None, + "focus_cosim": None, + "focus_limit_coef": 1.0, + "hindsight_ratio": 0.0, + "focus_model_training_limits": None, + "focus_model_state_space_res": None, + "use_steer_thermometer": False, } def setup_algo(self): @@ -191,6 +294,18 @@ def setup_algo(self): elif hps["algo"] == "MOQL": self.algo = EnvelopeQLearning(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) + if hps["focus_type"] is not None and "learned" in hps["focus_type"]: + if hps["focus_type"] == "learned-tabular": + self.focus_model = TabularFocusModel( + device=self.device, + n_objectives=len(hps["objectives"]), + state_space_res=hps["focus_model_state_space_res"], + ) + else: + raise NotImplementedError("Unknown focus model type {self.focus_type}") + else: + self.focus_model = None + def setup_task(self): self.task = SEHMOOTask( objectives=self.hps["objectives"], @@ -198,8 +313,17 @@ def setup_task(self): temperature_sample_dist=self.hps["temperature_sample_dist"], temperature_parameters=self.hps["temperature_dist_params"], num_thermometer_dim=self.hps["num_thermometer_dim"], - wrap_model=self._wrap_model_mp, - use_pref_thermometer=self.hps["use_pref_thermometer"], + use_steer_thermometer=self.hps["use_steer_thermometer"], + preference_type=self.hps["preference_type"], + focus_type=self.hps["focus_type"], + focus_cosim=self.hps["focus_cosim"], + focus_limit_coef=self.hps["focus_limit_coef"], + illegal_action_logreward=self.hps["illegal_action_logreward"], + focus_model=self.focus_model, + focus_model_training_limits=self.hps["focus_model_training_limits"], + max_train_it=self.hps["num_training_steps"], + rng=self.rng, + wrap_model=self._wrap_for_mp, ) def setup_model(self): @@ -223,57 +347,108 @@ def setup_model(self): self.model = model def setup_env_context(self): - if self.hps.get("use_pref_thermometer", False): - ncd = self.hps["num_thermometer_dim"] * (1 + len(self.hps["objectives"])) + if self.hps.get("use_steer_thermometer", False): + ncd = self.hps["num_thermometer_dim"] * (1 + 2 * len(self.hps["objectives"])) else: - ncd = self.hps["num_thermometer_dim"] + len(self.hps["objectives"]) - self.ctx = FragMolBuildingEnvContext(max_frags=9, num_cond_dim=ncd) + ncd = self.hps["num_thermometer_dim"] + 2 * len( + self.hps["objectives"] + ) # 1 for prefs and 1 for focus region + self.ctx = FragMolBuildingEnvContext(max_frags=self.hps["max_nodes"], num_cond_dim=ncd) def setup(self): super().setup() self.sampling_hooks.append( - MultiObjectiveStatsHook(256, self.hps["log_dir"], compute_igd=True, compute_pc_entropy=True) + MultiObjectiveStatsHook( + 256, + self.hps["log_dir"], + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.hps["focus_type"] is not None else False, + focus_cosim=self.hps["focus_cosim"], + ) ) + # instantiate preference and focus conditioning vectors for validation n_obj = len(self.hps["objectives"]) + n_valid = self.hps["n_valid"] + + # making sure hyperparameters for preferences and focus regions are consistent + if not ( + self.hps["focus_type"] is None + or self.hps["focus_type"] == "centered" + or (type(self.hps["focus_type"]) is list and len(self.hps["focus_type"]) == 1) + ): + assert self.hps["preference_type"] is None, ( + f"Cannot use preferences with multiple focus regions, here focus_type={self.hps['focus_type']} " + f"and preference_type={self.hps['preference_type']}" + ) + + if type(self.hps["focus_type"]) is list and len(self.hps["focus_type"]) > 1: + n_valid = len(self.hps["focus_type"]) - # create fixed preference vectors for validation + # preference vectors if self.hps["preference_type"] is None: - valid_preferences = np.ones((self.hps["n_valid_prefs"], n_obj)) + valid_preferences = np.ones((n_valid, n_obj)) elif self.hps["preference_type"] == "dirichlet": - valid_preferences = metrics.partition_hypersphere(d=n_obj, k=self.hps["n_valid_prefs"], normalisation="l1") + valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") elif self.hps["preference_type"] == "seeded_single": - seeded_prefs = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet( - [1] * n_obj, self.hps["n_valid_prefs"] - ) + seeded_prefs = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] elif self.hps["preference_type"] == "seeded_many": - valid_preferences = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet( - [1] * n_obj, self.hps["n_valid_prefs"] + valid_preferences = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet([1] * n_obj, n_valid) + else: + raise NotImplementedError(f"Unknown preference type {self.hps['preference_type']}") + + # focus regions + if self.hps["focus_type"] is None: + valid_focus_dirs = np.zeros((n_valid, n_obj)) + self.task.fixed_focus_dirs = valid_focus_dirs + elif self.hps["focus_type"] == "centered": + valid_focus_dirs = np.ones((n_valid, n_obj)) + self.task.fixed_focus_dirs = valid_focus_dirs + elif self.hps["focus_type"] == "partitioned": + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") + self.task.fixed_focus_dirs = valid_focus_dirs + elif self.hps["focus_type"] in ["dirichlet", "learned-gfn"]: + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") + self.task.fixed_focus_dirs = None + elif self.hps["focus_type"] in ["hyperspherical", "learned-tabular"]: + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") + self.task.fixed_focus_dirs = None + elif type(self.hps["focus_type"]) is list: + if len(self.hps["focus_type"]) == 1: + valid_focus_dirs = np.array([self.hps["focus_type"][0]] * n_valid) + self.task.fixed_focus_dirs = valid_focus_dirs + else: + valid_focus_dirs = np.array(self.hps["focus_type"]) + self.task.fixed_focus_dirs = valid_focus_dirs + else: + raise NotImplementedError( + f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " + f"focus_type, but here: {self.hps['focus_type']}" ) - self._top_k_hook = TopKHook(10, self.hps["n_valid_repeats_per_pref"], len(valid_preferences)) - self.test_data = RepeatedPreferenceDataset(valid_preferences, self.hps["n_valid_repeats_per_pref"]) + self.hps["fixed_focus_dirs"] = ( + np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None + ) + with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as f: + json.dump(self.hps, f) + assert valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), f"Invalid shape for valid_preferences, {valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + + # combine preferences and focus directions (fixed focus cosim) since they could be used together (not either/or) + # TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, valid_focus_dirs], axis=1) + + self._top_k_hook = TopKHook(10, self.hps["n_valid_repeats"], len(valid_cond_vector)) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.hps["n_valid_repeats"]) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task - git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.hps["gflownet_git_hash"] = git_hash - - os.makedirs(self.hps["log_dir"], exist_ok=True) - torch.save( - { - "hps": self.hps, - }, - open(pathlib.Path(self.hps["log_dir"]) / "hps.pt", "wb"), - ) - fmt_hps = "\n".join([f"{k}:\t({type(v).__name__})\t{v}".expandtabs(40) for k, v in self.hps.items()]) - print(f"\n\nHyperparameters:\n{'-'*50}\n{fmt_hps}\n{'-'*50}\n\n") - with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as fd: - json.dump(self.hps, fd, sort_keys=True, indent=4) - def build_callbacks(self): # We use this class-based setup to be compatible with the DeterminedAI API, but no direct # dependency is required. @@ -288,18 +463,34 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} - -class RepeatedPreferenceDataset: - def __init__(self, preferences, repeat): - self.prefs = preferences + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + focus_model_training_limits = self.hps["focus_model_training_limits"] + max_train_it = self.hps["num_training_steps"] + if ( + self.focus_model is not None + and train_it >= focus_model_training_limits[0] * max_train_it + and train_it <= focus_model_training_limits[1] * max_train_it + ): + self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) + return super().train_batch(batch, epoch_idx, batch_idx, train_it) + + def _save_state(self, it): + if self.focus_model is not None: + self.focus_model.save(pathlib.Path(self.hps["log_dir"])) + return super()._save_state(it) + + +class RepeatedCondInfoDataset: + def __init__(self, cond_info_vectors, repeat): + self.cond_info_vectors = cond_info_vectors self.repeat = repeat def __len__(self): - return len(self.prefs) * self.repeat + return len(self.cond_info_vectors) * self.repeat def __getitem__(self, idx): assert 0 <= idx < len(self) - return torch.tensor(self.prefs[int(idx // self.repeat)]) + return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) def main(): @@ -310,8 +501,10 @@ def main(): "seed": 0, "global_batch_size": 64, "num_training_steps": 20_000, - "validate_every": 1, - "num_layers": 4, + "num_final_gen_steps": 500, + "validate_every": 10, + "num_layers": 2, + "num_emb": 256, "algo": "TB", "objectives": ["seh", "qed"], "learning_rate": 1e-4, @@ -320,13 +513,23 @@ def main(): "Z_lr_decay": 50000, "sampling_tau": 0.95, "random_action_prob": 0.1, - "num_data_loader_workers": 8, + "num_data_loader_workers": 0, "temperature_sample_dist": "constant", "temperature_dist_params": 60.0, "num_thermometer_dim": 32, - "preference_type": "dirichlet", - "n_valid_prefs": 15, - "n_valid_repeats_per_pref": 128, + "use_steer_thermometer": False, + "preference_type": None, + "focus_type": "learned-tabular", + "focus_cosim": 0.98, + "focus_limit_coef": 1e-1, + "n_valid": 15, + "n_valid_repeats": 128, + "use_replay_buffer": True, + "replay_buffer_warmup": 0, + "hindsight_ratio": 0.3, + "mp_pickle_messages": True, + "focus_model_training_limits": [0.25, 0.75], + "focus_model_state_space_res": 10, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -336,7 +539,7 @@ def main(): os.makedirs(hps["log_dir"]) trial = SEHMOOFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - trial.verbose = True + trial.print_every = 1 trial.run() diff --git a/src/gflownet/train.py b/src/gflownet/train.py index e73cae92..324e96bf 100644 --- a/src/gflownet/train.py +++ b/src/gflownet/train.py @@ -10,10 +10,11 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet.data.replay_buffer import ReplayBuffer from gflownet.data.sampling_iterator import SamplingIterator from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.utils.misc import create_logger -from gflownet.utils.multiprocessing_proxy import wrap_model_mp +from gflownet.utils.multiprocessing_proxy import mp_object_wrapper # This type represents an unprocessed list of reward signals/conditioning information FlatRewards = NewType("FlatRewards", Tensor) # type: ignore @@ -99,6 +100,7 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): # `sampling_model` is used by the data workers to sample new objects from the model. Can be # the same as `model`. self.sampling_model: nn.Module + self.replay_buffer: ReplayBuffer self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext @@ -115,8 +117,8 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): self.offline_ratio = self.hps.get("offline_ratio", 0.5) # idem, but from `self.test_data` during validation. self.valid_offline_ratio = 1 - # If True, print messages during training - self.verbose = False + # Print the loss every `self.print_every` iterations + self.print_every = 1000 # These hooks allow us to compute extra quantities when sampling data self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = [] @@ -136,24 +138,28 @@ def setup(self): def step(self, loss: Tensor): raise NotImplementedError() - def _wrap_model_mp(self, model): - """Wraps a nn.Module instance so that it can be shared to `DataLoader` workers.""" - model.to(self.device) - if self.num_workers > 0: - placeholder = wrap_model_mp( - model, + def _wrap_for_mp(self, obj, send_to_device=False): + """Wraps an object in a placeholder whose reference can be sent to a + data worker process (only if the number of workers is non-zero).""" + if send_to_device: + obj.to(self.device) + if self.num_workers > 0 and obj is not None: + placeholder = mp_object_wrapper( + obj, self.num_workers, cast_types=(gd.Batch, GraphActionCategorical), pickle_messages=self.pickle_messages, ) return placeholder, torch.device("cpu") - return model, self.device + else: + return obj, self.device def build_callbacks(self): return {} def build_training_data_loader(self) -> DataLoader: - model, dev = self._wrap_model_mp(self.sampling_model) + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) iterator = SamplingIterator( self.training_data, model, @@ -162,9 +168,11 @@ def build_training_data_loader(self) -> DataLoader: self.algo, self.task, dev, + replay_buffer=replay_buffer, ratio=self.offline_ratio, log_dir=os.path.join(self.hps["log_dir"], "train"), random_action_prob=self.hps.get("random_action_prob", 0.0), + hindsight_ratio=self.hps.get("hindsight_ratio", 0.0), ) for hook in self.sampling_hooks: iterator.add_log_hook(hook) @@ -179,7 +187,7 @@ def build_training_data_loader(self) -> DataLoader: ) def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_model_mp(self.model) + model, dev = self._wrap_for_mp(self.model, send_to_device=True) iterator = SamplingIterator( self.test_data, model, @@ -204,7 +212,34 @@ def build_validation_data_loader(self) -> DataLoader: prefetch_factor=1 if self.num_workers else 2, ) - def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int) -> Dict[str, Any]: + def build_final_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + iterator = SamplingIterator( + self.training_data, + model, + self.mb_size, + self.ctx, + self.algo, + self.task, + dev, + replay_buffer=None, + ratio=0.0, + log_dir=os.path.join(self.hps["log_dir"], "final"), + random_action_prob=0.0, + hindsight_ratio=0.0, + init_train_iter=self.hps["num_training_steps"], + ) + for hook in self.sampling_hooks: + iterator.add_log_hook(hook) + return torch.utils.data.DataLoader( + iterator, + batch_size=None, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + prefetch_factor=1 if self.num_workers else 2, + ) + + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: try: loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): @@ -243,15 +278,22 @@ def run(self, logger=None): ckpt_freq = self.hps.get("checkpoint_every", valid_freq) train_dl = self.build_training_data_loader() valid_dl = self.build_validation_data_loader() + if self.hps.get("num_final_gen_steps", 0) > 0: + final_dl = self.build_final_data_loader() callbacks = self.build_callbacks() start = self.hps.get("start_at_step", 0) + 1 logger.info("Starting training") for it, batch in zip(range(start, 1 + self.hps["num_training_steps"]), cycle(train_dl)): epoch_idx = it // epoch_length batch_idx = it % epoch_length - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx) + if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: + logger.info( + f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" + ) + continue + info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) self.log(info, it, "train") - if self.verbose: + if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) if valid_freq > 0 and it % valid_freq == 0: @@ -268,6 +310,16 @@ def run(self, logger=None): self._save_state(it) self._save_state(self.hps["num_training_steps"]) + num_final_gen_steps = self.hps.get("num_final_gen_steps", 0) + if num_final_gen_steps > 0: + logger.info(f"Generating final {num_final_gen_steps} batches ...") + for it, batch in zip( + range(self.hps["num_training_steps"], self.hps["num_training_steps"] + num_final_gen_steps + 1), + cycle(final_dl), + ): + pass + logger.info("Final generation steps completed.") + def _save_state(self, it): torch.save( { diff --git a/src/gflownet/utils/focus_model.py b/src/gflownet/utils/focus_model.py new file mode 100644 index 00000000..14bf6c71 --- /dev/null +++ b/src/gflownet/utils/focus_model.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import torch +import torch.nn as nn + +from gflownet.utils.metrics import get_limits_of_hypercube + + +class FocusModel: + """ + Abstract class for a belief model over focus directions for goal-conditioned GFNs. + Goal-conditioned GFNs allow for more control over the objective-space region from which + we wish to sample. However due to the growing number of emtpy regions in the objective space, + if we naively sample focus-directions from the entire objective space, we will condition + our GFN with a lot of infeasible directions which significantly harms its sample efficiency + compared to a more simple preference-conditioned model. + To alleviate this problem, we introduce a focus belief model which is used to sample + focus directions from a subset of the objective space. The belief model is + trained to predict the probability of a focus direction being feasible. The likelihood + to sample a focus direction is then proportional to its population. Directions that have never + been sampled should be given the maximum likelihood. + """ + + def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: + """ + args: + device: torch device + n_objectives: number of objectives + state_space_res: resolution of the state space discretisation. The number of focus directions to consider + grows within O(state_space_res ** n_objectives) and depends on the amount of filtering we apply + (e.g. valid focus-directions should sum to 1 [dirichlet], should contain a 1 [limits], etc.) + """ + self.device = device + self.n_objectives = n_objectives + self.state_space_res = state_space_res + + self.feasible_flow = 1.0 + self.infeasible_flow = 0.1 + + def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): + raise NotImplementedError + + def sample_focus_directions(self, n: int): + raise NotImplementedError + + +class TabularFocusModel(FocusModel): + """ + Tabular model of the feasibility of focus directions for goal-condtioning. + We keep a count of the number of times each focus direction has been sampled and whether + this direction succesfully lead to a sample in this region of the objective space. The (unormalized) likelihood + of a focus direction being feasible is then given by the ratio of these numbers. + If a focus direction has not been sampled yet it obtains the maximum likelihood of one. + """ + + def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: + super().__init__(device, n_objectives, state_space_res) + self.n_objectives = n_objectives + self.state_space_res = state_space_res + self.focus_dir_dataset = ( + nn.functional.normalize(torch.tensor(get_limits_of_hypercube(n_objectives, state_space_res)), dim=1) + .float() + .to(self.device) + ) + self.focus_dir_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) + self.focus_dir_population_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) + + def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): + """ + Updates the focus model with the focus directions and rewards + of the last batch. + """ + focus_dirs = nn.functional.normalize(focus_dirs, dim=1) + flat_rewards = nn.functional.normalize(flat_rewards, dim=1) + + focus_dirs_indices = torch.argmin(torch.cdist(focus_dirs, self.focus_dir_dataset), dim=1) + flat_rewards_indices = torch.argmin(torch.cdist(flat_rewards, self.focus_dir_dataset), dim=1) + + for idxs, count in zip( + [focus_dirs_indices, flat_rewards_indices], + [self.focus_dir_count, self.focus_dir_population_count], + ): + idx_increments = torch.bincount(idxs, minlength=len(count)) + count += idx_increments + + def sample_focus_directions(self, n: int): + """ + Samples n focus directions from the focus model. + """ + sampling_likelihoods = torch.zeros_like(self.focus_dir_count).float().to(self.device) + sampling_likelihoods[self.focus_dir_count == 0] = self.feasible_flow + sampling_likelihoods[ + torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0) + ] = self.feasible_flow + sampling_likelihoods[ + torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0) + ] = self.infeasible_flow + focus_dir_indices = torch.multinomial(sampling_likelihoods, n, replacement=True) + return self.focus_dir_dataset[focus_dir_indices].to("cpu") + + def save(self, path: Path): + params = { + "n_objectives": self.n_objectives, + "state_space_res": self.state_space_res, + "focus_dir_dataset": self.focus_dir_dataset.to("cpu"), + "focus_dir_count": self.focus_dir_count.to("cpu"), + "focus_dir_population_count": self.focus_dir_population_count.to("cpu"), + } + torch.save(params, open(path / "tabular_focus_model.pt", "wb")) + + def load(self, device: torch.device, path: Path): + params = torch.load(open(path / "tabular_focus_model.pt", "rb")) + self.n_objectives = params["n_objectives"] + self.state_space_res = params["state_space_res"] + self.focus_dir_dataset = params["focus_dir_dataset"].to(device) + self.focus_dir_count = params["focus_dir_count"].to(device) + self.focus_dir_population_count = params["focus_dir_population_count"].to(device) diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index abf123df..3b118b44 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.nn as nn from botorch.utils.multi_objective import infer_reference_point, pareto from botorch.utils.multi_objective.hypervolume import Hypervolume from rdkit import Chem, DataStructs @@ -11,6 +12,47 @@ from sklearn.cluster import KMeans +def compute_focus_coef( + flat_rewards: torch.Tensor, focus_dirs: torch.Tensor, focus_cosim: float, focus_limit_coef: float = 1.0 +): + """ + The focus direction is defined as a hypercone in the objective space centered around an focus_dir. + The focus coefficient (between 0 and 1) scales the reward associated to a given sample. + It should be 1 when the sample is exactly at the focus direction, equal to the focus_limit_coef + when the sample is at on the limit of the focus region and 0 when it is outside the focus region + we can use an exponential decay of the focus coefficient between the center and the limit of the focus region + i.e. cosim(sample, focus_dir) ** focus_gamma_param = focus_limit_coef + Note that we work in the positive quadrant (each reward is positive) and thus the cosine similarity is in [0, 1] + + :param focus_dirs: the focus directions, shape (batch_size, num_objectives) + :param flat_rewards: the flat rewards, shape (batch_size, num_objectives) + :param focus_cosim: the cosine similarity threshold to define the focus region + :param focus_limit_coef: the focus coefficient at the limit of the focus region + """ + assert focus_cosim >= 0.0 and focus_cosim <= 1.0, f"focus_cosim must be in [0, 1], now {focus_cosim}" + assert ( + focus_limit_coef > 0.0 and focus_limit_coef <= 1.0 + ), f"focus_limit_coef must be in (0, 1], now {focus_limit_coef}" + focus_gamma_param = np.log(focus_limit_coef) / np.log(focus_cosim) + cosim = nn.functional.cosine_similarity(flat_rewards, focus_dirs, dim=1) + in_focus_mask = cosim >= focus_cosim + focus_coef = torch.where(in_focus_mask, cosim**focus_gamma_param, 0.0) + return focus_coef, in_focus_mask + + +def get_focus_accuracy(flat_rewards, focus_dirs, focus_cosim): + _, in_focus_mask = compute_focus_coef(focus_dirs, flat_rewards, focus_cosim, focus_limit_coef=1.0) + return in_focus_mask.float().sum() / len(flat_rewards) + + +def get_limits_of_hypercube(n_dims, n_points_per_dim=10): + """Discretise the faces that are at the extremity of a unit hypercube""" + linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] + grid = np.array(list(product(*linear_spaces))) + extreme_points = grid[np.any(grid == 1, axis=1)] + return extreme_points + + def get_IGD(samples, ref_front: np.ndarray = None): """ Computes the Inverse Generational Distance of a set of samples w.r.t a reference pareto front. @@ -28,14 +70,6 @@ def get_IGD(samples, ref_front: np.ndarray = None): Returns: float: The IGD value. """ - - def get_limits_of_hypercube(n_dims, n_points_per_dim=10): - """Discretise the faces that are at the extremity of a unit hypercube""" - linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] - grid = np.array(list(product(*linear_spaces))) - extreme_points = grid[np.any(grid == 1, axis=1)] - return extreme_points - n_objectives = samples.shape[1] if ref_front is None: ref_front = get_limits_of_hypercube(n_dims=n_objectives) @@ -71,14 +105,6 @@ def get_PC_entropy(samples, ref_front=None): Returns: float: The IGD value. """ - - def get_limits_of_hypercube(n_dims, n_points_per_dim=10): - """Discretise the faces that are at the extremity of a unit hypercube""" - linear_spaces = [np.linspace(0.0, 1.0, n_points_per_dim) for _ in range(n_dims)] - grid = np.array(list(product(*linear_spaces))) - extreme_points = grid[np.any(grid == 1, axis=1)] - return extreme_points - n_objectives = samples.shape[1] if ref_front is None: ref_front = get_limits_of_hypercube(n_dims=n_objectives) @@ -100,6 +126,18 @@ def get_limits_of_hypercube(n_dims, n_points_per_dim=10): return float(pc_ent) +def sample_positiveQuadrant_ndim_sphere(n=10, d=2, normalisation="l2"): + points = np.random.randn(n, d) + points = np.abs(points) # positive quadrant + if normalisation == "l2": + points /= np.linalg.norm(points, axis=1, keepdims=True) + elif normalisation == "l1": + points /= np.sum(points, axis=1, keepdims=True) + else: + raise ValueError(f"Unknown normalisation {normalisation}") + return points + + def partition_hypersphere(k: int, d: int, n_samples: int = 10000, normalisation: str = "l2"): """ Partition a hypersphere into k clusters. @@ -119,18 +157,6 @@ def partition_hypersphere(k: int, d: int, n_samples: int = 10000, normalisation: v: np.ndarray Array of shape (k, d) containing the cluster centers """ - - def sample_positiveQuadrant_ndim_sphere(n=10, d=2, normalisation="l2"): - points = np.random.randn(n, d) - points = np.abs(points) # positive quadrant - if normalisation == "l2": - points /= np.linalg.norm(points, axis=1, keepdims=True) - elif normalisation == "l1": - points /= np.sum(points, axis=1, keepdims=True) - else: - raise ValueError(f"Unknown normalisation {normalisation}") - return points - points = sample_positiveQuadrant_ndim_sphere(n_samples, d, normalisation) v = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(points).cluster_centers_ if normalisation == "l2": diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 844e3174..d1ac0a2a 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -23,6 +23,8 @@ def __init__( compute_normed=False, compute_igd=False, compute_pc_entropy=False, + compute_focus_accuracy=False, + focus_cosim=None, ): # This __init__ is only called in the main process. This object is then (potentially) cloned # in pytorch data worker processed and __call__'ed from within those processes. This means @@ -36,8 +38,11 @@ def __init__( self.compute_normed = compute_normed self.compute_igd = compute_igd self.compute_pc_entropy = compute_pc_entropy + self.compute_focus_accuracy = compute_focus_accuracy + self.focus_cosim = focus_cosim self.all_flat_rewards: List[Tensor] = [] + self.all_focus_dirs: List[Tensor] = [] self.all_smi: List[str] = [] self.pareto_queue: mp.Queue = mp.Queue() self.pareto_front = None @@ -115,12 +120,15 @@ def _run_pareto_accumulation(self): def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards) + self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) self.all_smi = self.all_smi + list([i.get("smi", None) for i in trajs]) if len(self.all_flat_rewards) > self.num_to_keep: self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep :] + self.all_focus_dirs = self.all_focus_dirs[-self.num_to_keep :] self.all_smi = self.all_smi[-self.num_to_keep :] flat_rewards = torch.stack(self.all_flat_rewards).numpy() + focus_dirs = torch.stack(self.all_focus_dirs).numpy() # collects empirical pareto front from in-process samples pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False) @@ -176,6 +184,14 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): "PCent": pc_ent, "lifetime_PCent_frontOnly": self.pareto_metrics[3], } + if self.compute_focus_accuracy: + focus_acc = metrics.get_focus_accuracy( + torch.tensor(flat_rewards), torch.tensor(focus_dirs), self.focus_cosim + ) + info = { + **info, + "focus_acc": focus_acc, + } return info diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index c49138e9..3027ff4f 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -6,9 +6,10 @@ import torch.multiprocessing as mp -class MPModelPlaceholder: - """This class can be used as a Model in a worker process, and - translates calls to queries to the main process""" +class MPObjectPlaceholder: + """This class can be used for example as a model or dataset placeholder + in a worker process, and translates calls to the object-placeholder into + queries for the main process to execute on the real object.""" def __init__(self, in_queues, out_queues, pickle_messages=False): self.qs = in_queues, out_queues @@ -34,39 +35,45 @@ def decode(self, m): return pickle.loads(m) return m - # TODO: make a generic method for this based on __getattr__ - def logZ(self, *a, **kw): - self._check_init() - self.in_queue.put(self.encode(("logZ", a, kw))) - return self.decode(self.out_queue.get()) + def __getattr__(self, name): + def method_wrapper(*a, **kw): + self._check_init() + self.in_queue.put(self.encode((name, a, kw))) + return self.decode(self.out_queue.get()) + + return method_wrapper def __call__(self, *a, **kw): self._check_init() self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) + def __len__(self): + self._check_init() + self.in_queue.put(("__len__", (), {})) + return self.out_queue.get() -class MPModelProxy: - """This class maintains a reference to an in-cuda-memory model, and + +class MPObjectProxy: + """This class maintains a reference to some object and creates a `placeholder` attribute which can be safely passed to multiprocessing DataLoader workers. - This placeholder model sends messages accross multiprocessing - queues, which are received by this proxy instance, which calls the - model and sends the return value back to the worker. - - Starts its own (daemon) thread. Always passes CPU tensors between - processes. + The placeholders in each process send messages accross multiprocessing + queues which are received by this proxy instance. The proxy instance then + runs the calls on our object and sends the return value back to the worker. + Starts its own (daemon) thread. + Always passes CPU tensors between processes. """ - def __init__(self, model: torch.nn.Module, num_workers: int, cast_types: tuple, pickle_messages: bool = False): - """Construct a multiprocessing model proxy for torch DataLoaders. + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + """Construct a multiprocessing object proxy. Parameters ---------- - model: torch.nn.Module - A torch model which lives in the main process to which method calls are passed + obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple @@ -80,9 +87,12 @@ def __init__(self, model: torch.nn.Module, num_workers: int, cast_types: tuple, self.in_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPModelPlaceholder(self.in_queues, self.out_queues, pickle_messages) - self.model = model - self.device = next(model.parameters()).device + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.obj = obj + if hasattr(obj, "parameters"): + self.device = next(obj.parameters()).device + else: + self.device = torch.device("cpu") self.cuda_types = (torch.Tensor,) + cast_types self.stop = threading.Event() self.thread = threading.Thread(target=self.run, daemon=True) @@ -114,12 +124,26 @@ def run(self): except ConnectionError: break attr, args, kwargs = r - f = getattr(self.model, attr) + f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} result = f(*args, **kwargs) if isinstance(result, (list, tuple)): msg = [self.to_cpu(i) for i in result] + elif isinstance(result, dict): + msg = {k: self.to_cpu(i) for k, i in result.items()} + else: + msg = self.to_cpu(result) + self.out_queues[qi].put(self.encode(msg)) + + +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): + """Construct a multiprocessing object proxy for torch DataLoaders so + that it does not need to be copied in every worker's memory. For example, + this can be used to wrap a model such that only the main process makes + cuda calls by forwarding data through the model, or a replay buffer + such that the new data is pushed in from the worker processes but only the + main process has to hold the full buffer in memory. self.out_queues[qi].put(self.encode(msg)) elif isinstance(result, dict): msg = {k: self.to_cpu(i) for k, i in result.items()} @@ -128,16 +152,10 @@ def run(self): msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) - -def wrap_model_mp(model, num_workers, cast_types, pickle_messages: bool = False): - """Construct a multiprocessing model proxy for torch DataLoaders so - that only one process ends up making cuda calls and holding cuda - tensors in memory. - Parameters ---------- - model: torch.Module - A torch model which lives in the main process to which method calls are passed + obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple @@ -150,8 +168,8 @@ def wrap_model_mp(model, num_workers, cast_types, pickle_messages: bool = False) Returns ------- - placeholder: MPModelPlaceholder - A placeholder model whose method calls route arguments to the main process + placeholder: MPObjectPlaceholder + A placeholder object whose method calls route arguments to the main process """ - return MPModelProxy(model, num_workers, cast_types, pickle_messages).placeholder + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages).placeholder diff --git a/tox.ini b/tox.ini index 617b57be..85d593d5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,17 +1,16 @@ [tox] -envlist = py3{8,9}, report +envlist = py3{9}, report [testenv] commands = pytest skip_install = true depends = - report: py3{8,9} + report: py3{9} setenv = - py3{8,9,10}: COVERAGE_FILE = .coverage.{envname} + py3{9,10}: COVERAGE_FILE = .coverage.{envname} install_command = - pip install -U {opts} {packages} --find-links https://data.pyg.org/whl/torch-1.10.0+cu113.html --find-links https://data.pyg.org/whl/torch-1.10.0+cpu.html + pip install -U {opts} {packages} --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html deps = - py38: -r requirements/dev_3.8.txt py39: -r requirements/dev_3.9.txt From 7f12b919d296578f7e62c0a49b5f9a8a9999ef88 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 26 May 2023 11:55:46 -0400 Subject: [PATCH 2/9] Add flow matching objective (#62) * add random walk features to atom-env * tox * improve fragment environment for validity + include stop_mask + include original fragment stems * tox * remove temp code * Remove json prettifier pre-commit hook * fix mask to device in GraphActionCategorical * more comments and slightly better architecture adaptation * parameterized p_B implementation * modifications to fragment env * fix mask + subTB ignore pad + max_nodes hparam * remove node mask fix in frag env * test for pPB on fragments * tox * tox * rm duplicate line * add timeout to requests in bengio2021flow * doc: add a few comments in TB + rename some variables for clarity * remove unnecessary fragment files * revert .pre-commit-config.yaml to trunk * typing annotation fix * remerge graphs.py * fix merge of graph_sampling.py * fix: forgot ln_type as GraphTransformer parameter * trying things * proper pickle wrapping of MPModelProxy * add mp_pickle_messages train flag * restore files from trunk + comments * remove print + isort * restore TB * fix: pass flag to placeholder * fix prefetch factor default * tox * fix pyproject to not complain about pickle * Initial working FM implementation * comments + fix terminal loss * fix FM terminal actions loop & loss info * more comments + change default * fix comment * idempotency correction for FM * black this branch * tox --- src/gflownet/algo/flow_matching.py | 193 ++++++++++++++++++++++++ src/gflownet/algo/trajectory_balance.py | 12 +- src/gflownet/tasks/seh_frag.py | 11 +- 3 files changed, 210 insertions(+), 6 deletions(-) create mode 100644 src/gflownet/algo/flow_matching.py diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py new file mode 100644 index 00000000..4614c087 --- /dev/null +++ b/src/gflownet/algo/flow_matching.py @@ -0,0 +1,193 @@ +from typing import Any, Dict + +import networkx as nx +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from torch_scatter import scatter + +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnv, + GraphBuildingEnvContext, +) + + +def relabel(ga: GraphAction, g: Graph): + """Relabel the nodes for g to 0-N, and the graph action ga applied to g. + + This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be + labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that + creates a gap in 0-N, leading to a faulty encoding of the graph. + """ + rmap = dict(zip(g.nodes, range(len(g.nodes)))) + if not len(g) and ga.action == GraphActionType.AddNode: + rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0 + g = nx.relabel_nodes(g, rmap) + if ga.source is not None: + ga.source = rmap[ga.source] + if ga.target is not None: + ga.target = rmap[ga.target] + return ga, g + + +class FlowMatching(TrajectoryBalance): # TODO: FM inherits from TB but we could have a generic GFNAlgorithm class + def __init__( + self, + env: GraphBuildingEnv, + ctx: GraphBuildingEnvContext, + rng: np.random.RandomState, + hps: Dict[str, Any], + max_len=None, + max_nodes=None, + ): + super().__init__(env, ctx, rng, hps, max_len=max_len, max_nodes=max_nodes) + self.fm_epsilon = torch.as_tensor(hps.get("fm_epsilon", 1e-38)).log() + # We include the "balanced loss" as a possibility to reproduce results from the FM paper, but + # in a number of settings the regular loss is more stable. + self.fm_balanced_loss = hps.get("fm_balanced_loss", False) + self.fm_leaf_coef = hps.get("fm_leaf_coef", 10) + self.correct_idempotent = self.correct_idempotent or hps.get("fm_correct_idempotent", False) + + def construct_batch(self, trajs, cond_info, log_rewards): + """Construct a batch from a list of trajectories and their information + + Parameters + ---------- + trajs: List[List[tuple[Graph, GraphAction]]] + A list of N trajectories. + cond_info: Tensor + The conditional info that is considered for each trajectory. Shape (N, n_info) + log_rewards: Tensor + The transformed reward (e.g. log(R(x) ** beta)) for each trajectory. Shape (N,) + Returns + ------- + batch: gd.Batch + A (CPU) Batch object with relevant attributes added + """ + if not self.correct_idempotent: + # For every s' (i.e. every state except the first of each trajectory), enumerate parents + parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]] + # convert parents to Data + parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent] + else: + # Here we again enumerate parents + states = [i[0] for tj in trajs for i in tj["traj"][1:]] + base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states] + base_parent_graphs = [ + [self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents + ] + parents = [] + parent_graphs = [] + for state, parent_set, parent_set_graphs in zip(states, base_parents, base_parent_graphs): + new_parent_set = [] + new_parent_graphs = [] + # But for each parent we add all the possible (action, parent) pairs to the sets of parents + for (ga, p), pd in zip(parent_set, parent_set_graphs): + ipa = self.get_idempotent_actions(p, pd, state, ga, return_aidx=False) + new_parent_set += [(a, p) for a in ipa] + new_parent_graphs += [pd] * len(ipa) + parents.append(new_parent_set) + parent_graphs += new_parent_graphs + # Implementation Note: no further correction is required for environments where episodes + # always end in a Stop action. If this is not the case, then this implementation is + # incorrect in that it doesn't account for the multiple ways that one could reach the + # terminal state (because it assumes that a terminal state has only one parent and gives + # 100% of the reward-flow to the edge between that parent and the terminal state, which + # for stop actions is correct). Notably, this error will happen in environments where + # there are invalid states that make episodes end prematurely (when those invalid states + # have multiple possible parents). + + # convert actions to aidx + parent_actions = [pact for parent in parents for pact, pstate in parent] + parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)] + # convert state to Data + state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]] + terminal_actions = [ + self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs + ] + + # Create a batch from [*parents, *states]. This order will make it easier when computing the loss + batch = self.ctx.collate(parent_graphs + state_graphs) + batch.num_parents = torch.tensor([len(i) for i in parents]) + batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) + batch.parent_acts = torch.tensor(parent_actionidcs) + batch.terminal_acts = torch.tensor(terminal_actions) + batch.log_rewards = log_rewards + batch.cond_info = cond_info + batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() + if self.correct_idempotent: + raise ValueError("Not implemented") + return batch + + def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): + dev = batch.x.device + eps = self.fm_epsilon.to(dev) + # Compute relevant quantities + num_trajs = len(batch.log_rewards) + num_states = int(batch.num_parents.shape[0]) + total_num_parents = batch.num_parents.sum() + # Compute, for every parent, the index of the state it corresponds to (all states are + # considered numbered 0..N regardless of which trajectory they correspond to) + parents_state_idx = torch.arange(num_states, device=dev).repeat_interleave(batch.num_parents) + # Compute, for every state, the index of the trajectory it corresponds to + states_traj_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens - 1) + # Idem for parents + parents_traj_idx = states_traj_idx.repeat_interleave(batch.num_parents) + # Compute the index of the first graph of every trajectory via a cumsum of the trajectory + # lengths. This works because by design the first parent of every trajectory is s0 (i.e. s1 + # only has one parent that is s0) + num_parents_per_traj = scatter(batch.num_parents, states_traj_idx, 0, reduce="sum") + first_graph_idx = torch.cumsum( + torch.cat([torch.zeros_like(num_parents_per_traj[0])[None], num_parents_per_traj]), 0 + ) + # Similarly we want the index of the last graph of each trajectory + final_graph_idx = torch.cumsum(batch.traj_lens - 1, 0) + total_num_parents - 1 + + # Query the model for Fsa. The model will output a GraphActionCategorical, but we will + # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of + # log F(s,a) so we don't have to change anything in the sampling routines. + cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the + # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs + # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We + # only need the parent's outputs so we specify those batch indices. + parent_log_F_sa = cat.log_prob( + batch.parent_acts, logprobs=cat.logits, batch=torch.arange(total_num_parents, device=dev) + ) + # The inflows is then simply the sum reduction of exponentiating the log edge flows. The + # indices are the state index that each parent belongs to. + log_inflows = scatter(parent_log_F_sa.exp(), parents_state_idx, 0, reduce="sum").log() + # To compute the outflows we can just logsumexp the log F(s,a) predictions. We do so for the + # entire batch, which is slightly wasteful (TODO). We only take the last outflows here, and + # later take the log outflows of s0 to estimate logZ. + all_log_outflows = cat.logsumexp() + log_outflows = all_log_outflows[total_num_parents:] + + # The loss of intermediary states is inflow - outflow. We use the log-epsilon variant (see FM paper) + intermediate_loss = (torch.logaddexp(log_inflows, eps) - torch.logaddexp(log_outflows, eps)).pow(2) + # To compute the loss of the terminal states we match F(s, a'), where a' is the action that + # terminated the trajectory, to R(s). We again use the mechanism of log_prob + log_F_s_stop = cat.log_prob(batch.terminal_acts, cat.logits, final_graph_idx) + terminal_loss = (torch.logaddexp(log_F_s_stop, eps) - torch.logaddexp(batch.log_rewards, eps)).pow(2) + + if self.fm_balanced_loss: + loss = intermediate_loss.mean() + terminal_loss.mean() * self.fm_leaf_coef + else: + loss = (intermediate_loss.sum() + terminal_loss.sum()) / ( + intermediate_loss.shape[0] + terminal_loss.shape[0] + ) + + # logZ is simply the outflow of s0, the first graph of each parent set. + logZ = all_log_outflows[first_graph_idx] + info = { + "intermediate_loss": intermediate_loss.mean().item(), + "terminal_loss": terminal_loss.mean().item(), + "loss": loss.item(), + "logZ": logZ.mean().item(), + } + return loss, info diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index cab29cb3..0e76abc9 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -151,7 +151,7 @@ def create_training_data_from_graphs(self, graphs): """ return [{"traj": generate_forward_trajectory(i)} for i in graphs] - def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction): + def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True): """Returns the list of idempotent actions for a given transition. Note, this is slow! Correcting for idempotency is needed to estimate p(x) correctly, but @@ -168,22 +168,24 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph The next state's graph action: GraphAction Action leading from g to gp + return_aidx: bool + If true returns of list of action indices, else a list of GraphAction Returns ------- - actions: List[Tuple[int,int,int]] + actions: Union[List[Tuple[int,int,int]], List[GraphAction]] The list of idempotent actions that all lead from g to gp. """ iaction = self.ctx.GraphAction_to_aidx(gd, action) if action.action == GraphActionType.Stop: - return [iaction] + return [iaction if return_aidx else action] # Here we're looking for potential idempotent actions by looking at legal actions of the # same type. This assumes that this is the only way to get to a similar parent. Perhaps # there are edges cases where this is not true...? lmask = getattr(gd, action.action.mask_name) nz = lmask.nonzero() # Legal actions are those with a nonzero mask value - actions = [iaction] + actions = [iaction if return_aidx else action] for i in nz: aidx = (iaction[0], i[0].item(), i[1].item()) if aidx == iaction: @@ -191,7 +193,7 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph ga = self.ctx.aidx_to_GraphAction(gd, aidx, fwd=not action.action.is_backward) child = self.env.step(g, ga) if nx.algorithms.is_isomorphic(child, gp, lambda a, b: a == b, lambda a, b: a == b): - actions.append(aidx) + actions.append(aidx if return_aidx else ga) return actions def construct_batch(self, trajs, cond_info, log_rewards): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 14009ef1..839be788 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -18,6 +18,7 @@ from torch import Tensor from torch.utils.data import Dataset +from gflownet.algo.flow_matching import FlowMatching from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext @@ -147,10 +148,18 @@ def default_hps(self) -> Dict[str, Any]: "replay_buffer_size": 10000, "replay_buffer_warmup": 10000, "mp_pickle_messages": False, + "algo": "TB", } def setup_algo(self): - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.hps, max_nodes=self.hps["max_nodes"]) + algo = self.hps.get("algo", "TB") + if algo == "TB": + algo = TrajectoryBalance + elif algo == "FM": + algo = FlowMatching + else: + raise ValueError(algo) + self.algo = algo(self.env, self.ctx, self.rng, self.hps, max_nodes=self.hps["max_nodes"]) def setup_task(self): self.task = SEHTask( From 6f7eea719da134a84afa74cd5b3fa1127972e1fa Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 26 May 2023 12:13:05 -0400 Subject: [PATCH 3/9] Disable gradient computations in multiprocessing proxy (#78) * add no_grad in mp proxy + error management * tox --- src/gflownet/utils/multiprocessing_proxy.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 3027ff4f..cb220f4b 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,6 +1,7 @@ import pickle import queue import threading +import traceback import torch import torch.multiprocessing as mp @@ -32,7 +33,10 @@ def encode(self, m): def decode(self, m): if self.pickle_messages: - return pickle.loads(m) + m = pickle.loads(m) + if isinstance(m, Exception): + print("Received exception from main process, reraising.") + raise m return m def __getattr__(self, name): @@ -127,7 +131,17 @@ def run(self): f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} - result = f(*args, **kwargs) + try: + # There's no need to compute gradients, since we can't transfer them back to the worker + with torch.no_grad(): + result = f(*args, **kwargs) + except Exception as e: + result = e + exc_str = traceback.format_exc() + try: + pickle.dumps(e) + except Exception: + result = RuntimeError("Exception raised in MPModelProxy, but it cannot be pickled.\n" + exc_str) if isinstance(result, (list, tuple)): msg = [self.to_cpu(i) for i in result] elif isinstance(result, dict): From a340412ff84e9df71d4550c9ed8c18cb5b206257 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 26 May 2023 12:22:23 -0400 Subject: [PATCH 4/9] Fix create_training_data_from_graphs in TB, add missing bck_logprobs and result keys (#81) --- src/gflownet/algo/trajectory_balance.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 0e76abc9..774848ec 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -149,7 +149,15 @@ def create_training_data_from_graphs(self, graphs): trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] A list of trajectories. """ - return [{"traj": generate_forward_trajectory(i)} for i in graphs] + trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] + for traj in trajs: + n_back = [ + self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) + for gp, _ in traj["traj"][1:] + ] + [1] + traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) + traj["result"] = traj["traj"][-1][0] + return trajs def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True): """Returns the list of idempotent actions for a given transition. From 75a8655b4ea9ecbb4517cc970e47936c3d60dfd9 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 6 Jun 2023 14:54:25 -0400 Subject: [PATCH 5/9] Fix: special case for edge attributes in generate_forward_trajectory (#83) * special case for edge attributes in generate_forward_trajectory * PR number --- src/gflownet/envs/frag_mol_env.py | 2 ++ src/gflownet/envs/graph_building_env.py | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index cb20fcd6..82818f83 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -65,6 +65,8 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu # for u and the second half for v. Each logit i in the first half for a given edge # corresponds to setting the stem atom of fragment u used to attach between u and v to be i # (named f'{u}_attach') and vice versa for the second half and v, u. + # Note to self: this choice results in a special case in generate_forward_trajectory for these + # edge attributes. See PR#83 for details. self.num_edge_attr_logits = most_stems * 2 # There are thus up to 2 edge attributes, the stem of u and the stem of v. self.num_edge_attrs = 2 diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 4ac72415..6425b0b3 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -348,14 +348,26 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G if len(i) > 1: # i is an edge e = relabeling_map.get(i[0], None), relabeling_map.get(i[1], None) if e in gn.edges: - # i exists in the new graph, that means some of its attributes need to be added - attrs = [j for j in g.edges[i] if j not in gn.edges[e]] + # i exists in the new graph, that means some of its attributes need to be added. + # + # This remap is a special case for the fragment environment, due to the (poor) design + # choice of treating directed edges as undirected edges. Until we have routines for + # directed graphs, this may need to stay. + def possibly_remap(attr): + if attr == f"{i[0]}_attach": + return f"{e[0]}_attach" + elif attr == f"{i[1]}_attach": + return f"{e[1]}_attach" + return attr + + attrs = [j for j in g.edges[i] if possibly_remap(j) not in gn.edges[e]] if len(attrs) == 0: continue # If nodes are in cycles edges leading to them get stack multiple times, disregard - attr = attrs[np.random.randint(len(attrs))] - gn.edges[e][attr] = g.edges[i][attr] + iattr = attrs[np.random.randint(len(attrs))] + eattr = possibly_remap(iattr) + gn.edges[e][eattr] = g.edges[i][iattr] act = GraphAction( - GraphActionType.SetEdgeAttr, source=e[0], target=e[1], attr=attr, value=g.edges[i][attr] + GraphActionType.SetEdgeAttr, source=e[0], target=e[1], attr=eattr, value=g.edges[i][iattr] ) else: # i doesn't exist, add the edge From 6807bb7c20bf03ad6c2ceee696453ede4f9accda Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Thu, 15 Jun 2023 15:34:50 -0400 Subject: [PATCH 6/9] handle multiple edge_index features (like remove and add) (#94) --- src/gflownet/envs/graph_building_env.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 6425b0b3..a3df87bf 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -511,9 +511,11 @@ def __init__( self.logprobs = None if deduplicate_edge_index and "edge_index" in keys: - idx = keys.index("edge_index") - self.batch[idx] = self.batch[idx][::2] - self.slice[idx] = self.slice[idx].div(2, rounding_mode="floor") + for idx, k in enumerate(keys): + if k != "edge_index": + continue + self.batch[idx] = self.batch[idx][::2] + self.slice[idx] = self.slice[idx].div(2, rounding_mode="floor") def detach(self): new = copy.copy(self) From ffabcfdf2fc1528ddf916ccba0f4d28e6f4bfff8 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 18 Jul 2023 15:26:05 -0400 Subject: [PATCH 7/9] Feature: better configs (#93) This PR replaces the "giant-dict-as-config" with a structured & typed config, which makes each module responsible for defining & documenting its own configuration. * little test * new config structure - in progress * trying config by names * further refactor progress * better pyi sort + fix moo example * tox * import fixes * add SQL + fix n_valid * tox test * fix mypy hook and convert qm9 to new cfg * better config generation * use generated config.py * use generated config.py * fix rng call types * fix test + tox * better config doc * fix deps * tox * re-fix deps * minor fixes for seh_frag_moo * tox * switch to OmegaConf * switch to omegaconf * fix pre-commit-config * add omegaconf dep * fix list defaults to fields * remove comment * version change and revert unnecessary torch.tensor call * addressing PR comments --- VERSION | 2 +- pyproject.toml | 1 + requirements/dev_3.9.txt | 2 + requirements/main_3.9.txt | 2 + src/gflownet/algo/advantage_actor_critic.py | 29 +- src/gflownet/algo/config.py | 119 ++++++ src/gflownet/algo/envelope_q_learning.py | 38 +- src/gflownet/algo/flow_matching.py | 17 +- src/gflownet/algo/multiobjective_reinforce.py | 9 +- src/gflownet/algo/soft_q_learning.py | 33 +- src/gflownet/algo/trajectory_balance.py | 97 +++-- src/gflownet/config.py | 94 +++++ src/gflownet/data/config.py | 24 ++ src/gflownet/data/qm9.py | 16 + src/gflownet/data/replay_buffer.py | 8 +- src/gflownet/data/sampling_iterator.py | 32 +- src/gflownet/models/config.py | 25 ++ src/gflownet/models/graph_transformer.py | 23 +- src/gflownet/tasks/config.py | 91 +++++ src/gflownet/tasks/qm9/qm9.py | 102 +++-- src/gflownet/tasks/seh_frag.py | 173 +++++---- src/gflownet/tasks/seh_frag_moo.py | 348 ++++++++---------- src/gflownet/train.py | 111 +++--- tests/test_frag_env.py | 7 +- 24 files changed, 861 insertions(+), 542 deletions(-) create mode 100644 src/gflownet/algo/config.py create mode 100644 src/gflownet/config.py create mode 100644 src/gflownet/data/config.py create mode 100644 src/gflownet/models/config.py create mode 100644 src/gflownet/tasks/config.py diff --git a/VERSION b/VERSION index bd263dfa..6a9e0389 100644 --- a/VERSION +++ b/VERSION @@ -1,2 +1,2 @@ MAJOR="0" -MINOR="0" +MINOR="1" diff --git a/pyproject.toml b/pyproject.toml index 4467b276..7afb21b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "botorch", "pyro-ppl", "gpytorch", + "omegaconf>=2.3", ] [project.optional-dependencies] diff --git a/requirements/dev_3.9.txt b/requirements/dev_3.9.txt index 0af2d196..4d755596 100644 --- a/requirements/dev_3.9.txt +++ b/requirements/dev_3.9.txt @@ -117,6 +117,8 @@ numpy==1.24.2 # torch-geometric oauthlib==3.2.2 # via requests-oauthlib +omegaconf==2.3.0 + # via gflownet opt-einsum==3.3.0 # via pyro-ppl packaging==23.1 diff --git a/requirements/main_3.9.txt b/requirements/main_3.9.txt index dcc19ca1..66be868a 100644 --- a/requirements/main_3.9.txt +++ b/requirements/main_3.9.txt @@ -74,6 +74,8 @@ numpy==1.24.1 # torch-geometric oauthlib==3.2.2 # via requests-oauthlib +omegaconf==2.3.0 + # via gflownet opt-einsum==3.3.0 # via pyro-ppl packaging==23.0 diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 43245286..001e19d0 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -1,11 +1,10 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd from torch import Tensor +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from .graph_sampling import GraphSampler @@ -17,9 +16,7 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Advantage Actor-Critic implementation, see Asynchronous Methods for Deep Reinforcement Learning, @@ -38,29 +35,25 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.entropy_coef = hps.get("a2c_entropy", 0.01) - self.gamma = hps.get("a2c_gamma", 1) - self.invalid_penalty = hps.get("a2c_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.entropy_coef = cfg.algo.a2c.entropy + self.gamma = cfg.algo.a2c.gamma + self.invalid_penalty = cfg.algo.a2c.penalty assert self.gamma == 1 self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py new file mode 100644 index 00000000..121ab5f4 --- /dev/null +++ b/src/gflownet/algo/config.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TBConfig: + """Trajectory Balance config. + + Attributes + ---------- + bootstrap_own_reward : bool + Whether to bootstrap the reward with the own reward. (deprecated) + epsilon : Optional[float] + The epsilon parameter in log-flow smoothing (see paper) + reward_loss_multiplier : float + The multiplier for the reward loss when bootstrapping the reward. (deprecated) + do_subtb : bool + Whether to use the full N^2 subTB loss + do_correct_idempotent : bool + Whether to correct for idempotent actions + do_parameterize_p_b : bool + Whether to parameterize the P_B distribution (otherwise it is uniform) + subtb_max_len : int + The maximum length trajectories, used to cache subTB computation indices + Z_learning_rate : float + The learning rate for the logZ parameter (only relevant when do_subtb is False) + Z_lr_decay : float + The learning rate decay for the logZ parameter (only relevant when do_subtb is False) + """ + + bootstrap_own_reward: bool = False + epsilon: Optional[float] = None + reward_loss_multiplier: float = 1.0 + do_subtb: bool = False + do_correct_idempotent: bool = False + do_parameterize_p_b: bool = False + subtb_max_len: int = 128 + Z_learning_rate: float = 1e-4 + Z_lr_decay: float = 50_000 + + +@dataclass +class MOQLConfig: + gamma: float = 1 + num_omega_samples: int = 32 + num_objectives: int = 2 + lambda_decay: int = 10_000 + penalty: float = -10 + + +@dataclass +class A2CConfig: + entropy: float = 0.01 + gamma: float = 1 + penalty: float = -10 + + +@dataclass +class FMConfig: + epsilon: float = 1e-38 + balanced_loss: bool = False + leaf_coef: float = 10 + correct_idempotent: bool = False + + +@dataclass +class SQLConfig: + alpha: float = 0.01 + gamma: float = 1 + penalty: float = -10 + + +@dataclass +class AlgoConfig: + """Generic configuration for algorithms + + Attributes + ---------- + method : str + The name of the algorithm to use (e.g. "TB") + global_batch_size : int + The batch size for training + max_len : int + The maximum length of a trajectory + max_nodes : int + The maximum number of nodes in a generated graph + max_edges : int + The maximum number of edges in a generated graph + illegal_action_logreward : float + The log reward an agent gets for illegal actions + offline_ratio: float + The ratio of samples drawn from `self.training_data` during training. The rest is drawn from + `self.sampling_model` + train_random_action_prob : float + The probability of taking a random action during training + valid_random_action_prob : float + The probability of taking a random action during validation + valid_sample_cond_info : bool + Whether to sample conditioning information during validation (if False, expects a validation set of cond_info) + sampling_tau : float + The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) + """ + + method: str = "TB" + global_batch_size: int = 64 + max_len: int = 128 + max_nodes: int = 128 + max_edges: int = 128 + illegal_action_logreward: float = -100 + offline_ratio: float = 0.5 + train_random_action_prob: float = 0.0 + valid_random_action_prob: float = 0.0 + valid_sample_cond_info: bool = True + sampling_tau: float = 0.0 + tb: TBConfig = TBConfig() + moql: MOQLConfig = MOQLConfig() + a2c: A2CConfig = A2CConfig() + fm: FMConfig = FMConfig() + sql: SQLConfig = SQLConfig() diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 0f563d60..f9fba9d2 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn @@ -8,6 +6,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, GraphBuildingEnv, @@ -15,6 +14,7 @@ generate_forward_trajectory, ) from gflownet.models.graph_transformer import GraphTransformer, mlp +from gflownet.train import GFNTask from .graph_sampling import GraphSampler @@ -165,10 +165,9 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, + task: GFNTask, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Envelope Q-Learning implementation, see A Generalized Algorithm for Multi-Objective Reinforcement Learning and Policy Adaptation, @@ -187,31 +186,28 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env + self.task = task self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.gamma = hps.get("moql_gamma", 1) - self.num_objectives = len(hps["objectives"]) - self.num_omega_samples = hps.get("moql_num_omega_samples", 32) - self.Lambda_decay = hps.get("moql_lambda_decay", 10_000) - self.invalid_penalty = hps.get("moql_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.gamma = cfg.algo.moql.gamma + self.num_objectives = cfg.algo.moql.num_objectives + self.num_omega_samples = cfg.algo.moql.num_omega_samples + self.lambda_decay = cfg.algo.moql.lambda_decay + self.invalid_penalty = cfg.algo.moql.penalty self._num_updates = 0 assert self.gamma == 1 self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -396,7 +392,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # and L_B loss_B = abs((w * y).sum(1) - (w * Q_saw).sum(1)) - Lambda = 1 - self.Lambda_decay / (self.Lambda_decay + self._num_updates) + Lambda = 1 - self.lambda_decay / (self.lambda_decay + self._num_updates) losses = (1 - Lambda) * loss_A + Lambda * loss_B self._num_updates += 1 diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 4614c087..a1e9a393 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - import networkx as nx import numpy as np import torch @@ -8,6 +6,7 @@ from torch_scatter import scatter from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.graph_building_env import ( Graph, GraphAction, @@ -41,17 +40,15 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): - super().__init__(env, ctx, rng, hps, max_len=max_len, max_nodes=max_nodes) - self.fm_epsilon = torch.as_tensor(hps.get("fm_epsilon", 1e-38)).log() + super().__init__(env, ctx, rng, cfg) + self.fm_epsilon = torch.as_tensor(cfg.algo.fm.epsilon).log() # We include the "balanced loss" as a possibility to reproduce results from the FM paper, but # in a number of settings the regular loss is more stable. - self.fm_balanced_loss = hps.get("fm_balanced_loss", False) - self.fm_leaf_coef = hps.get("fm_leaf_coef", 10) - self.correct_idempotent = self.correct_idempotent or hps.get("fm_correct_idempotent", False) + self.fm_balanced_loss = cfg.algo.fm.balanced_loss + self.fm_leaf_coef = cfg.algo.fm.leaf_coef + self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent def construct_batch(self, trajs, cond_info, log_rewards): """Construct a batch from a list of trajectories and their information diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index 7d014124..aa1feef8 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -1,11 +1,10 @@ -from typing import Any, Dict - import numpy as np import torch import torch_geometric.data as gd from torch_scatter import scatter from gflownet.algo.trajectory_balance import TrajectoryBalance, TrajectoryBalanceModel +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext @@ -19,11 +18,9 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): - super().__init__(env, ctx, rng, hps, max_len, max_nodes) + super().__init__(env, ctx, rng, cfg) def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): """Compute multi objective REINFORCE loss over trajectories contained in the batch""" diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 0d49a14b..1e3f1146 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -1,5 +1,3 @@ -from typing import Any, Dict - import numpy as np import torch import torch.nn as nn @@ -8,6 +6,7 @@ from torch_scatter import scatter from gflownet.algo.graph_sampling import GraphSampler +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory @@ -17,16 +16,14 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """Soft Q-Learning implementation, see - xxxxx + Haarnoja, Tuomas, Haoran Tang, Pieter Abbeel, and Sergey Levine. "Reinforcement learning with deep + energy-based policies." In International conference on machine learning, pp. 1352-1361. PMLR, 2017. Hyperparameters used: illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions - sql_alpha: float, the entropy coefficient Parameters ---------- @@ -36,27 +33,23 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + The experiment configuration """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.alpha = hps.get("sql_alpha", 0.01) - self.gamma = hps.get("sql_gamma", 1) - self.invalid_penalty = hps.get("sql_penalty", -10) + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.alpha = cfg.algo.sql.alpha + self.gamma = cfg.algo.sql.gamma + self.invalid_penalty = cfg.algo.sql.penalty self.bootstrap_own_reward = False # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 774848ec..57af6056 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Tuple import networkx as nx import numpy as np @@ -9,6 +9,7 @@ from torch_scatter import scatter, scatter_sum from gflownet.algo.graph_sampling import GraphSampler +from gflownet.config import Config from gflownet.envs.graph_building_env import ( Graph, GraphAction, @@ -18,6 +19,7 @@ GraphBuildingEnvContext, generate_forward_trajectory, ) +from gflownet.train import GFNAlgorithm class TrajectoryBalanceModel(nn.Module): @@ -28,29 +30,24 @@ def logZ(self, cond_info: Tensor) -> Tensor: raise NotImplementedError() -class TrajectoryBalance: - """ """ +class TrajectoryBalance(GFNAlgorithm): + """TB implementation, see + "Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain, + Emmanuel Bengio, Chen Sun, Yoshua Bengio" + https://arxiv.org/abs/2201.13259""" def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, rng: np.random.RandomState, - hps: Dict[str, Any], - max_len=None, - max_nodes=None, + cfg: Config, ): """TB implementation, see "Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio" https://arxiv.org/abs/2201.13259 - Hyperparameters used: - illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions - bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data - tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio - reward_loss_multiplier: float, multiplying constant for the bootstrap loss. - Parameters ---------- env: GraphBuildingEnv @@ -59,22 +56,16 @@ def __init__( A context. rng: np.random.RandomState rng used to take random actions - hps: Dict[str, Any] - Hyperparameter dictionary, see above for used keys. - max_len: int - If not None, ends trajectories of more than max_len steps. - max_nodes: int - If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). + cfg: Config + Hyperparameters """ self.ctx = ctx self.env = env self.rng = rng - self.max_len = max_len - self.max_nodes = max_nodes - self.illegal_action_logreward = hps["illegal_action_logreward"] - self.bootstrap_own_reward = hps["bootstrap_own_reward"] - self.epsilon = hps["tb_epsilon"] - self.reward_loss_multiplier = hps.get("reward_loss_multiplier", 1) + self.global_cfg = cfg + self.cfg = cfg.algo.tb + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes # Experimental flags self.reward_loss_is_mae = True self.tb_loss_is_mae = False @@ -83,22 +74,20 @@ def __init__( self.length_normalize_losses = False self.reward_normalize_losses = False self.sample_temp = 1 - self.is_doing_subTB = hps.get("tb_do_subtb", False) - self.correct_idempotent = hps.get("tb_correct_idempotent", False) - self.p_b_is_parameterized = hps.get("tb_p_b_is_parameterized", False) + self.bootstrap_own_reward = self.cfg.bootstrap_own_reward self.graph_sampler = GraphSampler( ctx, env, - max_len, - max_nodes, + cfg.algo.max_len, + cfg.algo.max_nodes, rng, self.sample_temp, - correct_idempotent=self.correct_idempotent, - pad_with_terminal_state=self.p_b_is_parameterized, + correct_idempotent=self.cfg.do_correct_idempotent, + pad_with_terminal_state=self.cfg.do_parameterize_p_b, ) - if self.is_doing_subTB: - self._subtb_max_len = hps.get("tb_subtb_max_len", max_len + 2 if max_len is not None else 128) + if self.cfg.do_subtb: + self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? def create_training_data_from_own_samples( @@ -152,7 +141,7 @@ def create_training_data_from_graphs(self, graphs): trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ - self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) + self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) for gp, _ in traj["traj"][1:] ] + [1] traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) @@ -228,7 +217,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) batch.log_p_B = torch.cat([i["bck_logprobs"] for i in trajs], 0) batch.actions = torch.tensor(actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: batch.bck_actions = torch.tensor( [ self.ctx.GraphAction_to_aidx(g, a) @@ -239,7 +228,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.log_rewards = log_rewards batch.cond_info = cond_info batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() - if self.correct_idempotent: + if self.cfg.do_correct_idempotent: # Every timestep is a (graph_a, action, graph_b) triple agraphs = [i[0] for tj in trajs for i in tj["traj"]] # Here we start at the 1th timestep and append the result @@ -251,7 +240,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): ] batch.ip_actions = torch.tensor(sum(ipa, [])) batch.ip_lens = torch.tensor([len(i) for i in ipa]) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # Here we start at the 0th timestep and prepend None (it will be unused) bgraphs = sum([[None] + [i[0] for i in tj["traj"][:-1]] for tj in trajs], []) gactions = [i for tj in trajs for i in tj["bck_a"]] @@ -264,7 +253,9 @@ def construct_batch(self, trajs, cond_info, log_rewards): return batch - def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): + def compute_batch_losses( + self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0 # type: ignore[override] + ): """Compute the losses over trajectories contained in the batch Parameters @@ -282,7 +273,9 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n log_rewards = batch.log_rewards # Clip rewards assert log_rewards.ndim == 1 - clip_log_R = torch.maximum(log_rewards, torch.tensor(self.illegal_action_logreward, device=dev)).float() + clip_log_R = torch.maximum( + log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev) + ).float() cond_info = batch.cond_info invalid_mask = 1 - batch.is_valid @@ -295,7 +288,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) else: fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) @@ -306,7 +299,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # Compute trajectory balance objective log_Z = model.logZ(cond_info)[:, 0] # Compute the log prob of each action in the trajectory - if self.correct_idempotent: + if self.cfg.do_correct_idempotent: # If we want to correct for idempotent actions, we need to sum probabilities # i.e. to compute P(s' | s) = sum_{a that lead to s'} P(a|s) # here we compute the indices of the graph that each action corresponds to, ip_lens @@ -322,7 +315,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # scatter(small number) = 0 on CUDA log_p_F = p.clamp(1e-30).log() - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # Now we repeat this but for the backward policy bck_ip_batch_idces = torch.arange(batch.bck_ip_lens.shape[0], device=dev).repeat_interleave( batch.bck_ip_lens @@ -335,14 +328,14 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n else: # Else just naively take the logprob of the actions we took log_p_F = fwd_cat.log_prob(batch.actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: log_p_B = bck_cat.log_prob(batch.bck_actions) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # If we're modeling P_B then trajectories are padded with a virtual terminal state sF, # zero-out the logP_F of those states log_p_F[final_graph_idx] = 0 - if self.is_doing_subTB: + if self.cfg.do_subtb: # Force the pad states' F(s) prediction to be R per_graph_out[final_graph_idx, 0] = clip_log_R @@ -365,7 +358,7 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n traj_log_p_F = scatter(log_p_F, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") traj_log_p_B = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") - if self.is_doing_subTB: + if self.cfg.do_subtb: # SubTB interprets the per_graph_out predictions to predict the state flow F(s) traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) # The position of the first graph of each trajectory @@ -384,9 +377,9 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # (thus the `numerator - 1`). Why 1? Intuition? denominator = denominator * (1 - invalid_mask) + invalid_mask * (numerator.detach() - 1) - if self.epsilon is not None: + if self.cfg.epsilon is not None: # Numerical stability epsilon - epsilon = torch.tensor([self.epsilon], device=dev).float() + epsilon = torch.tensor([self.cfg.epsilon], device=dev).float() numerator = torch.logaddexp(numerator, epsilon) denominator = torch.logaddexp(denominator, epsilon) if self.tb_loss_is_mae: @@ -409,17 +402,17 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n # undercount (by 2N) the contribution of each loss traj_losses = factor * traj_losses * num_trajs - if self.bootstrap_own_reward: + if self.cfg.bootstrap_own_reward: num_bootstrap = num_bootstrap or len(log_rewards) if self.reward_loss_is_mae: reward_losses = abs(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]) else: reward_losses = (log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]).pow(2) - reward_loss = reward_losses.mean() + reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier else: reward_loss = 0 - loss = traj_losses.mean() + reward_loss * self.reward_loss_multiplier + loss = traj_losses.mean() + reward_loss info = { "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, @@ -509,7 +502,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): for ep in range(traj_lengths.shape[0]): offset = cumul_lens[ep] T = int(traj_lengths[ep]) - if self.p_b_is_parameterized: + if self.cfg.do_parameterize_p_b: # The length of the trajectory is the padded length, reduce by 1 T -= 1 idces, dests = self._precomp[T - 1] diff --git a/src/gflownet/config.py b/src/gflownet/config.py new file mode 100644 index 00000000..8de0ebeb --- /dev/null +++ b/src/gflownet/config.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from typing import Optional + +from omegaconf import MISSING + +from gflownet.algo.config import AlgoConfig +from gflownet.data.config import ReplayConfig +from gflownet.models.config import ModelConfig +from gflownet.tasks.config import TasksConfig + + +@dataclass +class OptimizerConfig: + """Generic configuration for optimizers + + Attributes + ---------- + opt : str + The optimizer to use (either "adam" or "sgd") + learning_rate : float + The learning rate + lr_decay : float + The learning rate decay (in steps, f = 2 ** (-steps / self.cfg.opt.lr_decay)) + weight_decay : float + The L2 weight decay + momentum : float + The momentum parameter value + clip_grad_type : str + The type of gradient clipping to use (either "norm" or "value") + clip_grad_param : float + The parameter for gradient clipping + adam_eps : float + The epsilon parameter for Adam + """ + + opt: str = "adam" + learning_rate: float = 1e-4 + lr_decay: float = 20_000 + weight_decay: float = 1e-8 + momentum: float = 0.9 + clip_grad_type: str = "norm" + clip_grad_param: float = 10.0 + adam_eps: float = 1e-8 + + +@dataclass +class Config: + """Base configuration for training + + Attributes + ---------- + log_dir : str + The directory where to store logs, checkpoints, and samples. + seed : int + The random seed + validate_every : int + The number of training steps after which to validate the model + checkpoint_every : Optional[int] + The number of training steps after which to checkpoint the model + start_at_step : int + The training step to start at (default: 0) + num_final_gen_steps : Optional[int] + After training, the number of steps to generate graphs for + num_training_steps : int + The number of training steps + num_workers : int + The number of workers to use for creating minibatches (0 = no multiprocessing) + hostname : Optional[str] + The hostname of the machine on which the experiment is run + pickle_mp_messages : bool + Whether to pickle messages sent between processes (only relevant if num_workers > 0) + git_hash : Optional[str] + The git hash of the current commit + overwrite_existing_exp : bool + Whether to overwrite the contents of the log_dir if it already exists + """ + + log_dir: str = MISSING + seed: int = 0 + validate_every: int = 1000 + checkpoint_every: Optional[int] = None + start_at_step: int = 0 + num_final_gen_steps: Optional[int] = None + num_training_steps: int = 10_000 + num_workers: int = 0 + hostname: Optional[str] = None + pickle_mp_messages: bool = False + git_hash: Optional[str] = None + overwrite_existing_exp: bool = True + algo: AlgoConfig = AlgoConfig() + model: ModelConfig = ModelConfig() + opt: OptimizerConfig = OptimizerConfig() + replay: ReplayConfig = ReplayConfig() + task: TasksConfig = TasksConfig() diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py new file mode 100644 index 00000000..fab5d036 --- /dev/null +++ b/src/gflownet/data/config.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ReplayConfig: + """Replay buffer configuration + + Attributes + ---------- + use : bool + Whether to use a replay buffer + capacity : int + The capacity of the replay buffer + warmup : int + The number of samples to collect before starting to sample from the replay buffer + hindsight_ratio : float + The ratio of hindsight samples within a batch + """ + + use: bool = False + capacity: Optional[int] = None + warmup: Optional[int] = None + hindsight_ratio: float = 0 diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index 28dc7f78..c68ead98 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -40,3 +40,19 @@ def __len__(self): def __getitem__(self, idx): return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]]) + + +def convert_h5(): + # File obtained from + # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 + # (from http://quantum-machine.org/datasets/) + f = tarfile.TarFile("qm9.xyz.tar", "r") + labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] + all_mols = [] + for pt in f: + pt = f.extractfile(pt) # type: ignore + data = pt.read().decode().splitlines() # type: ignore + all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) + df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) + store = pd.HDFStore("qm9.h5", "w") + store["df"] = df diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 2f27de0a..541656fd 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -3,11 +3,13 @@ import numpy as np import torch +from gflownet.config import Config + class ReplayBuffer(object): - def __init__(self, capacity: int = 100000, warmup: int = 0, rng: np.random.Generator = None): - self.capacity = capacity - self.warmup = warmup + def __init__(self, cfg: Config, rng: np.random.Generator = None): + self.capacity = cfg.replay.capacity + self.warmup = cfg.replay.warmup assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" self.buffer: List[tuple] = [] diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 2de3f21a..b69591cc 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -28,11 +28,12 @@ def __init__( self, dataset: Dataset, model: nn.Module, - batch_size: int, ctx, algo, task, device, + batch_size: int = 1, + illegal_action_logreward: float = -50, ratio: float = 0.5, stream: bool = True, replay_buffer: ReplayBuffer = None, @@ -49,14 +50,21 @@ def __init__( model: nn.Module The model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker) + ctx: + The context for the environment, e.g. a MolBuildingEnvContext instance + algo: + The training algorithm, e.g. a TrajectoryBalance instance + task: GFNTask + A Task instance, e.g. a MakeRingsTask instance + device: torch.device + The device the model is on replay_buffer: ReplayBuffer The replay buffer for training on past data batch_size: int The number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task) - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: ConditionalTask + illegal_action_logreward: float + The logreward for invalid trajectories ratio: float The ratio of offline trajectories in the batch. stream: bool @@ -67,14 +75,18 @@ def __init__( sample_cond_info: bool If True (default), then the dataset is a dataset of points used in offline training. If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - + random_action_prob: float + The probability of taking a random action, passed to the graph sampler + init_train_iter: int + The initial training iteration, incremented and passed to task.sample_conditional_information """ self.data = dataset self.model = model self.replay_buffer = replay_buffer self.batch_size = batch_size - self.offline_batch_size = int(np.ceil(batch_size * ratio)) - self.online_batch_size = int(np.floor(batch_size * (1 - ratio))) + self.illegal_action_logreward = illegal_action_logreward + self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) + self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) self.ratio = ratio self.ctx = ctx self.algo = algo @@ -92,7 +104,7 @@ def __init__( # then "offline" now refers to cond info and online to x, so no duplication and we don't end # up with 2*batch_size accidentally if not sample_cond_info: - self.offline_batch_size = self.online_batch_size = batch_size + self.offline_batch_size = self.online_batch_size = self.batch_size # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we # don't want to initialize per-worker things just yet, such as where the log the worker writes @@ -222,7 +234,7 @@ def __iter__(self): # Compute scalar rewards from conditional information & flat rewards flat_rewards = torch.stack(flat_rewards) log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.algo.illegal_action_logreward + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward # Computes some metrics if not self.sample_cond_info: @@ -293,7 +305,7 @@ def __iter__(self): cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( cond_info, log_rewards, flat_rewards, hindsight_idxs ) - log_rewards[torch.logical_not(is_valid)] = self.algo.illegal_action_logreward + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward # Construct batch batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py new file mode 100644 index 00000000..833a2bba --- /dev/null +++ b/src/gflownet/models/config.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + + +@dataclass +class GraphTransformerConfig: + num_heads: int = 2 + ln_type: str = "pre" + num_mlp_layers: int = 0 + + +@dataclass +class ModelConfig: + """Generic configuration for models + + Attributes + ---------- + num_layers : int + The number of layers in the model + num_emb : int + The number of dimensions of the embedding + """ + + num_layers: int = 3 + num_emb: int = 128 + graph_transformer: GraphTransformerConfig = GraphTransformerConfig() diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 6dea00a2..097e7e91 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -6,6 +6,7 @@ import torch_geometric.nn as gnn from torch_geometric.utils import add_self_loops +from gflownet.config import Config from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType @@ -139,8 +140,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): class GraphTransformerGFN(nn.Module): - """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Meant for atom-wise - generation. + """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Outputs logits corresponding to the action types used by the env_ctx argument. """ @@ -148,11 +148,7 @@ class GraphTransformerGFN(nn.Module): def __init__( self, env_ctx, - num_emb=64, - num_layers=3, - num_heads=2, - num_mlp_layers=0, - ln_type="pre", + cfg: Config, num_graph_out=1, do_bck=False, ): @@ -162,11 +158,12 @@ def __init__( x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim, g_dim=env_ctx.num_cond_dim, - num_emb=num_emb, - num_layers=num_layers, - num_heads=num_heads, - ln_type=ln_type, + num_emb=cfg.model.num_emb, + num_layers=cfg.model.num_layers, + num_heads=cfg.model.graph_transformer.num_heads, + ln_type=cfg.model.graph_transformer.ln_type, ) + num_emb = cfg.model.num_emb num_final = num_emb num_glob_final = num_emb * 2 num_edge_feat = num_emb if env_ctx.edges_are_unordered else num_emb * 2 @@ -212,14 +209,14 @@ def __init__( mlps = {} for atype in chain(env_ctx.action_type_order, env_ctx.bck_action_type_order if do_bck else []): num_in, num_out = self._action_type_to_num_inputs_outputs[atype] - mlps[atype.cname] = mlp(num_in, num_emb, num_out, num_mlp_layers) + mlps[atype.cname] = mlp(num_in, num_emb, num_out, cfg.model.graph_transformer.num_mlp_layers) self.mlps = nn.ModuleDict(mlps) self.do_bck = do_bck if do_bck: self.bck_action_type_order = env_ctx.bck_action_type_order - self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, num_mlp_layers) + self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py new file mode 100644 index 00000000..bc0e31d0 --- /dev/null +++ b/src/gflownet/tasks/config.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional, Tuple + + +@dataclass +class SEHTaskConfig: + """Config for the SEHTask + + Attributes + ---------- + + temperature_sample_dist : str + The distribution to sample the inverse temperature from. Can be one of: + - "uniform": uniform distribution + - "loguniform": log-uniform distribution + - "gamma": gamma distribution + - "constant": constant temperature + temperature_dist_params : List[Any] + The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. + num_thermometer_dim : int + The number of thermometer encoding dimensions to use. + """ + + # TODO: a proper class for temperature-conditional sampling + temperature_sample_dist: str = "uniform" + temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + + +@dataclass +class SEHMOOTaskConfig: + """Config for the SEHMOOTask + + Attributes + ---------- + use_steer_thermometer : bool + Whether to use a thermometer encoding for the steering. + preference_type : Optional[str] + The preference sampling distribution, defaults to "dirichlet". + focus_type : Union[list, str, None] + The type of focus distribtuion used, see SEHMOOTask.setup_focus_regions. + focus_cosim : float + The cosine similarity threshold for the focus distribution. + focus_limit_coef : float + The smoothing coefficient for the focus reward. + focus_model_training_limits : Optional[Tuple[int, int]] + The training limits for the focus sampling model (if used). + focus_model_state_space_res : Optional[int] + The state space resolution for the focus sampling model (if used). + max_train_it : Optional[int] + The maximum number of training iterations for the focus sampling model (if used). + n_valid : int + The number of valid cond_info tensors to sample + n_valid_repeats : int + The number of times to repeat the valid cond_info tensors + objectives : List[str] + The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. + """ + + # TODO: a proper class for temperature-conditional sampling + temperature_sample_dist: str = "uniform" + temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + use_steer_thermometer: bool = False + preference_type: Optional[str] = "dirichlet" + focus_type: Optional[str] = None + focus_dirs_listed: Optional[List[List[float]]] = None + focus_cosim: float = 0.0 + focus_limit_coef: float = 1.0 + focus_model_training_limits: Optional[Tuple[int, int]] = None + focus_model_state_space_res: Optional[int] = None + max_train_it: Optional[int] = None + n_valid: int = 15 + n_valid_repeats: int = 128 + objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "wt"]) + + +@dataclass +class QM9TaskConfig: + # TODO: a proper class for temperature-conditional sampling + temperature_sample_dist: str = "uniform" + temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + h5_path = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + + +@dataclass +class TasksConfig: + qm9: QM9TaskConfig = QM9TaskConfig() + seh: SEHTaskConfig = SEHTaskConfig() + seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 39dd8309..5405d867 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,7 +1,6 @@ -import ast import copy import os -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np import scipy.stats as stats @@ -16,6 +15,7 @@ import gflownet.models.mxmnet as mxmnet from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext @@ -31,7 +31,7 @@ def __init__( self, dataset: Dataset, temperature_distribution: str, - temperature_parameters: Tuple[float, float], + temperature_parameters: List[float], num_thermometer_dim: int, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, @@ -90,14 +90,16 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: beta = self.rng.gamma(loc, scale, n).astype(np.float32) upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) elif self.temperature_sample_dist == "uniform": - beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32) + a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) + beta = self.rng.uniform(a, b, n).astype(np.float32) upper_bound = self.temperature_dist_params[1] elif self.temperature_sample_dist == "loguniform": low, high = np.log(self.temperature_dist_params) beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) upper_bound = self.temperature_dist_params[1] elif self.temperature_sample_dist == "beta": - beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32) + a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) + beta = self.rng.beta(a, b, n).astype(np.float32) upper_bound = 1 beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) @@ -127,79 +129,73 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class QM9GapTrainer(GFNTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - "bootstrap_own_reward": False, - "learning_rate": 1e-4, - "global_batch_size": 64, - "num_emb": 128, - "num_layers": 4, - "tb_epsilon": None, - "illegal_action_logreward": -75, - "reward_loss_multiplier": 1, - "temperature_sample_dist": "uniform", - "temperature_dist_params": (0.5, 32.0), - "weight_decay": 1e-8, - "num_data_loader_workers": 8, - "momentum": 0.9, - "adam_eps": 1e-8, - "lr_decay": 20000, - "Z_lr_decay": 20000, - "clip_grad_type": "norm", - "clip_grad_param": 10, - "random_action_prob": 0.001, - "sampling_tau": 0.0, - "num_thermometer_dim": 32, - } + def set_default_hps(self, cfg: Config): + cfg.num_workers = 8 + cfg.num_training_steps = 100000 + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20000 + cfg.opt.clip_grad_type = "norm" + cfg.opt.clip_grad_param = 10 + cfg.algo.global_batch_size = 64 + cfg.algo.train_random_action_prob = 0.001 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.sampling_tau = 0.0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + cfg.task.qm9.temperature_sample_dist = "uniform" + cfg.task.qm9.temperature_dist_params = [0.5, 32.0] + cfg.task.qm9.num_thermometer_dim = 32 def setup(self): - hps = self.hps RDLogger.DisableLog("rdApp.*") self.rng = np.random.default_rng(142857) self.env = GraphBuildingEnv() self.ctx = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32) - self.training_data = QM9Dataset(hps["qm9_h5_path"], train=True, target="gap") - self.test_data = QM9Dataset(hps["qm9_h5_path"], train=False, target="gap") + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") - model = GraphTransformerGFN(self.ctx, num_emb=hps["num_emb"], num_layers=hps["num_layers"]) + model = GraphTransformerGFN(self.ctx, self.cfg) self.model = model # Separate Z parameters from non-Z to allow for LR decay on the former Z_params = list(model.logZ.parameters()) non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] self.opt = torch.optim.Adam( non_Z_params, - hps["learning_rate"], - (hps["momentum"], 0.999), - weight_decay=hps["weight_decay"], - eps=hps["adam_eps"], + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + self.opt_Z = torch.optim.Adam(Z_params, self.cfg.opt.learning_rate, (0.9, 0.999)) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay) ) - self.opt_Z = torch.optim.Adam(Z_params, hps["learning_rate"], (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / hps["lr_decay"])) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(self.opt_Z, lambda steps: 2 ** (-steps / hps["Z_lr_decay"])) - self.sampling_tau = hps["sampling_tau"] + self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: self.sampling_model = copy.deepcopy(model) else: self.sampling_model = self.model - eps = hps["tb_epsilon"] - hps["tb_epsilon"] = ast.literal_eval(eps) if isinstance(eps, str) else eps - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, hps, max_nodes=9) + self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) self.task = QM9GapTask( dataset=self.training_data, - temperature_distribution=hps["temperature_sample_dist"], - temperature_parameters=hps["temperature_dist_params"], - num_thermometer_dim=hps["num_thermometer_dim"], + temperature_distribution=self.cfg.task.qm9.temperature_sample_dist, + temperature_parameters=self.cfg.task.qm9.temperature_dist_params, + num_thermometer_dim=self.cfg.task.qm9.num_thermometer_dim, wrap_model=self._wrap_for_mp, ) - self.mb_size = hps["global_batch_size"] - self.clip_grad_param = hps["clip_grad_param"] + self.mb_size = self.cfg.algo.global_batch_size + self.clip_grad_param = self.cfg.opt.clip_grad_param self.clip_grad_callback = { - "value": (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)), - "norm": (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)), - "none": (lambda x: None), - }[hps["clip_grad_type"]] + "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param), + "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param), + "none": lambda x: None, + }[self.cfg.opt.clip_grad_type] def step(self, loss: Tensor): loss.backward() diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 839be788..10fa142c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,11 +1,9 @@ -import ast import copy -import json import os import pathlib import shutil import socket -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import git import numpy as np @@ -13,13 +11,17 @@ import torch import torch.nn as nn import torch_geometric.data as gd +from omegaconf import OmegaConf from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset +from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.soft_q_learning import SoftQLearning from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv @@ -36,15 +38,13 @@ class SEHTask(GFNTask): The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. This setup essentially reproduces the results of the Trajectory Balance paper when using the TB - objective, or of the original paper when using Flow Matching (TODO: port to this repo). + objective, or of the original paper when using Flow Matching. """ def __init__( self, dataset: Dataset, - temperature_distribution: str, - temperature_parameters: Tuple[float, float], - num_thermometer_dim: int, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): @@ -52,9 +52,9 @@ def __init__( self.rng = rng self.models = self._load_task_models() self.dataset = dataset - self.temperature_sample_dist = temperature_distribution - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim + self.temperature_sample_dist = cfg.task.seh.temperature_sample_dist + self.temperature_dist_params = cfg.task.seh.temperature_dist_params + self.num_thermometer_dim = cfg.task.seh.num_thermometer_dim def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -79,14 +79,16 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten beta = self.rng.gamma(loc, scale, n).astype(np.float32) upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) elif self.temperature_sample_dist == "uniform": - beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32) + a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) + beta = self.rng.uniform(a, b, n).astype(np.float32) upper_bound = self.temperature_dist_params[1] elif self.temperature_sample_dist == "loguniform": low, high = np.log(self.temperature_dist_params) beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) upper_bound = self.temperature_dist_params[1] elif self.temperature_sample_dist == "beta": - beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32) + a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) + beta = self.rng.beta(a, b, n).astype(np.float32) upper_bound = 1 beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) @@ -116,71 +118,74 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class SEHFragTrainer(GFNTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - "hostname": socket.gethostname(), - "bootstrap_own_reward": False, - "learning_rate": 1e-4, - "Z_learning_rate": 1e-3, - "global_batch_size": 64, - "num_emb": 128, - "num_layers": 4, - "tb_epsilon": None, - "tb_p_b_is_parameterized": False, - "illegal_action_logreward": -75, - "reward_loss_multiplier": 1, - "temperature_sample_dist": "uniform", - "temperature_dist_params": (0.5, 32.0), - "weight_decay": 1e-8, - "num_data_loader_workers": 8, - "momentum": 0.9, - "adam_eps": 1e-8, - "lr_decay": 20000, - "Z_lr_decay": 50000, - "clip_grad_type": "norm", - "clip_grad_param": 10, - "random_action_prob": 0.01, - "valid_random_action_prob": 0.0, - "sampling_tau": 0.0, - "max_nodes": 9, - "num_thermometer_dim": 32, - "use_replay_buffer": False, - "replay_buffer_size": 10000, - "replay_buffer_warmup": 10000, - "mp_pickle_messages": False, - "algo": "TB", - } + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.pickle_mp_messages = False + cfg.num_workers = 8 + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20_000 + cfg.opt.clip_grad_type = "norm" + cfg.opt.clip_grad_param = 10 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 9 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.tb.epsilon = None + cfg.algo.tb.bootstrap_own_reward = False + cfg.algo.tb.Z_learning_rate = 1e-3 + cfg.algo.tb.Z_lr_decay = 50_000 + cfg.algo.tb.do_parameterize_p_b = False + + cfg.replay.use = False + cfg.replay.capacity = 10_000 + cfg.replay.warmup = 1_000 def setup_algo(self): - algo = self.hps.get("algo", "TB") + algo = self.cfg.algo.method if algo == "TB": algo = TrajectoryBalance elif algo == "FM": algo = FlowMatching + elif algo == "A2C": + algo = A2C + elif algo == "SQL": + algo = SoftQLearning else: raise ValueError(algo) - self.algo = algo(self.env, self.ctx, self.rng, self.hps, max_nodes=self.hps["max_nodes"]) + self.algo = algo(self.env, self.ctx, self.rng, self.cfg) def setup_task(self): self.task = SEHTask( dataset=self.training_data, - temperature_distribution=self.hps["temperature_sample_dist"], - temperature_parameters=self.hps["temperature_dist_params"], + cfg=self.cfg, rng=self.rng, - num_thermometer_dim=self.hps["num_thermometer_dim"], wrap_model=self._wrap_for_mp, ) def setup_model(self): - self.model = GraphTransformerGFN(self.ctx, num_emb=self.hps["num_emb"], num_layers=self.hps["num_layers"]) + model = GraphTransformerGFN( + self.ctx, + self.cfg, + do_bck=self.cfg.algo.tb.do_parameterize_p_b, + ) + self.model = model def setup_env_context(self): self.ctx = FragMolBuildingEnvContext( - max_frags=self.hps["max_nodes"], num_cond_dim=self.hps["num_thermometer_dim"] + max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.cfg.task.seh.num_thermometer_dim ) def setup(self): - hps = self.hps RDLogger.DisableLog("rdApp.*") self.rng = np.random.default_rng(142857) self.env = GraphBuildingEnv() @@ -188,11 +193,7 @@ def setup(self): self.test_data = [] self.offline_ratio = 0 self.valid_offline_ratio = 0 - self.replay_buffer = ( - ReplayBuffer(self.hps["replay_buffer_size"], self.hps["replay_buffer_warmup"], self.rng) - if self.hps["use_replay_buffer"] - else None - ) + self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None self.setup_env_context() self.setup_algo() self.setup_task() @@ -203,40 +204,40 @@ def setup(self): non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] self.opt = torch.optim.Adam( non_Z_params, - hps["learning_rate"], - (hps["momentum"], 0.999), - weight_decay=hps["weight_decay"], - eps=hps["adam_eps"], + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) ) - self.opt_Z = torch.optim.Adam(Z_params, hps["Z_learning_rate"], (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / hps["lr_decay"])) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(self.opt_Z, lambda steps: 2 ** (-steps / hps["Z_lr_decay"])) - self.sampling_tau = hps["sampling_tau"] + self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: self.sampling_model = copy.deepcopy(self.model) else: self.sampling_model = self.model - eps = hps["tb_epsilon"] - hps["tb_epsilon"] = ast.literal_eval(eps) if isinstance(eps, str) else eps - self.mb_size = hps["global_batch_size"] - self.clip_grad_param = hps["clip_grad_param"] + self.mb_size = self.cfg.algo.global_batch_size self.clip_grad_callback = { - "value": (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)), - "norm": (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)), - "none": (lambda x: None), - }[hps["clip_grad_type"]] + "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "none": lambda x: None, + }[self.cfg.opt.clip_grad_type] # saving hyperparameters git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.hps["gflownet_git_hash"] = git_hash + self.cfg.git_hash = git_hash - os.makedirs(self.hps["log_dir"], exist_ok=True) - fmt_hps = "\n".join([f"{f'{k}':40}:\t{f'({type(v).__name__})':10}\t{v}" for k, v in sorted(self.hps.items())]) - print(f"\n\nHyperparameters:\n{'-'*50}\n{fmt_hps}\n{'-'*50}\n\n") - with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as f: - json.dump(self.hps, f) + os.makedirs(self.cfg.log_dir, exist_ok=True) + print("\n\nHyperparameters:\n") + yaml = OmegaConf.to_yaml(self.cfg) + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + f.write(yaml) def step(self, loss: Tensor): loss.backward() @@ -258,13 +259,11 @@ def main(): hps = { "log_dir": "./logs/debug_run", "overwrite_existing_exp": True, - "qm9_h5_path": "/data/chem/qm9/qm9.h5", "num_training_steps": 10_000, - "validate_every": 1, - "lr_decay": 20000, - "sampling_tau": 0.99, - "num_data_loader_workers": 8, - "temperature_dist_params": (0.0, 64.0), + "num_workers": 8, + "opt.lr_decay": 20000, + "algo.sampling_tau": 0.99, + "task.seh.temperature_dist_params": (0.0, 64.0), } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 1373d3d8..c515bc6e 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,9 +1,8 @@ -import json import os import pathlib import shutil from copy import deepcopy -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -15,14 +14,11 @@ from torch.distributions.dirichlet import Dirichlet from torch.utils.data import Dataset -from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -from gflownet.algo.soft_q_learning import SoftQLearning -from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow -from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.train import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore @@ -43,45 +39,69 @@ class SEHMOOTask(SEHTask): def __init__( self, - objectives: List[str], dataset: Dataset, - temperature_sample_dist: str, - temperature_parameters: Tuple[float, float], - num_thermometer_dim: int, - use_steer_thermometer: bool = False, - preference_type: str = None, - focus_type: Union[list, str] = None, - focus_cosim: float = 0.0, - focus_limit_coef: float = 1.0, - fixed_focus_dirs: torch.Tensor = None, - illegal_action_logreward: float = None, - focus_model: FocusModel = None, - focus_model_training_limits: Tuple[int, int] = None, - max_train_it: int = None, + cfg: Config, rng: np.random.Generator = None, + focus_model: Optional[FocusModel] = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): self._wrap_model = wrap_model + self.cfg = cfg + mcfg = self.cfg.task.seh_moo self.rng = rng self.models = self._load_task_models() - self.objectives = objectives + self.objectives = cfg.task.seh_moo.objectives self.dataset = dataset - self.temperature_sample_dist = temperature_sample_dist - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim - self.use_steer_thermometer = use_steer_thermometer - self.preference_type = preference_type + self.temperature_sample_dist = mcfg.temperature_sample_dist + self.temperature_dist_params = mcfg.temperature_dist_params + self.num_thermometer_dim = mcfg.num_thermometer_dim + self.use_steer_thermometer = mcfg.use_steer_thermometer + self.preference_type = mcfg.preference_type self.seeded_preference = None self.experimental_dirichlet = False - self.focus_type = focus_type - self.focus_cosim = focus_cosim - self.focus_limit_coef = focus_limit_coef - self.fixed_focus_dirs = fixed_focus_dirs - self.illegal_action_logreward = illegal_action_logreward + self.focus_type = mcfg.focus_type + self.focus_cosim = mcfg.focus_cosim + self.focus_limit_coef = mcfg.focus_limit_coef self.focus_model = focus_model - self.focus_model_training_limits = focus_model_training_limits - self.max_train_it = max_train_it - assert set(objectives) <= {"seh", "qed", "sa", "mw"} and len(objectives) == len(set(objectives)) + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.focus_model_training_limits = mcfg.focus_model_training_limits + self.max_train_it = mcfg.max_train_it + self.setup_focus_regions() + assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) + + def setup_focus_regions(self): + mcfg = self.cfg.task.seh_moo + n_valid = mcfg.n_valid + n_obj = len(self.objectives) + # focus regions + if mcfg.focus_type is None: + valid_focus_dirs = np.zeros((n_valid, n_obj)) + self.fixed_focus_dirs = valid_focus_dirs + elif mcfg.focus_type == "centered": + valid_focus_dirs = np.ones((n_valid, n_obj)) + self.fixed_focus_dirs = valid_focus_dirs + elif mcfg.focus_type == "partitioned": + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") + self.fixed_focus_dirs = valid_focus_dirs + elif mcfg.focus_type in ["dirichlet", "learned-gfn"]: + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") + self.fixed_focus_dirs = None + elif mcfg.focus_type in ["hyperspherical", "learned-tabular"]: + valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") + self.fixed_focus_dirs = None + elif mcfg.focus_type == "listed": + if len(mcfg.focus_type) == 1: + valid_focus_dirs = np.array([mcfg.focus_dirs_listed[0]] * n_valid) + self.fixed_focus_dirs = valid_focus_dirs + else: + valid_focus_dirs = np.array(mcfg.focus_dirs_listed) + self.fixed_focus_dirs = valid_focus_dirs + else: + raise NotImplementedError( + f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " + f"focus_type, but here: {mcfg.focus_type}" + ) + self.valid_focus_dirs = valid_focus_dirs def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y)) @@ -262,44 +282,29 @@ def safe(f, x, default): class SEHMOOFragTrainer(SEHFragTrainer): - def default_hps(self) -> Dict[str, Any]: - return { - **super().default_hps(), - "use_fixed_weight": False, - "objectives": ["seh", "qed", "sa", "mw"], - "sampling_tau": 0.95, - "valid_sample_cond_info": False, - "n_valid": 15, - "n_valid_repeats": 128, - "preference_type": "dirichlet", - "focus_type": None, - "focus_cosim": None, - "focus_limit_coef": 1.0, - "hindsight_ratio": 0.0, - "focus_model_training_limits": None, - "focus_model_state_space_res": None, - "use_steer_thermometer": False, - } + task: SEHMOOTask + + def set_default_hps(self, cfg: Config): + super().set_default_hps(cfg) + cfg.algo.sampling_tau = 0.95 + cfg.algo.valid_sample_cond_info = False def setup_algo(self): - hps = self.hps - if hps["algo"] == "TB": - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "SQL": - self.algo = SoftQLearning(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "A2C": - self.algo = A2C(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "MOREINFORCE": - self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - elif hps["algo"] == "MOQL": - self.algo = EnvelopeQLearning(self.env, self.ctx, self.rng, hps, max_nodes=self.hps["max_nodes"]) - - if hps["focus_type"] is not None and "learned" in hps["focus_type"]: - if hps["focus_type"] == "learned-tabular": + algo = self.cfg.algo.method + if algo == "MOREINFORCE": + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + elif algo == "MOQL": + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + else: + super().setup_algo() + + focus_type = self.cfg.task.seh_moo.focus_type + if focus_type is not None and "learned" in focus_type: + if focus_type == "learned-tabular": self.focus_model = TabularFocusModel( device=self.device, - n_objectives=len(hps["objectives"]), - state_space_res=hps["focus_model_state_space_res"], + n_objectives=len(self.cfg.task.seh_moo.objectives), + state_space_res=self.cfg.task.seh_moo.focus_model_state_space_res, ) else: raise NotImplementedError("Unknown focus model type {self.focus_type}") @@ -308,143 +313,95 @@ def setup_algo(self): def setup_task(self): self.task = SEHMOOTask( - objectives=self.hps["objectives"], dataset=self.training_data, - temperature_sample_dist=self.hps["temperature_sample_dist"], - temperature_parameters=self.hps["temperature_dist_params"], - num_thermometer_dim=self.hps["num_thermometer_dim"], - use_steer_thermometer=self.hps["use_steer_thermometer"], - preference_type=self.hps["preference_type"], - focus_type=self.hps["focus_type"], - focus_cosim=self.hps["focus_cosim"], - focus_limit_coef=self.hps["focus_limit_coef"], - illegal_action_logreward=self.hps["illegal_action_logreward"], + cfg=self.cfg, focus_model=self.focus_model, - focus_model_training_limits=self.hps["focus_model_training_limits"], - max_train_it=self.hps["num_training_steps"], rng=self.rng, wrap_model=self._wrap_for_mp, ) - def setup_model(self): - if self.hps["algo"] == "MOQL": - model = GraphTransformerFragEnvelopeQL( - self.ctx, - num_emb=self.hps["num_emb"], - num_layers=self.hps["num_layers"], - num_objectives=len(self.hps["objectives"]), - ) + def setup_env_context(self): + if self.cfg.task.seh_moo.use_steer_thermometer: + ncd = self.cfg.task.seh_moo.num_thermometer_dim * (1 + 2 * len(self.cfg.task.seh_moo.objectives)) else: - model = GraphTransformerGFN( + # 1 for prefs and 1 for focus region + ncd = self.cfg.task.seh_moo.num_thermometer_dim + 2 * len(self.cfg.task.seh_moo.objectives) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=ncd) + + def setup_model(self): + if self.cfg.algo.method == "MOQL": + self.model = GraphTransformerFragEnvelopeQL( self.ctx, - num_emb=self.hps["num_emb"], - num_layers=self.hps["num_layers"], - do_bck=self.hps["tb_p_b_is_parameterized"], + num_emb=self.cfg.model.num_emb, + num_layers=self.cfg.model.num_layers, + num_heads=self.cfg.model.graph_transformer.num_heads, + num_objectives=len(self.cfg.task.seh_moo.objectives), ) - - if self.hps["algo"] in ["A2C", "MOQL"]: - model.do_mask = False - self.model = model - - def setup_env_context(self): - if self.hps.get("use_steer_thermometer", False): - ncd = self.hps["num_thermometer_dim"] * (1 + 2 * len(self.hps["objectives"])) else: - ncd = self.hps["num_thermometer_dim"] + 2 * len( - self.hps["objectives"] - ) # 1 for prefs and 1 for focus region - self.ctx = FragMolBuildingEnvContext(max_frags=self.hps["max_nodes"], num_cond_dim=ncd) + super().setup_model() def setup(self): super().setup() self.sampling_hooks.append( MultiObjectiveStatsHook( 256, - self.hps["log_dir"], + self.cfg.log_dir, compute_igd=True, compute_pc_entropy=True, - compute_focus_accuracy=True if self.hps["focus_type"] is not None else False, - focus_cosim=self.hps["focus_cosim"], + compute_focus_accuracy=True if self.cfg.task.seh_moo.focus_type is not None else False, + focus_cosim=self.cfg.task.seh_moo.focus_cosim, ) ) # instantiate preference and focus conditioning vectors for validation - n_obj = len(self.hps["objectives"]) - n_valid = self.hps["n_valid"] + tcfg = self.cfg.task.seh_moo + n_obj = len(tcfg.objectives) # making sure hyperparameters for preferences and focus regions are consistent if not ( - self.hps["focus_type"] is None - or self.hps["focus_type"] == "centered" - or (type(self.hps["focus_type"]) is list and len(self.hps["focus_type"]) == 1) + tcfg.focus_type is None + or tcfg.focus_type == "centered" + or (type(tcfg.focus_type) is list and len(tcfg.focus_type) == 1) ): - assert self.hps["preference_type"] is None, ( - f"Cannot use preferences with multiple focus regions, here focus_type={self.hps['focus_type']} " - f"and preference_type={self.hps['preference_type']}" + assert tcfg.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " + f"and preference_type={tcfg.preference_type}" ) - if type(self.hps["focus_type"]) is list and len(self.hps["focus_type"]) > 1: - n_valid = len(self.hps["focus_type"]) + if type(tcfg.focus_type) is list and len(tcfg.focus_type) > 1: + n_valid = len(tcfg.focus_type) + else: + n_valid = tcfg.n_valid # preference vectors - if self.hps["preference_type"] is None: + if tcfg.preference_type is None: valid_preferences = np.ones((n_valid, n_obj)) - elif self.hps["preference_type"] == "dirichlet": + elif tcfg.preference_type == "dirichlet": valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - elif self.hps["preference_type"] == "seeded_single": - seeded_prefs = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet([1] * n_obj, n_valid) + elif tcfg.preference_type == "seeded_single": + seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] - elif self.hps["preference_type"] == "seeded_many": - valid_preferences = np.random.default_rng(142857 + int(self.hps["seed"])).dirichlet([1] * n_obj, n_valid) + elif tcfg.preference_type == "seeded_many": + valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: - raise NotImplementedError(f"Unknown preference type {self.hps['preference_type']}") + raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") - # focus regions - if self.hps["focus_type"] is None: - valid_focus_dirs = np.zeros((n_valid, n_obj)) - self.task.fixed_focus_dirs = valid_focus_dirs - elif self.hps["focus_type"] == "centered": - valid_focus_dirs = np.ones((n_valid, n_obj)) - self.task.fixed_focus_dirs = valid_focus_dirs - elif self.hps["focus_type"] == "partitioned": - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.task.fixed_focus_dirs = valid_focus_dirs - elif self.hps["focus_type"] in ["dirichlet", "learned-gfn"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - self.task.fixed_focus_dirs = None - elif self.hps["focus_type"] in ["hyperspherical", "learned-tabular"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.task.fixed_focus_dirs = None - elif type(self.hps["focus_type"]) is list: - if len(self.hps["focus_type"]) == 1: - valid_focus_dirs = np.array([self.hps["focus_type"][0]] * n_valid) - self.task.fixed_focus_dirs = valid_focus_dirs - else: - valid_focus_dirs = np.array(self.hps["focus_type"]) - self.task.fixed_focus_dirs = valid_focus_dirs - else: - raise NotImplementedError( - f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " - f"focus_type, but here: {self.hps['focus_type']}" - ) - - self.hps["fixed_focus_dirs"] = ( - np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None - ) - with open(pathlib.Path(self.hps["log_dir"]) / "hps.json", "w") as f: - json.dump(self.hps, f) - assert valid_focus_dirs.shape == ( + # TODO: this was previously reported, would be nice to serialize it + # hps["fixed_focus_dirs"] = ( + # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None + # ) + assert self.task.valid_focus_dirs.shape == ( n_valid, n_obj, - ), f"Invalid shape for valid_preferences, {valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + ), f"Invalid shape for valid_preferences, {self.task.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" # combine preferences and focus directions (fixed focus cosim) since they could be used together (not either/or) # TODO: this relies on positional assumptions, should have something cleaner - valid_cond_vector = np.concatenate([valid_preferences, valid_focus_dirs], axis=1) + valid_cond_vector = np.concatenate([valid_preferences, self.task.valid_focus_dirs], axis=1) - self._top_k_hook = TopKHook(10, self.hps["n_valid_repeats"], len(valid_cond_vector)) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.hps["n_valid_repeats"]) + self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, len(valid_cond_vector)) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -464,8 +421,8 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: - focus_model_training_limits = self.hps["focus_model_training_limits"] - max_train_it = self.hps["num_training_steps"] + focus_model_training_limits = self.cfg.task.seh_moo.focus_model_training_limits + max_train_it = self.cfg.num_training_steps if ( self.focus_model is not None and train_it >= focus_model_training_limits[0] * max_train_it @@ -476,7 +433,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: def _save_state(self, it): if self.focus_model is not None: - self.focus_model.save(pathlib.Path(self.hps["log_dir"])) + self.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) @@ -497,39 +454,40 @@ def main(): """Example of how this model can be run outside of Determined""" hps = { "log_dir": "./logs/debug_run", + "pickle_mp_messages": True, "overwrite_existing_exp": True, "seed": 0, - "global_batch_size": 64, "num_training_steps": 20_000, "num_final_gen_steps": 500, - "validate_every": 10, - "num_layers": 2, - "num_emb": 256, - "algo": "TB", - "objectives": ["seh", "qed"], - "learning_rate": 1e-4, - "Z_learning_rate": 1e-3, - "lr_decay": 20000, - "Z_lr_decay": 50000, - "sampling_tau": 0.95, - "random_action_prob": 0.1, - "num_data_loader_workers": 0, - "temperature_sample_dist": "constant", - "temperature_dist_params": 60.0, - "num_thermometer_dim": 32, - "use_steer_thermometer": False, - "preference_type": None, - "focus_type": "learned-tabular", - "focus_cosim": 0.98, - "focus_limit_coef": 1e-1, - "n_valid": 15, - "n_valid_repeats": 128, - "use_replay_buffer": True, - "replay_buffer_warmup": 0, - "hindsight_ratio": 0.3, - "mp_pickle_messages": True, - "focus_model_training_limits": [0.25, 0.75], - "focus_model_state_space_res": 10, + "validate_every": 500, + "num_workers": 0, + "algo.global_batch_size": 64, + "algo.method": "TB", + "model.num_layers": 2, + "model.num_emb": 256, + "task.seh_moo.objectives": ["seh", "qed"], + "opt.learning_rate": 1e-4, + "algo.tb.Z_learning_rate": 1e-3, + "opt.lr_decay": 20000, + "algo.tb.Z_lr_decay": 50000, + "algo.sampling_tau": 0.95, + "algo.train_random_action_prob": 0.01, + "task.seh_moo.temperature_sample_dist": "constant", + "task.seh_moo.temperature_dist_params": 60.0, + "task.seh_moo.num_thermometer_dim": 32, + "task.seh_moo.use_steer_thermometer": False, + "task.seh_moo.preference_type": None, + "task.seh_moo.focus_type": "learned-tabular", + "task.seh_moo.focus_cosim": 0.98, + "task.seh_moo.focus_limit_coef": 1e-1, + "task.seh_moo.n_valid": 15, + "task.seh_moo.n_valid_repeats": 128, + "replay.use": True, + "replay.warmup": 1000, + "replay.hindsight_ratio": 0.3, + "task.seh_moo.focus_model_training_limits": [0.25, 0.75], + "task.seh_moo.focus_model_state_space_res": 30, + "task.seh_moo.max_train_it": 20_000, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: diff --git a/src/gflownet/train.py b/src/gflownet/train.py index 324e96bf..bb9c6ef3 100644 --- a/src/gflownet/train.py +++ b/src/gflownet/train.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd +from omegaconf import OmegaConf from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -16,6 +17,8 @@ from gflownet.utils.misc import create_logger from gflownet.utils.multiprocessing_proxy import mp_object_wrapper +from .config import Config + # This type represents an unprocessed list of reward signals/conditioning information FlatRewards = NewType("FlatRewards", Tensor) # type: ignore @@ -29,6 +32,7 @@ def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: """Computes the loss for a batch of data, and proves logging informations + Parameters ---------- model: nn.Module @@ -37,6 +41,7 @@ def compute_batch_losses( A batch of graphs num_bootstrap: Optional[int] The number of trajectories with reward targets in the batch (if applicable). + Returns ------- loss: Tensor @@ -89,7 +94,7 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): Parameters ---------- hps: Dict[str, Any] - A dictionary of hyperparameters. These override default values obtained by the `default_hps` method. + A dictionary of hyperparameters. These override default values obtained by the `set_default_hps` method. device: torch.device The torch device of the main worker. """ @@ -100,21 +105,23 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): # `sampling_model` is used by the data workers to sample new objects from the model. Can be # the same as `model`. self.sampling_model: nn.Module - self.replay_buffer: ReplayBuffer + self.replay_buffer: Optional[ReplayBuffer] self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext self.task: GFNTask self.algo: GFNAlgorithm - # Override default hyperparameters with the constructor arguments - self.hps = {**self.default_hps(), **hps} + # There are three sources of config values + # - The default values specified in individual config classes + # - The default values specified in the `default_hps` method, typically what is defined by a task + # - The values passed in the constructor, typically what is called by the user + # The final config is obtained by merging the three sources + self.cfg: Config = OmegaConf.structured(Config()) + self.set_default_hps(self.cfg) + OmegaConf.merge(self.cfg, hps) + self.device = device - # The number of processes spawned to sample object and do CPU work - self.num_workers: int = self.hps.get("num_data_loader_workers", 0) - # The ratio of samples drawn from `self.training_data` during training. The rest is drawn from - # `self.sampling_model`. - self.offline_ratio = self.hps.get("offline_ratio", 0.5) # idem, but from `self.test_data` during validation. self.valid_offline_ratio = 1 # Print the loss every `self.print_every` iterations @@ -124,12 +131,10 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): self.valid_sampling_hooks: List[Callable] = [] # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False - # Pickle messages to reduce load on shared memory (conversely, increases load on CPU) - self.pickle_messages = hps.get("mp_pickle_messages", False) self.setup() - def default_hps(self) -> Dict[str, Any]: + def set_default_hps(self, base: Config): raise NotImplementedError() def setup(self): @@ -143,12 +148,12 @@ def _wrap_for_mp(self, obj, send_to_device=False): data worker process (only if the number of workers is non-zero).""" if send_to_device: obj.to(self.device) - if self.num_workers > 0 and obj is not None: + if self.cfg.num_workers > 0 and obj is not None: placeholder = mp_object_wrapper( obj, - self.num_workers, + self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical), - pickle_messages=self.pickle_messages, + pickle_messages=self.cfg.pickle_mp_messages, ) return placeholder, torch.device("cpu") else: @@ -163,27 +168,28 @@ def build_training_data_loader(self) -> DataLoader: iterator = SamplingIterator( self.training_data, model, - self.mb_size, self.ctx, self.algo, self.task, dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, replay_buffer=replay_buffer, - ratio=self.offline_ratio, - log_dir=os.path.join(self.hps["log_dir"], "train"), - random_action_prob=self.hps.get("random_action_prob", 0.0), - hindsight_ratio=self.hps.get("hindsight_ratio", 0.0), + ratio=self.cfg.algo.offline_ratio, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), + random_action_prob=self.cfg.algo.train_random_action_prob, + hindsight_ratio=self.cfg.replay.hindsight_ratio, ) for hook in self.sampling_hooks: iterator.add_log_hook(hook) return torch.utils.data.DataLoader( iterator, batch_size=None, - num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, # The 2 here is an odd quirk of torch 1.10, it is fixed and # replaced by None in torch 2. - prefetch_factor=1 if self.num_workers else 2, + prefetch_factor=1 if self.cfg.num_workers else 2, ) def build_validation_data_loader(self) -> DataLoader: @@ -191,25 +197,26 @@ def build_validation_data_loader(self) -> DataLoader: iterator = SamplingIterator( self.test_data, model, - self.mb_size, self.ctx, self.algo, self.task, dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, ratio=self.valid_offline_ratio, - log_dir=os.path.join(self.hps["log_dir"], "valid"), - sample_cond_info=self.hps.get("valid_sample_cond_info", True), + log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), + sample_cond_info=self.cfg.algo.valid_sample_cond_info, stream=False, - random_action_prob=self.hps.get("valid_random_action_prob", 0.0), + random_action_prob=self.cfg.algo.valid_random_action_prob, ) for hook in self.valid_sampling_hooks: iterator.add_log_hook(hook) return torch.utils.data.DataLoader( iterator, batch_size=None, - num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, - prefetch_factor=1 if self.num_workers else 2, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else 2, ) def build_final_data_loader(self) -> DataLoader: @@ -217,26 +224,27 @@ def build_final_data_loader(self) -> DataLoader: iterator = SamplingIterator( self.training_data, model, - self.mb_size, self.ctx, self.algo, self.task, dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, replay_buffer=None, ratio=0.0, - log_dir=os.path.join(self.hps["log_dir"], "final"), + log_dir=os.path.join(self.cfg.log_dir, "final"), random_action_prob=0.0, hindsight_ratio=0.0, - init_train_iter=self.hps["num_training_steps"], + init_train_iter=self.cfg.num_training_steps, ) for hook in self.sampling_hooks: iterator.add_log_hook(hook) return torch.utils.data.DataLoader( iterator, batch_size=None, - num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, - prefetch_factor=1 if self.num_workers else 2, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else 2, ) def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: @@ -248,8 +256,8 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: - os.makedirs(self.hps["log_dir"], exist_ok=True) - torch.save([self.model.state_dict(), batch, loss, info], open(self.hps["log_dir"] + "/dump.pkl", "wb")) + os.makedirs(self.cfg.log_dir, exist_ok=True) + torch.save([self.model.state_dict(), batch, loss, info], open(self.cfg.log_dir + "/dump.pkl", "wb")) raise e if step_info is not None: @@ -269,21 +277,22 @@ def run(self, logger=None): validation every `validate_every` minibatches. """ if logger is None: - logger = create_logger(logfile=self.hps["log_dir"] + "/train.log") + logger = create_logger(logfile=self.cfg.log_dir + "/train.log") self.model.to(self.device) self.sampling_model.to(self.device) epoch_length = max(len(self.training_data), 1) - valid_freq = self.hps.get("validate_every", 0) + valid_freq = self.cfg.validate_every # If checkpoint_every is not specified, checkpoint at every validation epoch - ckpt_freq = self.hps.get("checkpoint_every", valid_freq) + ckpt_freq = self.cfg.checkpoint_every if self.cfg.checkpoint_every is not None else valid_freq train_dl = self.build_training_data_loader() valid_dl = self.build_validation_data_loader() - if self.hps.get("num_final_gen_steps", 0) > 0: + if self.cfg.num_final_gen_steps: final_dl = self.build_final_data_loader() callbacks = self.build_callbacks() - start = self.hps.get("start_at_step", 0) + 1 + start = self.cfg.start_at_step + 1 + num_training_steps = self.cfg.num_training_steps logger.info("Starting training") - for it, batch in zip(range(start, 1 + self.hps["num_training_steps"]), cycle(train_dl)): + for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -308,13 +317,13 @@ def run(self, logger=None): self.log(end_metrics, it, "valid_end") if ckpt_freq > 0 and it % ckpt_freq == 0: self._save_state(it) - self._save_state(self.hps["num_training_steps"]) + self._save_state(num_training_steps) - num_final_gen_steps = self.hps.get("num_final_gen_steps", 0) - if num_final_gen_steps > 0: + num_final_gen_steps = self.cfg.num_final_gen_steps + if num_final_gen_steps: logger.info(f"Generating final {num_final_gen_steps} batches ...") for it, batch in zip( - range(self.hps["num_training_steps"], self.hps["num_training_steps"] + num_final_gen_steps + 1), + range(num_training_steps, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): pass @@ -324,15 +333,15 @@ def _save_state(self, it): torch.save( { "models_state_dict": [self.model.state_dict()], - "hps": self.hps, + "cfg": self.cfg, "step": it, }, - open(pathlib.Path(self.hps["log_dir"]) / "model_state.pt", "wb"), + open(pathlib.Path(self.cfg.log_dir) / "model_state.pt", "wb"), ) def log(self, info, index, key): if not hasattr(self, "_summary_writer"): - self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.hps["log_dir"]) + self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) diff --git a/tests/test_frag_env.py b/tests/test_frag_env.py index 414eaad8..deda34e2 100644 --- a/tests/test_frag_env.py +++ b/tests/test_frag_env.py @@ -1,11 +1,12 @@ import base64 import pickle -from collections import defaultdict import networkx as nx import pytest +from omegaconf import OmegaConf from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv @@ -112,7 +113,9 @@ def test_backwards_mask_equivalence_ipa(two_node_states): """ env = GraphBuildingEnv() ctx = FragMolBuildingEnvContext(max_frags=2) - algo = TrajectoryBalance(env, ctx, None, defaultdict(int), max_nodes=2) + cfg = OmegaConf.structured(Config) + cfg.algo.max_nodes = 2 + algo = TrajectoryBalance(env, ctx, None, cfg) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=True) From 152b18f11460d7b97bf9118060548f2fa9858a91 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 1 Aug 2023 12:02:05 -0400 Subject: [PATCH 8/9] Refactor tasks (#98) This PR refactors tasks into reusable components that will make it easier to define new tasks and methods. The PR is pretty big, but a lot of it is moving code around (or deleting repeated code): - Creates a notion of `Conditional`, implements `TemperatureConditional`, `MultiObjectiveWeightedPreferences`, `FocusRegionConditional` - Separates the commonalities between `seh_frag` and `seh_frag_moo` into a more generic `StandardOnlineTrainer` that is meant to be easily subclassed for new tasks. - Adds some implementation notes and comments - Adds a `validate_batch` routine that's useful for debugging, e.g. new environments and datasets Also fixes some bugs: - Makes `valid_offline_ratio` a flag and sets it explicitly in tasks where it wasn't properly set - `first_graph_idx` was incorrectly calculated in SubTB (affected logging of logZ values) - QM9Dataset was returning the wrong shape for its rewards - Adds a `allow_5_valence_nitrogen` flag to `MolBuildingEnvContext`, this is needed in some cases, see `tasks/qm9.py`. - Adds an explicit `stop_mask` to `MolBuildingEnvContext.graph_to_Data` - Fixes incorrect default objective name in `seh_frag_moo` - Fixes the default configurations in the tasks' `main` that hadn't been updated - Fixes a number of routines where `focus_cond` was assumed to exist (but we can now turn it off). commits: * little test * new config structure - in progress * trying config by names * further refactor progress * better pyi sort + fix moo example * tox * import fixes * add SQL + fix n_valid * tox test * fix mypy hook and convert qm9 to new cfg * better config generation * use generated config.py * use generated config.py * fix rng call types * fix test + tox * better config doc * fix deps * tox * re-fix deps * minor fixes for seh_frag_moo * tox * beginning of refactor + impl notes * multiobject weighted prefs * focus conditional in progress * switch to OmegaConf * switch to omegaconf * fix pre-commit-config * add omegaconf dep * fix list defaults to fields * remove comment * finish focus conditional * various fixes + switch to new config * tox * update README * make string configs into Literals * switch task construction order * OmegaConf does not support Literal :( * tox * remove dead code + guard against no focus used * fix for no replay * refactor qm9 + remove unused configs * many fixes to QM9, some debugging code and other various fixes * explicit valid_offline_ratio flag * do_validate_batch off by default * addressing PR comments * made device configurable --- README.md | 11 +- docs/implementation_notes.md | 18 ++ src/gflownet/algo/config.py | 3 + src/gflownet/algo/envelope_q_learning.py | 2 +- src/gflownet/algo/trajectory_balance.py | 4 +- src/gflownet/config.py | 8 + src/gflownet/data/qm9.py | 6 +- src/gflownet/data/sampling_iterator.py | 35 ++- src/gflownet/envs/mol_building_env.py | 13 +- src/gflownet/models/graph_transformer.py | 45 +-- src/gflownet/online_trainer.py | 107 +++++++ src/gflownet/tasks/config.py | 37 +-- src/gflownet/tasks/qm9/qm9.py | 146 +++------ src/gflownet/tasks/qm9/qm9.yaml | 11 +- src/gflownet/tasks/seh_frag.py | 177 ++--------- src/gflownet/tasks/seh_frag_moo.py | 332 ++++++++------------- src/gflownet/{train.py => trainer.py} | 39 ++- src/gflownet/utils/conditioning.py | 246 +++++++++++++++ src/gflownet/utils/config.py | 77 +++++ src/gflownet/utils/metrics.py | 2 +- src/gflownet/utils/multiobjective_hooks.py | 6 +- 21 files changed, 782 insertions(+), 543 deletions(-) create mode 100644 docs/implementation_notes.md create mode 100644 src/gflownet/online_trainer.py rename src/gflownet/{train.py => trainer.py} (93%) create mode 100644 src/gflownet/utils/conditioning.py create mode 100644 src/gflownet/utils/config.py diff --git a/README.md b/README.md index 5f3dc65c..3b7a143d 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera ## Repo overview -- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations (only [Trajectory Balance](https://arxiv.org/abs/2201.13259) for now), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. +- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. - [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. - [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data. - [examples](docs/examples), contains simple example implementations of GFlowNet. @@ -30,8 +30,11 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). -- [utils](src/gflownet/utils), contains utilities (multiprocessing). -- [`train.py`](src/gflownet/train.py), defines a general harness for training GFlowNet models. +- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). +- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. +- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. + +See [implementation notes](docs/implementation_notes.md) for more. ## Getting started @@ -57,6 +60,8 @@ To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository- pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ... ``` +If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main_3.9.txt`. + ## Developing & Contributing TODO: Write Contributing.md. diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md new file mode 100644 index 00000000..598e570c --- /dev/null +++ b/docs/implementation_notes.md @@ -0,0 +1,18 @@ +# Implementation notes + +This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments. + +## Environment, Context, Task, Trainers + +We separate experiment concerns in four categories: +- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP. +- The Context provides an interface between the agent and the environment, it + - maps graphs to torch_geometric `Data` + instances + - maps GraphActions to action indices + - produces action masks + - communicates to the model what inputs it should expect +- The Task class is responsible for computing the reward of a state, and for sampling conditioning information +- The Trainer class is responsible for instanciating everything, and running the training & testing loop + +Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 121ab5f4..bd0ce3de 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -91,6 +91,8 @@ class AlgoConfig: offline_ratio: float The ratio of samples drawn from `self.training_data` during training. The rest is drawn from `self.sampling_model` + valid_offline_ratio: float + Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training valid_random_action_prob : float @@ -108,6 +110,7 @@ class AlgoConfig: max_edges: int = 128 illegal_action_logreward: float = -100 offline_ratio: float = 0.5 + valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 valid_random_action_prob: float = 0.0 valid_sample_cond_info: bool = True diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index f9fba9d2..4d694ae2 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -14,7 +14,7 @@ generate_forward_trajectory, ) from gflownet.models.graph_transformer import GraphTransformer, mlp -from gflownet.train import GFNTask +from gflownet.trainer import GFNTask from .graph_sampling import GraphSampler diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 57af6056..22fe655e 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -19,7 +19,7 @@ GraphBuildingEnvContext, generate_forward_trajectory, ) -from gflownet.train import GFNAlgorithm +from gflownet.trainer import GFNAlgorithm class TrajectoryBalanceModel(nn.Module): @@ -363,7 +363,7 @@ def compute_batch_losses( traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) # The position of the first graph of each trajectory first_graph_idx = torch.zeros_like(batch.traj_lens) - first_graph_idx = torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) log_Z = per_graph_out[first_graph_idx, 0] else: # Compute log numerator and denominator of the TB objective diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 8de0ebeb..be4fa879 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -7,6 +7,7 @@ from gflownet.data.config import ReplayConfig from gflownet.models.config import ModelConfig from gflownet.tasks.config import TasksConfig +from gflownet.utils.config import ConditionalsConfig @dataclass @@ -51,12 +52,16 @@ class Config: ---------- log_dir : str The directory where to store logs, checkpoints, and samples. + device : str + The device to use for training (either "cpu" or "cuda[:]") seed : int The random seed validate_every : int The number of training steps after which to validate the model checkpoint_every : Optional[int] The number of training steps after which to checkpoint the model + print_every : int + The number of training steps after which to print the training loss start_at_step : int The training step to start at (default: 0) num_final_gen_steps : Optional[int] @@ -76,9 +81,11 @@ class Config: """ log_dir: str = MISSING + device: str = "cuda" seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None + print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None num_training_steps: int = 10_000 @@ -92,3 +99,4 @@ class Config: opt: OptimizerConfig = OptimizerConfig() replay: ReplayConfig = ReplayConfig() task: TasksConfig = TasksConfig() + cond: ConditionalsConfig = ConditionalsConfig() diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index c68ead98..b26c29d2 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import rdkit.Chem as Chem +import torch from torch.utils.data import Dataset @@ -39,7 +40,10 @@ def __len__(self): return len(self.idcs) def __getitem__(self, idx): - return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]]) + return ( + Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), + torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), + ) def convert_h5(): diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index b69591cc..90b8b4db 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -12,6 +12,7 @@ from torch.utils.data import Dataset, IterableDataset from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical class SamplingIterator(IterableDataset): @@ -98,6 +99,7 @@ def __init__( self.random_action_prob = random_action_prob self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter + self.do_validate_batch = False # Turn this on for debugging self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) @@ -130,6 +132,9 @@ def _idx_iterator(self): if n == 0: yield np.arange(0, 0) return + assert ( + self.offline_batch_size > 0 + ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) @@ -237,6 +242,7 @@ def __iter__(self): log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward # Computes some metrics + extra_info = {} if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): @@ -253,7 +259,6 @@ def __iter__(self): {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) if num_online > 0: - extra_info = {} for hook in self.log_hooks: extra_info.update( hook( @@ -318,9 +323,37 @@ def __iter__(self): # TODO: we could very well just pass the cond_info dict to construct_batch above, # and the algo can decide what it wants to put in the batch object + # Only activate for debugging your environment or dataset (e.g. the dataset could be + # generating trajectories with illegal actions) + if self.do_validate_batch: + self.validate_batch(batch, trajs) + self.train_it += worker_info.num_workers if worker_info is not None else 1 yield batch + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [self.model._action_type_to_mask(t, batch) for t in atypes], + [self.model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + def log_generated(self, trajs, rewards, flat_rewards, cond_info): if self.log_molecule_smis: mols = [ diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 44c99e4a..925e3456 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -28,6 +28,7 @@ def __init__( charges=[0, 1, -1], expl_H_range=[0, 1], allow_explicitly_aromatic=False, + allow_5_valence_nitrogen=False, num_rw_feat=8, max_nodes=None, max_edges=None, @@ -118,9 +119,11 @@ def __init__( BondType.AROMATIC: 1.5, } pt = Chem.GetPeriodicTable() + self.allow_5_valence_nitrogen = allow_5_valence_nitrogen self._max_atom_valence = { **{a: max(pt.GetValenceList(a)) for a in atoms}, - "N": 3, # We'll handle nitrogen valence later explicitly in graph_to_Data + # We'll handle nitrogen valence later explicitly in graph_to_Data + "N": 3 if not allow_5_valence_nitrogen else 5, "*": 0, # wildcard atoms have 0 valence until filled in } @@ -231,9 +234,10 @@ def graph_to_Data(self, g: Graph) -> gd.Data: max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]] # Special rule for Nitrogen if ad["v"] == "N" and ad.get("charge", 0) == 1: - # This is definitely a heuristic, but to keep things simple we'll limit Nitrogen's valence to 3 (as + # This is definitely a heuristic, but to keep things simple we'll limit* Nitrogen's valence to 3 (as # per self._max_atom_valence) unless it is charged, then we make it 5. # This keeps RDKit happy (and is probably a good idea anyway). + # (* unless allow_5_valence_nitrogen is True, then it's just always 5) max_atom_valence = 5 max_valence[n] = max_atom_valence - abs(ad.get("charge", 0)) - ad.get("expl_H", 0) # Compute explicitly defined valence: @@ -281,14 +285,15 @@ def is_ok_non_edge(e): non_edge_index = torch.zeros((2, 0), dtype=torch.long) else: gc = nx.complement(g) - non_edge_index = torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).T.reshape( - (2, -1) + non_edge_index = ( + torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).reshape((-1, 2)).T ) data = gd.Data( x, edge_index, edge_attr, non_edge_index=non_edge_index, + stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1), add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 097e7e91..05f9b0e4 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -145,6 +145,27 @@ class GraphTransformerGFN(nn.Module): Outputs logits corresponding to the action types used by the env_ctx argument. """ + # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the + # embeddings to the right MLP + _action_type_to_graph_part = { + GraphActionType.Stop: "graph", + GraphActionType.AddNode: "node", + GraphActionType.SetNodeAttr: "node", + GraphActionType.AddEdge: "non_edge", + GraphActionType.SetEdgeAttr: "edge", + GraphActionType.RemoveNode: "node", + GraphActionType.RemoveNodeAttr: "node", + GraphActionType.RemoveEdge: "edge", + GraphActionType.RemoveEdgeAttr: "edge", + } + # The torch_geometric batch key each graph part corresponds to + _graph_part_to_key = { + "graph": None, + "node": "x", + "non_edge": "non_edge_index", + "edge": "edge_index", + } + def __init__( self, env_ctx, @@ -184,25 +205,8 @@ def __init__( GraphActionType.RemoveEdge: (num_edge_feat, 1), GraphActionType.RemoveEdgeAttr: (num_edge_feat, env_ctx.num_edge_attrs), } - # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the - # embeddings to the right MLP - self._action_type_to_graph_part = { - GraphActionType.Stop: "graph", - GraphActionType.AddNode: "node", - GraphActionType.SetNodeAttr: "node", - GraphActionType.AddEdge: "non_edge", - GraphActionType.SetEdgeAttr: "edge", - GraphActionType.RemoveNode: "node", - GraphActionType.RemoveNodeAttr: "node", - GraphActionType.RemoveEdge: "edge", - GraphActionType.RemoveEdgeAttr: "edge", - } - # The torch_geometric batch key each graph part corresponds to - self._graph_part_to_key = { - "graph": None, - "node": "x", - "non_edge": "non_edge_index", - "edge": "edge_index", + self._action_type_to_key = { + at: self._graph_part_to_key[self._action_type_to_graph_part[at]] for at in self._action_type_to_graph_part } # Here we create only the embedding -> logit mapping MLPs that are required by the environment @@ -229,13 +233,14 @@ def _action_type_to_logit(self, t, emb, g): def _mask(self, x, m): # mask logit vector x with binary mask m, -1000 is a tiny log-value + # Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99) return x * m + -1000 * (1 - m) def _make_cat(self, g, emb, action_types): return GraphActionCategorical( g, logits=[self._action_type_to_logit(t, emb, g) for t in action_types], - keys=[self._graph_part_to_key[self._action_type_to_graph_part[t]] for t in action_types], + keys=[self._action_type_to_key[t] for t in action_types], masks=[self._action_type_to_mask(t, g) for t in action_types], types=action_types, ) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py new file mode 100644 index 00000000..98791be5 --- /dev/null +++ b/src/gflownet/online_trainer.py @@ -0,0 +1,107 @@ +import copy +import os +import pathlib + +import git +import torch +from omegaconf import OmegaConf +from torch import Tensor + +from gflownet.algo.advantage_actor_critic import A2C +from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.soft_q_learning import SoftQLearning +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.models.graph_transformer import GraphTransformerGFN + +from .trainer import GFNTrainer + + +class StandardOnlineTrainer(GFNTrainer): + def setup_model(self): + self.model = GraphTransformerGFN( + self.ctx, + self.cfg, + do_bck=self.cfg.algo.tb.do_parameterize_p_b, + ) + + def setup_algo(self): + algo = self.cfg.algo.method + if algo == "TB": + algo = TrajectoryBalance + elif algo == "FM": + algo = FlowMatching + elif algo == "A2C": + algo = A2C + elif algo == "SQL": + algo = SoftQLearning + else: + raise ValueError(algo) + self.algo = algo(self.env, self.ctx, self.rng, self.cfg) + + def setup_data(self): + self.training_data = [] + self.test_data = [] + + def setup(self): + super().setup() + self.offline_ratio = 0 + self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + + # Separate Z parameters from non-Z to allow for LR decay on the former + if hasattr(self.model, "logZ"): + Z_params = list(self.model.logZ.parameters()) + non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] + else: + Z_params = [] + non_Z_params = list(self.model.parameters()) + self.opt = torch.optim.Adam( + non_Z_params, + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + + self.sampling_tau = self.cfg.algo.sampling_tau + if self.sampling_tau > 0: + self.sampling_model = copy.deepcopy(self.model) + else: + self.sampling_model = self.model + + self.mb_size = self.cfg.algo.global_batch_size + self.clip_grad_callback = { + "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "none": lambda x: None, + }[self.cfg.opt.clip_grad_type] + + # saving hyperparameters + git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] + self.cfg.git_hash = git_hash + + os.makedirs(self.cfg.log_dir, exist_ok=True) + print("\n\nHyperparameters:\n") + yaml = OmegaConf.to_yaml(self.cfg) + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + f.write(yaml) + + def step(self, loss: Tensor): + loss.backward() + for i in self.model.parameters(): + self.clip_grad_callback(i) + self.opt.step() + self.opt.zero_grad() + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched.step() + self.lr_sched_Z.step() + if self.sampling_tau > 0: + for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): + b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index bc0e31d0..a9f6ac3f 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,30 +1,10 @@ from dataclasses import dataclass, field -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple @dataclass class SEHTaskConfig: - """Config for the SEHTask - - Attributes - ---------- - - temperature_sample_dist : str - The distribution to sample the inverse temperature from. Can be one of: - - "uniform": uniform distribution - - "loguniform": log-uniform distribution - - "gamma": gamma distribution - - "constant": constant temperature - temperature_dist_params : List[Any] - The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. - num_thermometer_dim : int - The number of thermometer encoding dimensions to use. - """ - - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 + pass # SEH just uses a temperature conditional @dataclass @@ -57,10 +37,6 @@ class SEHMOOTaskConfig: The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. """ - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 use_steer_thermometer: bool = False preference_type: Optional[str] = "dirichlet" focus_type: Optional[str] = None @@ -72,16 +48,13 @@ class SEHMOOTaskConfig: max_train_it: Optional[int] = None n_valid: int = 15 n_valid_repeats: int = 128 - objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "wt"]) + objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) @dataclass class QM9TaskConfig: - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 - h5_path = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + model_path: str = "./data/qm9/qm9_model.pt" @dataclass diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 5405d867..e5b1d29a 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,27 +1,22 @@ -import copy import os from typing import Callable, Dict, List, Tuple, Union import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class QM9GapTask(GFNTask): @@ -30,19 +25,15 @@ class QM9GapTask(GFNTask): def __init__( self, dataset: Dataset, - temperature_distribution: str, - temperature_parameters: List[float], - num_thermometer_dim: int, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models() + self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset - self.temperature_sample_dist = temperature_distribution - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) # TODO: fix interface self._min, self._max, self._percentile_95 = self.dataset.get_stats(percentile=0.05) # type: ignore self._width = self._max - self._min @@ -69,51 +60,20 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self): + def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? - state_dict = torch.load("/data/chem/qm9/mxmnet_gap_model.pt") + state_dict = torch.load(path) gap_model.load_state_dict(state_dict) gap_model.cuda() gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) in [float, int] - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.uniform(a, b, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.beta(a, b, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": torch.tensor(beta), "encoding": beta_enc} + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] @@ -128,7 +88,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class QM9GapTrainer(GFNTrainer): +class QM9GapTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.num_workers = 8 cfg.num_training_steps = 100000 @@ -139,86 +99,48 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 + cfg.algo.max_nodes = 9 cfg.algo.global_batch_size = 64 cfg.algo.train_random_action_prob = 0.001 cfg.algo.illegal_action_logreward = -75 cfg.algo.sampling_tau = 0.0 cfg.model.num_emb = 128 cfg.model.num_layers = 4 - cfg.task.qm9.temperature_sample_dist = "uniform" - cfg.task.qm9.temperature_dist_params = [0.5, 32.0] - cfg.task.qm9.num_thermometer_dim = 32 - - def setup(self): - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.ctx = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32) - self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") - self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") + cfg.cond.temperature.sample_dist = "uniform" + cfg.cond.temperature.dist_params = [0.5, 32.0] + cfg.cond.temperature.num_thermometer_dim = 32 - model = GraphTransformerGFN(self.ctx, self.cfg) - self.model = model - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.opt.learning_rate, (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay) + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True ) + # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories + # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action + # for setting the explicit hydrogen is used before the positive charge is set, it will be considered + # an invalid action. However, generate_forward_trajectory does not consider this implementation detail, + # it assumes that attribute-setting will always be valid. For the molecular environment, as of writing + # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. + + def setup_data(self): + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") - self.sampling_tau = self.cfg.algo.sampling_tau - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(model) - else: - self.sampling_model = self.model - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) - + def setup_task(self): self.task = QM9GapTask( dataset=self.training_data, - temperature_distribution=self.cfg.task.qm9.temperature_sample_dist, - temperature_parameters=self.cfg.task.qm9.temperature_dist_params, - num_thermometer_dim=self.cfg.task.qm9.num_thermometer_dim, + cfg=self.cfg, + rng=self.rng, wrap_model=self._wrap_for_mp, ) - self.mb_size = self.cfg.algo.global_batch_size - self.clip_grad_param = self.cfg.opt.clip_grad_param - self.clip_grad_callback = { - "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param), - "none": lambda x: None, - }[self.cfg.opt.clip_grad_type] - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" yaml = YAML(typ="safe", pure=True) config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") with open(config_file, "r") as f: hps = yaml.load(f) - trial = QM9GapTrainer(hps, torch.device("cpu")) + trial = QM9GapTrainer(hps) trial.run() diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml index 01ea17f4..19701fac 100644 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ b/src/gflownet/tasks/qm9/qm9.yaml @@ -1,5 +1,10 @@ -lr_decay: 10000 -qm9_h5_path: /data/chem/qm9/qm9.h5 -log_dir: /scratch/logs/qm9_gap_mxmnet +opt: + lr_decay: 10000 +task: + qm9: + h5_path: /rxrx/data/chem/qm9/qm9.h5 + model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt num_training_steps: 100000 validate_every: 100 +log_dir: ./logs/debug_qm9 +num_workers: 0 diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 10fa142c..4d0cc624 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,34 +1,22 @@ -import copy import os -import pathlib import shutil import socket from typing import Callable, Dict, List, Tuple, Union -import git import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from omegaconf import OmegaConf -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset -from gflownet.algo.advantage_actor_critic import A2C -from gflownet.algo.flow_matching import FlowMatching -from gflownet.algo.soft_q_learning import SoftQLearning -from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config -from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.models import bengio2021flow -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class SEHTask(GFNTask): @@ -52,9 +40,8 @@ def __init__( self.rng = rng self.models = self._load_task_models() self.dataset = dataset - self.temperature_sample_dist = cfg.task.seh.temperature_sample_dist - self.temperature_dist_params = cfg.task.seh.temperature_dist_params - self.num_thermometer_dim = cfg.task.seh.num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -68,41 +55,10 @@ def _load_task_models(self): return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) is float - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.uniform(a, b, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.beta(a, b, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": torch.tensor(beta), "encoding": beta_enc} + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -117,7 +73,9 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class SEHFragTrainer(GFNTrainer): +class SEHFragTrainer(StandardOnlineTrainer): + task: SEHTask + def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False @@ -140,6 +98,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-3 @@ -150,20 +109,6 @@ def set_default_hps(self, cfg: Config): cfg.replay.capacity = 10_000 cfg.replay.warmup = 1_000 - def setup_algo(self): - algo = self.cfg.algo.method - if algo == "TB": - algo = TrajectoryBalance - elif algo == "FM": - algo = FlowMatching - elif algo == "A2C": - algo = A2C - elif algo == "SQL": - algo = SoftQLearning - else: - raise ValueError(algo) - self.algo = algo(self.env, self.ctx, self.rng, self.cfg) - def setup_task(self): self.task = SEHTask( dataset=self.training_data, @@ -172,98 +117,30 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - def setup_model(self): - model = GraphTransformerGFN( - self.ctx, - self.cfg, - do_bck=self.cfg.algo.tb.do_parameterize_p_b, - ) - self.model = model - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext( - max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.cfg.task.seh.num_thermometer_dim - ) - - def setup(self): - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.training_data = [] - self.test_data = [] - self.offline_ratio = 0 - self.valid_offline_ratio = 0 - self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None - self.setup_env_context() - self.setup_algo() - self.setup_task() - self.setup_model() - - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(self.model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) - - self.sampling_tau = self.cfg.algo.sampling_tau - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(self.model) - else: - self.sampling_model = self.model - - self.mb_size = self.cfg.algo.global_batch_size - self.clip_grad_callback = { - "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), - "none": lambda x: None, - }[self.cfg.opt.clip_grad_type] - - # saving hyperparameters - git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.cfg.git_hash = git_hash - - os.makedirs(self.cfg.log_dir, exist_ok=True) - print("\n\nHyperparameters:\n") - yaml = OmegaConf.to_yaml(self.cfg) - print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: - f.write(yaml) - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_seh_frag", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), "overwrite_existing_exp": True, "num_training_steps": 10_000, "num_workers": 8, - "opt.lr_decay": 20000, - "algo.sampling_tau": 0.99, - "task.seh.temperature_dist_params": (0.0, 64.0), + "opt": { + "lr_decay": 20000, + }, + "algo": { + "sampling_tau": 0.99, + }, + "cond": { + "temperature": { + "sample_dist": "uniform", + "dist_params": [0, 64.0], + } + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -272,7 +149,7 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + trial = SEHFragTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index c515bc6e..8ad73320 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,8 +1,7 @@ import os import pathlib import shutil -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import torch @@ -11,7 +10,6 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.distributions.dirichlet import Dirichlet from torch.utils.data import Dataset from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL @@ -20,11 +18,10 @@ from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask -from gflownet.train import FlatRewards, RewardScalar +from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore -from gflownet.utils.focus_model import FocusModel, TabularFocusModel +from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook -from gflownet.utils.transforms import thermometer class SEHMOOTask(SEHTask): @@ -42,66 +39,27 @@ def __init__( dataset: Dataset, cfg: Config, rng: np.random.Generator = None, - focus_model: Optional[FocusModel] = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): - self._wrap_model = wrap_model + super().__init__(dataset, cfg, rng, wrap_model) self.cfg = cfg mcfg = self.cfg.task.seh_moo - self.rng = rng - self.models = self._load_task_models() self.objectives = cfg.task.seh_moo.objectives self.dataset = dataset - self.temperature_sample_dist = mcfg.temperature_sample_dist - self.temperature_dist_params = mcfg.temperature_dist_params - self.num_thermometer_dim = mcfg.num_thermometer_dim - self.use_steer_thermometer = mcfg.use_steer_thermometer - self.preference_type = mcfg.preference_type - self.seeded_preference = None - self.experimental_dirichlet = False - self.focus_type = mcfg.focus_type - self.focus_cosim = mcfg.focus_cosim - self.focus_limit_coef = mcfg.focus_limit_coef - self.focus_model = focus_model - self.illegal_action_logreward = cfg.algo.illegal_action_logreward - self.focus_model_training_limits = mcfg.focus_model_training_limits - self.max_train_it = mcfg.max_train_it - self.setup_focus_regions() - assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) - - def setup_focus_regions(self): - mcfg = self.cfg.task.seh_moo - n_valid = mcfg.n_valid - n_obj = len(self.objectives) - # focus regions - if mcfg.focus_type is None: - valid_focus_dirs = np.zeros((n_valid, n_obj)) - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type == "centered": - valid_focus_dirs = np.ones((n_valid, n_obj)) - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type == "partitioned": - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type in ["dirichlet", "learned-gfn"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - self.fixed_focus_dirs = None - elif mcfg.focus_type in ["hyperspherical", "learned-tabular"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.fixed_focus_dirs = None - elif mcfg.focus_type == "listed": - if len(mcfg.focus_type) == 1: - valid_focus_dirs = np.array([mcfg.focus_dirs_listed[0]] * n_valid) - self.fixed_focus_dirs = valid_focus_dirs - else: - valid_focus_dirs = np.array(mcfg.focus_dirs_listed) - self.fixed_focus_dirs = valid_focus_dirs + if self.cfg.cond.focus_region.focus_type is not None: + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) else: - raise NotImplementedError( - f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " - f"focus_type, but here: {mcfg.focus_type}" - ) - self.valid_focus_dirs = valid_focus_dirs + self.focus_cond = None + self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) + self.temperature_sample_dist = cfg.cond.temperature.sample_dist + self.temperature_dist_params = cfg.cond.temperature.dist_params + self.num_thermometer_dim = cfg.cond.temperature.num_thermometer_dim + self.num_cond_dim = ( + self.temperature_conditional.encoding_size() + + self.pref_cond.encoding_size() + + (self.focus_cond.encoding_size() if self.focus_cond is not None else 0) + ) + assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y)) @@ -109,62 +67,18 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: def inverse_flat_reward_transform(self, rp): return rp - def _load_task_models(self): - model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model, send_to_device=True) - return {"seh": model} - - def get_steer_encodings(self, preferences, focus_dirs): - n = len(preferences) - if self.use_steer_thermometer: - pref_enc = thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1) - focus_enc = thermometer(focus_dirs, self.num_thermometer_dim, 0, 1).reshape(n, -1) - else: - pref_enc = preferences - focus_enc = focus_dirs - return pref_enc, focus_enc - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) - - if self.preference_type is None: - preferences = torch.ones((n, len(self.objectives))) - else: - if self.seeded_preference is not None: - preferences = torch.tensor([self.seeded_preference] * n).float() - elif self.experimental_dirichlet: - a = np.random.dirichlet([1] * len(self.objectives), n) - b = np.random.exponential(1, n)[:, None] - preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() - else: - m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - preferences = m.sample([n]) - - if self.fixed_focus_dirs is not None: - focus_dir = torch.tensor( - np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) - ) - elif self.focus_type == "dirichlet": - m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - focus_dir = m.sample([n]) - elif self.focus_type == "hyperspherical": - focus_dir = torch.tensor( - metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") - ).float() - elif self.focus_type is not None and "learned" in self.focus_type: - if self.focus_model is not None and train_it >= self.focus_model_training_limits[0] * self.max_train_it: - focus_dir = self.focus_model.sample_focus_directions(n) - else: - focus_dir = torch.tensor( - metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") - ).float() - else: - raise NotImplementedError(f"Unsupported focus_type={type(self.focus_type)}") - - preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) - cond_info["encoding"] = torch.cat([cond_info["encoding"], preferences_enc, focus_enc], 1) - cond_info["preferences"] = preferences - cond_info["focus_dir"] = focus_dir + pref_ci = self.pref_cond.sample(n) + focus_ci = ( + self.focus_cond.sample(n, train_it) if self.focus_cond is not None else {"encoding": torch.zeros(n, 0)} + ) + cond_info = { + **cond_info, + **pref_ci, + **focus_ci, + "encoding": torch.cat([cond_info["encoding"], pref_ci["encoding"], focus_ci["encoding"]], dim=1), + } return cond_info def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: @@ -179,7 +93,7 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor """ n = len(steer_info) if self.temperature_sample_dist == "constant": - beta = torch.ones(n) * self.temperature_dist_params + beta = torch.ones(n) * self.temperature_dist_params[0] beta_enc = torch.zeros((n, self.num_thermometer_dim)) else: beta = torch.ones(n) * self.temperature_dist_params[-1] @@ -191,9 +105,12 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor preferences = steer_info[:, : len(self.objectives)].float() focus_dir = steer_info[:, len(self.objectives) :].float() - preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) - encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() - + preferences_enc = self.pref_cond.encode(preferences) + if self.focus_cond is not None: + focus_enc = self.focus_cond.encode(focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() + else: + encoding = torch.cat([beta_enc, preferences_enc], 1).float() return { "beta": beta, "encoding": encoding, @@ -204,17 +121,23 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor def relabel_condinfo_and_logrewards( self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor ): - if self.focus_type is None: + # TODO: we seem to be relabeling tensors in place, could that cause a problem? + if self.focus_cond is None: + raise NotImplementedError("Hindsight relabeling only implemented for focus conditioning") + if self.focus_cond.cfg.focus_type is None: return cond_info, log_rewards # only keep hindsight_idxs that actually correspond to a violated constraint - _, in_focus_mask = metrics.compute_focus_coef(flat_rewards, cond_info["focus_dir"], self.focus_cosim) + _, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + ) out_focus_mask = torch.logical_not(in_focus_mask) hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] # relabels the focus_dirs and log_rewards cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) - preferences_enc, focus_enc = self.get_steer_encodings(cond_info["preferences"], cond_info["focus_dir"]) + preferences_enc = self.pref_cond.encode(cond_info["preferences"]) + focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) cond_info["encoding"] = torch.cat( [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 ) @@ -228,19 +151,15 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - - if self.focus_type is not None: - focus_coef, in_focus_mask = metrics.compute_focus_coef( - flat_reward, cond_info["focus_dir"], self.focus_cosim, self.focus_limit_coef - ) - scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) - scalar_logreward[~in_focus_mask] = self.illegal_action_logreward - return RewardScalar(scalar_logreward * cond_info["beta"]) + scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) + focused_reward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + if self.focus_cond is not None + else scalarized_reward + ) + tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) + return RewardScalar(tempered_reward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -283,11 +202,15 @@ def safe(f, x, default): class SEHMOOFragTrainer(SEHFragTrainer): task: SEHMOOTask + ctx: FragMolBuildingEnvContext def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 + # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) + # sampling and set the offline ratio to 1 cfg.algo.valid_sample_cond_info = False + cfg.algo.valid_offline_ratio = 1 def setup_algo(self): algo = self.cfg.algo.method @@ -298,35 +221,16 @@ def setup_algo(self): else: super().setup_algo() - focus_type = self.cfg.task.seh_moo.focus_type - if focus_type is not None and "learned" in focus_type: - if focus_type == "learned-tabular": - self.focus_model = TabularFocusModel( - device=self.device, - n_objectives=len(self.cfg.task.seh_moo.objectives), - state_space_res=self.cfg.task.seh_moo.focus_model_state_space_res, - ) - else: - raise NotImplementedError("Unknown focus model type {self.focus_type}") - else: - self.focus_model = None - def setup_task(self): self.task = SEHMOOTask( dataset=self.training_data, cfg=self.cfg, - focus_model=self.focus_model, rng=self.rng, wrap_model=self._wrap_for_mp, ) def setup_env_context(self): - if self.cfg.task.seh_moo.use_steer_thermometer: - ncd = self.cfg.task.seh_moo.num_thermometer_dim * (1 + 2 * len(self.cfg.task.seh_moo.objectives)) - else: - # 1 for prefs and 1 for focus region - ncd = self.cfg.task.seh_moo.num_thermometer_dim + 2 * len(self.cfg.task.seh_moo.objectives) - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=ncd) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) def setup_model(self): if self.cfg.algo.method == "MOQL": @@ -391,16 +295,22 @@ def setup(self): # hps["fixed_focus_dirs"] = ( # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None # ) - assert self.task.valid_focus_dirs.shape == ( - n_valid, - n_obj, - ), f"Invalid shape for valid_preferences, {self.task.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + if self.task.focus_cond is not None: + assert self.task.focus_cond.valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), ( + "Invalid shape for valid_preferences, " + f"{self.task.focus_cond.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + ) - # combine preferences and focus directions (fixed focus cosim) since they could be used together (not either/or) - # TODO: this relies on positional assumptions, should have something cleaner - valid_cond_vector = np.concatenate([valid_preferences, self.task.valid_focus_dirs], axis=1) + # combine preferences and focus directions (fixed focus cosim) since they could be used together + # (not either/or). TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, self.task.focus_cond.valid_focus_dirs], axis=1) + else: + valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, len(valid_cond_vector)) + self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) @@ -421,19 +331,13 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: - focus_model_training_limits = self.cfg.task.seh_moo.focus_model_training_limits - max_train_it = self.cfg.num_training_steps - if ( - self.focus_model is not None - and train_it >= focus_model_training_limits[0] * max_train_it - and train_it <= focus_model_training_limits[1] * max_train_it - ): - self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) + if self.task.focus_cond is not None: + self.task.focus_cond.step_focus_model(batch, train_it) return super().train_batch(batch, epoch_idx, batch_idx, train_it) def _save_state(self, it): - if self.focus_model is not None: - self.focus_model.save(pathlib.Path(self.cfg.log_dir)) + if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: + self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) @@ -451,43 +355,65 @@ def __getitem__(self, idx): def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_sfm", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), "pickle_mp_messages": True, "overwrite_existing_exp": True, "seed": 0, - "num_training_steps": 20_000, - "num_final_gen_steps": 500, - "validate_every": 500, + "num_training_steps": 500, + "num_final_gen_steps": 50, + "validate_every": 100, "num_workers": 0, - "algo.global_batch_size": 64, - "algo.method": "TB", - "model.num_layers": 2, - "model.num_emb": 256, - "task.seh_moo.objectives": ["seh", "qed"], - "opt.learning_rate": 1e-4, - "algo.tb.Z_learning_rate": 1e-3, - "opt.lr_decay": 20000, - "algo.tb.Z_lr_decay": 50000, - "algo.sampling_tau": 0.95, - "algo.train_random_action_prob": 0.01, - "task.seh_moo.temperature_sample_dist": "constant", - "task.seh_moo.temperature_dist_params": 60.0, - "task.seh_moo.num_thermometer_dim": 32, - "task.seh_moo.use_steer_thermometer": False, - "task.seh_moo.preference_type": None, - "task.seh_moo.focus_type": "learned-tabular", - "task.seh_moo.focus_cosim": 0.98, - "task.seh_moo.focus_limit_coef": 1e-1, - "task.seh_moo.n_valid": 15, - "task.seh_moo.n_valid_repeats": 128, - "replay.use": True, - "replay.warmup": 1000, - "replay.hindsight_ratio": 0.3, - "task.seh_moo.focus_model_training_limits": [0.25, 0.75], - "task.seh_moo.focus_model_state_space_res": 30, - "task.seh_moo.max_train_it": 20_000, + "algo": { + "global_batch_size": 64, + "method": "TB", + "sampling_tau": 0.95, + "train_random_action_prob": 0.01, + "tb": { + "Z_learning_rate": 1e-3, + "Z_lr_decay": 50000, + }, + }, + "model": { + "num_layers": 2, + "num_emb": 256, + }, + "task": { + "seh_moo": { + "objectives": ["seh", "qed"], + "n_valid": 15, + "n_valid_repeats": 128, + }, + }, + "opt": { + "learning_rate": 1e-4, + "lr_decay": 20000, + }, + "cond": { + "temperature": { + "sample_dist": "constant", + "dist_params": [60.0], + "num_thermometer_dim": 32, + }, + "weighted_prefs": { + "preference_type": "dirichlet", + }, + "focus_region": { + "focus_type": None, # "learned-tabular", + "focus_cosim": 0.98, + "focus_limit_coef": 1e-1, + "focus_model_training_limits": (0.25, 0.75), + "focus_model_state_space_res": 30, + "max_train_it": 5_000, + }, + }, + "replay": { + "use": False, + "warmup": 1000, + "hindsight_ratio": 0.0, + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -496,7 +422,7 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHMOOFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + trial = SEHMOOFragTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/train.py b/src/gflownet/trainer.py similarity index 93% rename from src/gflownet/train.py rename to src/gflownet/trainer.py index bb9c6ef3..93e0e0a5 100644 --- a/src/gflownet/train.py +++ b/src/gflownet/trainer.py @@ -2,11 +2,13 @@ import pathlib from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd from omegaconf import OmegaConf +from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -88,7 +90,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class GFNTrainer: - def __init__(self, hps: Dict[str, Any], device: torch.device): + def __init__(self, hps: Dict[str, Any]): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters @@ -119,13 +121,12 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): # The final config is obtained by merging the three sources self.cfg: Config = OmegaConf.structured(Config()) self.set_default_hps(self.cfg) - OmegaConf.merge(self.cfg, hps) + # OmegaConf returns a fancy object but we can still pretend it's a Config instance + self.cfg = OmegaConf.merge(self.cfg, hps) # type: ignore - self.device = device - # idem, but from `self.test_data` during validation. - self.valid_offline_ratio = 1 + self.device = torch.device(self.cfg.device) # Print the loss every `self.print_every` iterations - self.print_every = 1000 + self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = [] @@ -137,12 +138,34 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): def set_default_hps(self, base: Config): raise NotImplementedError() - def setup(self): + def setup_env_context(self): + raise NotImplementedError() + + def setup_task(self): + raise NotImplementedError() + + def setup_model(self): raise NotImplementedError() + def setup_algo(self): + raise NotImplementedError() + + def setup_data(self): + pass + def step(self, loss: Tensor): raise NotImplementedError() + def setup(self): + RDLogger.DisableLog("rdApp.*") + self.rng = np.random.default_rng(142857) + self.env = GraphBuildingEnv() + self.setup_data() + self.setup_task() + self.setup_env_context() + self.setup_algo() + self.setup_model() + def _wrap_for_mp(self, obj, send_to_device=False): """Wraps an object in a placeholder whose reference can be sent to a data worker process (only if the number of workers is non-zero).""" @@ -203,7 +226,7 @@ def build_validation_data_loader(self) -> DataLoader: dev, batch_size=self.cfg.algo.global_batch_size, illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - ratio=self.valid_offline_ratio, + ratio=self.cfg.algo.valid_offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), sample_cond_info=self.cfg.algo.valid_sample_cond_info, stream=False, diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py new file mode 100644 index 00000000..5893ef3b --- /dev/null +++ b/src/gflownet/utils/conditioning.py @@ -0,0 +1,246 @@ +import abc +from copy import deepcopy +from typing import Dict + +import numpy as np +import torch +from scipy import stats +from torch import Tensor +from torch.distributions.dirichlet import Dirichlet +from torch_geometric import data as gd + +from gflownet.config import Config +from gflownet.utils import metrics +from gflownet.utils.focus_model import TabularFocusModel +from gflownet.utils.transforms import thermometer + + +class Conditional(abc.ABC): + def sample(self, n): + raise NotImplementedError() + + @abc.abstractmethod + def transform(self, cond_info: Dict[str, Tensor], properties: Tensor) -> Tensor: + raise NotImplementedError() + + def encoding_size(self): + raise NotImplementedError() + + def encode(self, conditional: Tensor) -> Tensor: + raise NotImplementedError() + + +class TemperatureConditional(Conditional): + def __init__(self, cfg: Config, rng: np.random.Generator): + self.cfg = cfg + tmp_cfg = self.cfg.cond.temperature + self.rng = rng + self.upper_bound = 1024 + if tmp_cfg.sample_dist == "gamma": + loc, scale = tmp_cfg.dist_params + self.upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) + elif tmp_cfg.sample_dist == "uniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "loguniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "beta": + self.upper_bound = 1 + + def encoding_size(self): + return self.cfg.cond.temperature.num_thermometer_dim + + def sample(self, n): + cfg = self.cfg.cond.temperature + beta = None + if cfg.sample_dist == "constant": + assert type(cfg.dist_params[0]) is float + beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) + beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + else: + if cfg.sample_dist == "gamma": + loc, scale = cfg.dist_params + beta = self.rng.gamma(loc, scale, n).astype(np.float32) + elif cfg.sample_dist == "uniform": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.uniform(a, b, n).astype(np.float32) + elif cfg.sample_dist == "loguniform": + low, high = np.log(cfg.dist_params) + beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) + elif cfg.sample_dist == "beta": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.beta(a, b, n).astype(np.float32) + beta_enc = thermometer(torch.tensor(beta), cfg.num_thermometer_dim, 0, self.upper_bound) + + assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" + return {"beta": torch.tensor(beta), "encoding": beta_enc} + + def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor: + scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == len( + cond_info["beta"].shape + ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" + return scalar_logreward * cond_info["beta"] + + def encode(self, conditional: Tensor) -> Tensor: + cfg = self.cfg.cond.temperature + if cfg.sample_dist == "constant": + return torch.zeros((conditional.shape[0], cfg.num_thermometer_dim)) + return thermometer(torch.tensor(conditional), cfg.num_thermometer_dim, 0, self.upper_bound) + + +class MultiObjectiveWeightedPreferences(Conditional): + def __init__(self, cfg: Config): + self.cfg = cfg.cond.weighted_prefs + self.num_objectives = cfg.cond.moo.num_objectives + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim + if self.cfg.preference_type == "seeded": + self.seeded_prefs = np.random.default_rng(142857 + int(cfg.seed)).dirichlet([1] * self.num_objectives) + + def sample(self, n): + if self.cfg.preference_type is None: + preferences = torch.ones((n, self.num_objectives)) + elif self.cfg.preference_type == "seeded": + preferences = torch.tensor(self.seeded_prefs).float().repeat(n, 1) + elif self.cfg.preference_type == "dirichlet_exponential": + a = np.random.dirichlet([1] * self.num_objectives, n) + b = np.random.exponential(1, n)[:, None] + preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() + elif self.cfg.preference_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.num_objectives)) + preferences = m.sample([n]) + else: + raise ValueError(f"Unknown preference type {self.cfg.preference_type}") + preferences = torch.as_tensor(preferences).float() + return {"preferences": preferences, "encoding": self.encode(preferences)} + + def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: + scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" + return scalar_logreward + + def encoding_size(self): + return max(1, self.num_thermometer_dim * self.num_objectives) + + def encode(self, conditional: Tensor) -> Tensor: + if self.num_thermometer_dim > 0: + return thermometer(conditional, self.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + else: + return conditional.unsqueeze(1) + + +class FocusRegionConditional(Conditional): + def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): + self.cfg = cfg.cond.focus_region + self.n_valid = n_valid + self.n_objectives = cfg.cond.moo.num_objectives + self.ocfg = cfg + self.rng = rng + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim if self.cfg.use_steer_thermomether else 0 + + focus_type = self.cfg.focus_type + if focus_type is not None and "learned" in focus_type: + if focus_type == "learned-tabular": + self.focus_model = TabularFocusModel( + # TODO: proper device propagation + device=torch.device("cpu"), + n_objectives=cfg.cond.moo.num_objectives, + state_space_res=self.cfg.focus_model_state_space_res, + ) + else: + raise NotImplementedError("Unknown focus model type {self.focus_type}") + else: + self.focus_model = None + self.setup_focus_regions() + + def encoding_size(self): + if self.num_thermometer_dim > 0: + return self.num_thermometer_dim * self.n_objectives + return self.n_objectives + + def setup_focus_regions(self): + # focus regions + if self.cfg.focus_type is None: + valid_focus_dirs = np.zeros((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "centered": + valid_focus_dirs = np.ones((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "partitioned": + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type in ["dirichlet", "learned-gfn"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l1") + self.fixed_focus_dirs = None + elif self.cfg.focus_type in ["hyperspherical", "learned-tabular"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = None + elif type(self.cfg.focus_type) is list: + if len(self.cfg.focus_type) == 1: + valid_focus_dirs = np.array([self.cfg.focus_type[0]] * self.n_valid) + self.fixed_focus_dirs = valid_focus_dirs + else: + valid_focus_dirs = np.array(self.cfg.focus_type) + self.fixed_focus_dirs = valid_focus_dirs + else: + raise NotImplementedError( + f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " + f"focus_type, but here: {self.cfg.focus_type}" + ) + self.valid_focus_dirs = valid_focus_dirs + + def sample(self, n: int, train_it: int = None): + train_it = train_it or 0 + if self.fixed_focus_dirs is not None: + focus_dir = torch.tensor( + np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) + ) + elif self.cfg.focus_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.n_objectives)) + focus_dir = m.sample([n]) + elif self.cfg.focus_type == "hyperspherical": + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + elif self.cfg.focus_type is not None and "learned" in self.cfg.focus_type: + if ( + self.focus_model is not None + and train_it >= self.cfg.focus_model_training_limits[0] * self.cfg.max_train_it + ): + focus_dir = self.focus_model.sample_focus_directions(n) + else: + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + else: + raise NotImplementedError(f"Unsupported focus_type={type(self.cfg.focus_type)}") + + return {"focus_dir": focus_dir, "encoding": self.encode(focus_dir)} + + def encode(self, conditional: Tensor) -> Tensor: + return ( + thermometer(conditional, self.ocfg.cond.moo.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + if self.cfg.use_steer_thermomether + else conditional + ) + + def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_logreward: Tensor = None) -> Tensor: + focus_coef, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef + ) + if scalar_logreward is None: + scalar_logreward = torch.log(focus_coef) + else: + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward + + return scalar_logreward + + def step_focus_model(self, batch: gd.Batch, train_it: int): + focus_model_training_limits = self.cfg.focus_model_training_limits + max_train_it = self.ocfg.num_training_steps + if ( + self.focus_model is not None + and train_it >= focus_model_training_limits[0] * max_train_it + and train_it <= focus_model_training_limits[1] * max_train_it + ): + self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py new file mode 100644 index 00000000..db3d3905 --- /dev/null +++ b/src/gflownet/utils/config.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional + + +@dataclass +class TempCondConfig: + """Config for the temperature conditional. + + Attributes + ---------- + + sample_dist : str + The distribution to sample the inverse temperature from. Can be one of: + - "uniform": uniform distribution + - "loguniform": log-uniform distribution + - "gamma": gamma distribution + - "constant": constant temperature + - "beta": beta distribution + dist_params : List[Any] + The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. + num_thermometer_dim : int + The number of thermometer encoding dimensions to use. + """ + + sample_dist: str = "uniform" + dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + + +@dataclass +class MultiObjectiveConfig: + num_objectives: int = 2 + num_thermometer_dim: int = 16 + + +@dataclass +class WeightedPreferencesConfig: + """Config for the weighted preferences conditional. + + Attributes + ---------- + preference_type : str + The preference sampling distribution, defaults to "dirichlet". Can be one of: + - "dirichlet": Dirichlet distribution + - "dirichlet_exponential": Dirichlet distribution with exponential temperature + - "seeded": Enumerated preferences + - None: All rewards equally weighted""" + + preference_type: Optional[str] = "dirichlet" + + +@dataclass +class FocusRegionConfig: + """Config for the focus region conditional. + + Attributes + ---------- + focus_type : str + The type of focus distribtuion used, see FocusRegionConditon.setup_focus_regions. Can be one of: + [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] + """ + + focus_type: Optional[str] = "learned-tabular" + use_steer_thermomether: bool = False + focus_cosim: float = 0.98 + focus_limit_coef: float = 0.1 + focus_model_training_limits: tuple[float, float] = (0.25, 0.75) + focus_model_state_space_res: int = 30 + max_train_it: int = 20_000 + + +@dataclass +class ConditionalsConfig: + temperature: TempCondConfig = TempCondConfig() + moo: MultiObjectiveConfig = MultiObjectiveConfig() + weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() + focus_region: FocusRegionConfig = FocusRegionConfig() diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index 3b118b44..cc37c127 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -33,7 +33,7 @@ def compute_focus_coef( assert ( focus_limit_coef > 0.0 and focus_limit_coef <= 1.0 ), f"focus_limit_coef must be in (0, 1], now {focus_limit_coef}" - focus_gamma_param = np.log(focus_limit_coef) / np.log(focus_cosim) + focus_gamma_param = torch.tensor(np.log(focus_limit_coef) / np.log(focus_cosim)).float() cosim = nn.functional.cosine_similarity(flat_rewards, focus_dirs, dim=1) in_focus_mask = cosim >= focus_cosim focus_coef = torch.where(in_focus_mask, cosim**focus_gamma_param, 0.0) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index d1ac0a2a..4862c6c7 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -120,7 +120,8 @@ def _run_pareto_accumulation(self): def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards) - self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) + if self.compute_focus_accuracy: + self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) self.all_smi = self.all_smi + list([i.get("smi", None) for i in trajs]) if len(self.all_flat_rewards) > self.num_to_keep: self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep :] @@ -128,7 +129,8 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): self.all_smi = self.all_smi[-self.num_to_keep :] flat_rewards = torch.stack(self.all_flat_rewards).numpy() - focus_dirs = torch.stack(self.all_focus_dirs).numpy() + if self.compute_focus_accuracy: + focus_dirs = torch.stack(self.all_focus_dirs).numpy() # collects empirical pareto front from in-process samples pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False) From ec857a5e2389eca511e0fc4264b1a701024db5dd Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 4 Aug 2023 12:41:04 -0400 Subject: [PATCH 9/9] Add support for backwards actions in MolBuildingEnvContext (#100) This PR: - adds support for backward actions in `MolBuildingEnvContext`, - adds tests for `MolBuildingEnvContext` that check that backwards masks are correct, - adds a toy atom environment where the reward is simply the number of rings in a carbon-only molecule of up to 6 atoms. - fixes a bug in `GraphBuildingEnv.parent`, whereby 1-node graphs with attributes would have their parents misenumerated (which hadn't manifested itself because until now we didn't have contexts with more than 1 node attribute). * add backward mask support + small ring task * more comments and implementation notes --- docs/implementation_notes.md | 18 ++++ src/gflownet/envs/graph_building_env.py | 4 +- src/gflownet/envs/mol_building_env.py | 117 +++++++++++++++++++---- src/gflownet/tasks/make_rings.py | 90 +++++++++++++++++ tests/{test_frag_env.py => test_envs.py} | 51 ++++++++-- 5 files changed, 254 insertions(+), 26 deletions(-) create mode 100644 src/gflownet/tasks/make_rings.py rename tests/{test_frag_env.py => test_envs.py} (76%) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 598e570c..6930728d 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -16,3 +16,21 @@ We separate experiment concerns in four categories: - The Trainer class is responsible for instanciating everything, and running the training & testing loop Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. + + +## Graphs + +This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations. + +Some notes: +- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs. +- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding. + + +### Graph policies & graph action categoricals + +The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch. + +Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. + +The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. \ No newline at end of file diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index a3df87bf..fa7b284b 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -271,9 +271,9 @@ def add_parent(a, new_g): GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]), new_g, ) - if len(g.nodes) == 1: + if len(g.nodes) == 1 and len(g.nodes[i]) == 1: # The final node is degree 0, need this special case to remove it - # and end up with S0, the empty graph root + # and end up with S0, the empty graph root (but only if it has no attrs except 'v') add_parent( GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]), graph_without_node(g, i), diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 925e3456..13535bbf 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -8,7 +8,13 @@ from rdkit.Chem import Mol from rdkit.Chem.rdchem import BondType, ChiralType -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnvContext, + graph_without_edge, +) from gflownet.utils.graphs import random_walk_probs DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] @@ -77,19 +83,22 @@ def __init__( # The size of the input vector for each atom self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values()) self.atom_attrs = sorted(self.atom_attr_values.keys()) + # 'v' is set separately when creating the node, so there's no point in having a SetNodeAttr logit for it + self.settable_atom_attrs = [i for i in self.atom_attrs if i != "v"] # The beginning position within the input vector of each attribute self.atom_attr_slice = [0] + list(np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs])) # The beginning position within the logit vector of each attribute - num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs] + num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.settable_atom_attrs] self.atom_attr_logit_slice = { k: (s, e) - for k, s, e in zip(self.atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits)) + for k, s, e in zip( + self.settable_atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits) + ) } # The attribute and value each logit dimension maps back to self.atom_attr_logit_map = [ (k, v) - for k in self.atom_attrs - if k != "v" + for k in self.settable_atom_attrs # index 0 is skipped because it is the default value for v in self.atom_attr_values[k][1:] ] @@ -147,12 +156,21 @@ def __init__( GraphActionType.AddEdge, GraphActionType.SetEdgeAttr, ] + self.bck_action_type_order = [ + GraphActionType.RemoveNode, + GraphActionType.RemoveNodeAttr, + GraphActionType.RemoveEdge, + GraphActionType.RemoveEdgeAttr, + ] self.device = torch.device("cpu") def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" act_type, act_row, act_col = [int(i) for i in action_idx] - t = self.action_type_order[act_type] + if fwd: + t = self.action_type_order[act_type] + else: + t = self.bck_action_type_order[act_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: @@ -164,12 +182,34 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: a, b = g.non_edge_index[:, act_row] return GraphAction(t, source=a.item(), target=b.item()) elif t is GraphActionType.SetEdgeAttr: - a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. + a, b = g.edge_index[:, act_row * 2] attr, val = self.bond_attr_logit_map[act_col] return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) + elif t is GraphActionType.RemoveNode: + return GraphAction(t, source=act_row) + elif t is GraphActionType.RemoveNodeAttr: + attr = self.settable_atom_attrs[act_col] + return GraphAction(t, source=act_row, attr=attr) + elif t is GraphActionType.RemoveEdge: + a, b = g.edge_index[:, act_row * 2] # see note above about edge_index + return GraphAction(t, source=a.item(), target=b.item()) + elif t is GraphActionType.RemoveEdgeAttr: + a, b = g.edge_index[:, act_row * 2] # see note above about edge_index + attr = self.bond_attrs[act_col] + return GraphAction(t, source=a.item(), target=b.item(), attr=attr) def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: """Translate a GraphAction to an index tuple""" + for u in [self.action_type_order, self.bck_action_type_order]: + if action.action in u: + type_idx = u.index(action.action) + break + else: + raise ValueError(f"Unknown action type {action.action}") + if action.action is GraphActionType.Stop: row = col = 0 elif action.action is GraphActionType.AddNode: @@ -191,17 +231,33 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int ).argmax() col = 0 elif action.action is GraphActionType.SetEdgeAttr: - # Here the edges are duplicated, both (i,j) and (j,i) are in edge_index - # so no need for a double check. - # row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) + - # (g.edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)).argmax() + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() - # Because edges are duplicated but logits aren't, divide by two row = row.div(2, rounding_mode="floor") # type: ignore col = ( self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0] ) - type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.RemoveNode: + row = action.source + col = 0 + elif action.action is GraphActionType.RemoveNodeAttr: + row = action.source + col = self.settable_atom_attrs.index(action.attr) + elif action.action is GraphActionType.RemoveEdge: + row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax() + # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. + # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one + # to another we can safely divide or multiply by two. + row = int(row) // 2 + col = 0 + elif action.action is GraphActionType.RemoveEdgeAttr: + row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() + row = row.div(2, rounding_mode="floor") # type: ignore + col = self.bond_attrs.index(action.attr) + else: + raise ValueError(f"Unknown action type {action.action}") return (type_idx, int(row), int(col)) def graph_to_Data(self, g: Graph) -> gd.Data: @@ -211,6 +267,9 @@ def graph_to_Data(self, g: Graph) -> gd.Data: add_node_mask = torch.ones((x.shape[0], self.num_new_node_values)) if self.max_nodes is not None and len(g.nodes) >= self.max_nodes: add_node_mask *= 0 + remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) + remove_node_attr_mask = torch.zeros((x.shape[0], len(self.settable_atom_attrs))) + explicit_valence = {} max_valence = {} set_node_attr_mask = torch.ones((x.shape[0], self.num_node_attr_logits)) @@ -218,18 +277,33 @@ def graph_to_Data(self, g: Graph) -> gd.Data: set_node_attr_mask *= 0 for i, n in enumerate(g.nodes): ad = g.nodes[n] + if g.degree(n) <= 1 and len(ad) == 1 and all([len(g[n][neigh]) == 0 for neigh in g.neighbors(n)]): + # If there's only the 'v' key left and the node is a leaf, and the edge that connect to the node have + # no attributes set, we can remove it + remove_node_mask[i] = 1 for k, sl in zip(self.atom_attrs, self.atom_attr_slice): + # idx > 0 means that the attribute is not the default value idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0 x[i, sl + idx] = 1 - # If the attribute is already there, mask out logits - # (or if the attribute is a negative attribute and has been filled) + if k == "v": + continue + # If the attribute + # - is already there (idx > 0), + # - or the attribute is a negative attribute and has been filled + # - or the attribute is a negative attribute and is not fillable (i.e. not a key of ad) + # then mask forward logits. + # For backward logits, positively mask if the attribute is there (idx > 0). if k in self.negative_attrs: if k in ad and idx > 0 or k not in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + # We don't want to make the attribute removable if it's not fillable (i.e. not a key of ad) + if k in ad: + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 elif k in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 # Account for charge and explicit Hs in atom as limiting the total valence max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]] # Special rule for Nitrogen @@ -256,8 +330,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data: s, e = self.atom_attr_logit_slice["expl_H"] set_node_attr_mask[i, s:e] = 0 + remove_edge_mask = torch.zeros((len(g.edges), 1)) + for i, (u, v) in enumerate(g.edges): + if g.degree(u) > 1 and g.degree(v) > 1: + if nx.algorithms.is_connected(graph_without_edge(g, (u, v))): + remove_edge_mask[i] = 1 edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) set_edge_attr_mask = torch.zeros((len(g.edges), self.num_edge_attr_logits)) + remove_edge_attr_mask = torch.zeros((len(g.edges), len(self.bond_attrs))) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): @@ -267,6 +347,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: if k in ad: # If the attribute is already there, mask out logits s, e = self.bond_attr_logit_slice[k] set_edge_attr_mask[i, s:e] = 0 + remove_edge_attr_mask[i, self.bond_attrs.index(k)] = 1 # Check which bonds don't bust the valence of their atoms if "type" not in ad: # Only if type isn't already set sl, _ = self.bond_attr_logit_slice["type"] @@ -293,11 +374,15 @@ def is_ok_non_edge(e): edge_index, edge_attr, non_edge_index=non_edge_index, - stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1), + stop_mask=torch.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge set_edge_attr_mask=set_edge_attr_mask, + remove_node_mask=remove_node_mask, + remove_node_attr_mask=remove_node_attr_mask, + remove_edge_mask=remove_edge_mask, + remove_edge_attr_mask=remove_edge_attr_mask, ) if self.num_rw_feat > 0: data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1) diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py new file mode 100644 index 00000000..c3e8d0f9 --- /dev/null +++ b/src/gflownet/tasks/make_rings.py @@ -0,0 +1,90 @@ +import os +import socket +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from rdkit import Chem +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor + +from gflownet.config import Config +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar + + +class MakeRingsTask(GFNTask): + """A toy task where the reward is the number of rings in the molecule.""" + + def __init__( + self, + rng: np.random.Generator, + ): + self.rng = rng + + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(y) + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log() + return RewardScalar(scalar_logreward.flatten()) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() + return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() + + +class MakeRingsTrainer(StandardOnlineTrainer): + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.num_workers = 8 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 6 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.tb.do_parameterize_p_b = True + + cfg.replay.use = False + + def setup_task(self): + self.task = MakeRingsTask(rng=self.rng) + + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C"], + charges=[0], # disable charge + chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality + num_rw_feat=0, + max_nodes=self.cfg.algo.max_nodes, + num_cond_dim=1, + ) + + +def main(): + hps = { + "log_dir": "./logs/debug_run_mr4", + "device": "cuda", + "num_training_steps": 10_000, + "num_workers": 8, + "algo": {"tb": {"do_parameterize_p_b": True}}, + } + os.makedirs(hps["log_dir"], exist_ok=True) + + trial = MakeRingsTrainer(hps) + trial.print_every = 1 + trial.run() + + +if __name__ == "__main__": + main() diff --git a/tests/test_frag_env.py b/tests/test_envs.py similarity index 76% rename from tests/test_frag_env.py rename to tests/test_envs.py index deda34e2..204a17cb 100644 --- a/tests/test_frag_env.py +++ b/tests/test_envs.py @@ -9,9 +9,11 @@ from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.models import bengio2021flow -def build_two_node_states(): +def build_two_node_states(ctx): # TODO: This is actually fairly generic code that will probably be reused by other tests in the future. # Having a proper class to handle graph-indexed hash maps would probably be good. graph_cache = {} @@ -21,7 +23,6 @@ def build_two_node_states(): # We're enumerating all states of length two, but we could've just as well randomly sampled # some states. env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) def g2h(g): gc = g.to_directed() @@ -73,11 +74,19 @@ def expand(s, idx): return [graph_by_idx[i] for i in list(nx.topological_sort(mdp_graph))] +def get_frag_env_ctx(): + return FragMolBuildingEnvContext(max_frags=2, fragments=bengio2021flow.FRAGMENTS[:20]) + + +def get_atom_env_ctx(): + return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) + + @pytest.fixture -def two_node_states(request): +def two_node_states_frags(request): data = request.config.cache.get("frag_env/two_node_states", None) if data is None: - data = build_two_node_states() + data = build_two_node_states(get_frag_env_ctx()) # pytest caches through JSON so we have to make a clean enough string request.config.cache.set("frag_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) else: @@ -85,13 +94,24 @@ def two_node_states(request): return data -def test_backwards_mask_equivalence(two_node_states): +@pytest.fixture +def two_node_states_atoms(request): + data = request.config.cache.get("atom_env/two_node_states", None) + if data is None: + data = build_two_node_states(get_atom_env_ctx()) + # pytest caches through JSON so we have to make a clean enough string + request.config.cache.set("atom_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) + else: + data = pickle.loads(base64.b64decode(data)) + return data + + +def _test_backwards_mask_equivalence(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=False) @@ -104,7 +124,7 @@ def test_backwards_mask_equivalence(two_node_states): raise ValueError() -def test_backwards_mask_equivalence_ipa(two_node_states): +def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -112,7 +132,6 @@ def test_backwards_mask_equivalence_ipa(two_node_states): This test also accounts for idempotent actions. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) cfg = OmegaConf.structured(Config) cfg.algo.max_nodes = 2 algo = TrajectoryBalance(env, ctx, None, cfg) @@ -141,3 +160,19 @@ def test_backwards_mask_equivalence_ipa(two_node_states): equivalence_classes.append(ipa) if n != len(equivalence_classes): raise ValueError() + + +def test_backwards_mask_equivalence_frag(two_node_states_frags): + _test_backwards_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_frag(two_node_states_frags): + _test_backwards_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_atom(two_node_states_atoms): + _test_backwards_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_atom(two_node_states_atoms): + _test_backwards_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx())