-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR replaces the "giant-dict-as-config" with a structured & typed config, which makes each module responsible for defining & documenting its own configuration. * little test * new config structure - in progress * trying config by names * further refactor progress * better pyi sort + fix moo example * tox * import fixes * add SQL + fix n_valid * tox test * fix mypy hook and convert qm9 to new cfg * better config generation * use generated config.py * use generated config.py * fix rng call types * fix test + tox * better config doc * fix deps * tox * re-fix deps * minor fixes for seh_frag_moo * tox * switch to OmegaConf * switch to omegaconf * fix pre-commit-config * add omegaconf dep * fix list defaults to fields * remove comment * version change and revert unnecessary torch.tensor call * addressing PR comments
- Loading branch information
Showing
24 changed files
with
861 additions
and
542 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
MAJOR="0" | ||
MINOR="0" | ||
MINOR="1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class TBConfig: | ||
"""Trajectory Balance config. | ||
Attributes | ||
---------- | ||
bootstrap_own_reward : bool | ||
Whether to bootstrap the reward with the own reward. (deprecated) | ||
epsilon : Optional[float] | ||
The epsilon parameter in log-flow smoothing (see paper) | ||
reward_loss_multiplier : float | ||
The multiplier for the reward loss when bootstrapping the reward. (deprecated) | ||
do_subtb : bool | ||
Whether to use the full N^2 subTB loss | ||
do_correct_idempotent : bool | ||
Whether to correct for idempotent actions | ||
do_parameterize_p_b : bool | ||
Whether to parameterize the P_B distribution (otherwise it is uniform) | ||
subtb_max_len : int | ||
The maximum length trajectories, used to cache subTB computation indices | ||
Z_learning_rate : float | ||
The learning rate for the logZ parameter (only relevant when do_subtb is False) | ||
Z_lr_decay : float | ||
The learning rate decay for the logZ parameter (only relevant when do_subtb is False) | ||
""" | ||
|
||
bootstrap_own_reward: bool = False | ||
epsilon: Optional[float] = None | ||
reward_loss_multiplier: float = 1.0 | ||
do_subtb: bool = False | ||
do_correct_idempotent: bool = False | ||
do_parameterize_p_b: bool = False | ||
subtb_max_len: int = 128 | ||
Z_learning_rate: float = 1e-4 | ||
Z_lr_decay: float = 50_000 | ||
|
||
|
||
@dataclass | ||
class MOQLConfig: | ||
gamma: float = 1 | ||
num_omega_samples: int = 32 | ||
num_objectives: int = 2 | ||
lambda_decay: int = 10_000 | ||
penalty: float = -10 | ||
|
||
|
||
@dataclass | ||
class A2CConfig: | ||
entropy: float = 0.01 | ||
gamma: float = 1 | ||
penalty: float = -10 | ||
|
||
|
||
@dataclass | ||
class FMConfig: | ||
epsilon: float = 1e-38 | ||
balanced_loss: bool = False | ||
leaf_coef: float = 10 | ||
correct_idempotent: bool = False | ||
|
||
|
||
@dataclass | ||
class SQLConfig: | ||
alpha: float = 0.01 | ||
gamma: float = 1 | ||
penalty: float = -10 | ||
|
||
|
||
@dataclass | ||
class AlgoConfig: | ||
"""Generic configuration for algorithms | ||
Attributes | ||
---------- | ||
method : str | ||
The name of the algorithm to use (e.g. "TB") | ||
global_batch_size : int | ||
The batch size for training | ||
max_len : int | ||
The maximum length of a trajectory | ||
max_nodes : int | ||
The maximum number of nodes in a generated graph | ||
max_edges : int | ||
The maximum number of edges in a generated graph | ||
illegal_action_logreward : float | ||
The log reward an agent gets for illegal actions | ||
offline_ratio: float | ||
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from | ||
`self.sampling_model` | ||
train_random_action_prob : float | ||
The probability of taking a random action during training | ||
valid_random_action_prob : float | ||
The probability of taking a random action during validation | ||
valid_sample_cond_info : bool | ||
Whether to sample conditioning information during validation (if False, expects a validation set of cond_info) | ||
sampling_tau : float | ||
The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) | ||
""" | ||
|
||
method: str = "TB" | ||
global_batch_size: int = 64 | ||
max_len: int = 128 | ||
max_nodes: int = 128 | ||
max_edges: int = 128 | ||
illegal_action_logreward: float = -100 | ||
offline_ratio: float = 0.5 | ||
train_random_action_prob: float = 0.0 | ||
valid_random_action_prob: float = 0.0 | ||
valid_sample_cond_info: bool = True | ||
sampling_tau: float = 0.0 | ||
tb: TBConfig = TBConfig() | ||
moql: MOQLConfig = MOQLConfig() | ||
a2c: A2CConfig = A2CConfig() | ||
fm: FMConfig = FMConfig() | ||
sql: SQLConfig = SQLConfig() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.