Skip to content

Commit

Permalink
Feature: better configs (#93)
Browse files Browse the repository at this point in the history
This PR replaces the "giant-dict-as-config" with a structured & typed config, which makes each module responsible for defining & documenting its own configuration.

* little test

* new config structure - in progress

* trying config by names

* further refactor progress

* better pyi sort + fix moo example

* tox

* import fixes

* add SQL + fix n_valid

* tox test

* fix mypy hook and convert qm9 to new cfg

* better config generation

* use generated config.py

* use generated config.py

* fix rng call types

* fix test + tox

* better config doc

* fix deps

* tox

* re-fix deps

* minor fixes for seh_frag_moo

* tox

* switch to OmegaConf

* switch to omegaconf

* fix pre-commit-config

* add omegaconf dep

* fix list defaults to fields

* remove comment

* version change and revert unnecessary torch.tensor call

* addressing PR comments
  • Loading branch information
bengioe authored Jul 18, 2023
1 parent 6807bb7 commit ffabcfd
Show file tree
Hide file tree
Showing 24 changed files with 861 additions and 542 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
MAJOR="0"
MINOR="0"
MINOR="1"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies = [
"botorch",
"pyro-ppl",
"gpytorch",
"omegaconf>=2.3",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions requirements/dev_3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ numpy==1.24.2
# torch-geometric
oauthlib==3.2.2
# via requests-oauthlib
omegaconf==2.3.0
# via gflownet
opt-einsum==3.3.0
# via pyro-ppl
packaging==23.1
Expand Down
2 changes: 2 additions & 0 deletions requirements/main_3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ numpy==1.24.1
# torch-geometric
oauthlib==3.2.2
# via requests-oauthlib
omegaconf==2.3.0
# via gflownet
opt-einsum==3.3.0
# via pyro-ppl
packaging==23.0
Expand Down
29 changes: 11 additions & 18 deletions src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict

import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor

from gflownet.config import Config
from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory

from .graph_sampling import GraphSampler
Expand All @@ -17,9 +16,7 @@ def __init__(
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
rng: np.random.RandomState,
hps: Dict[str, Any],
max_len=None,
max_nodes=None,
cfg: Config,
):
"""Advantage Actor-Critic implementation, see
Asynchronous Methods for Deep Reinforcement Learning,
Expand All @@ -38,29 +35,25 @@ def __init__(
A context.
rng: np.random.RandomState
rng used to take random actions
hps: Dict[str, Any]
Hyperparameter dictionary, see above for used keys.
max_len: int
If not None, ends trajectories of more than max_len steps.
max_nodes: int
If not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
cfg: Config
The experiment configuration
"""
self.ctx = ctx
self.env = env
self.rng = rng
self.max_len = max_len
self.max_nodes = max_nodes
self.illegal_action_logreward = hps["illegal_action_logreward"]
self.entropy_coef = hps.get("a2c_entropy", 0.01)
self.gamma = hps.get("a2c_gamma", 1)
self.invalid_penalty = hps.get("a2c_penalty", -10)
self.max_len = cfg.algo.max_len
self.max_nodes = cfg.algo.max_nodes
self.illegal_action_logreward = cfg.algo.illegal_action_logreward
self.entropy_coef = cfg.algo.a2c.entropy
self.gamma = cfg.algo.a2c.gamma
self.invalid_penalty = cfg.algo.a2c.penalty
assert self.gamma == 1
self.bootstrap_own_reward = False
# Experimental flags
self.sample_temp = 1
self.do_q_prime_correction = False
self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp)
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp)

def create_training_data_from_own_samples(
self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float
Expand Down
119 changes: 119 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
class TBConfig:
"""Trajectory Balance config.
Attributes
----------
bootstrap_own_reward : bool
Whether to bootstrap the reward with the own reward. (deprecated)
epsilon : Optional[float]
The epsilon parameter in log-flow smoothing (see paper)
reward_loss_multiplier : float
The multiplier for the reward loss when bootstrapping the reward. (deprecated)
do_subtb : bool
Whether to use the full N^2 subTB loss
do_correct_idempotent : bool
Whether to correct for idempotent actions
do_parameterize_p_b : bool
Whether to parameterize the P_B distribution (otherwise it is uniform)
subtb_max_len : int
The maximum length trajectories, used to cache subTB computation indices
Z_learning_rate : float
The learning rate for the logZ parameter (only relevant when do_subtb is False)
Z_lr_decay : float
The learning rate decay for the logZ parameter (only relevant when do_subtb is False)
"""

bootstrap_own_reward: bool = False
epsilon: Optional[float] = None
reward_loss_multiplier: float = 1.0
do_subtb: bool = False
do_correct_idempotent: bool = False
do_parameterize_p_b: bool = False
subtb_max_len: int = 128
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000


@dataclass
class MOQLConfig:
gamma: float = 1
num_omega_samples: int = 32
num_objectives: int = 2
lambda_decay: int = 10_000
penalty: float = -10


@dataclass
class A2CConfig:
entropy: float = 0.01
gamma: float = 1
penalty: float = -10


@dataclass
class FMConfig:
epsilon: float = 1e-38
balanced_loss: bool = False
leaf_coef: float = 10
correct_idempotent: bool = False


@dataclass
class SQLConfig:
alpha: float = 0.01
gamma: float = 1
penalty: float = -10


@dataclass
class AlgoConfig:
"""Generic configuration for algorithms
Attributes
----------
method : str
The name of the algorithm to use (e.g. "TB")
global_batch_size : int
The batch size for training
max_len : int
The maximum length of a trajectory
max_nodes : int
The maximum number of nodes in a generated graph
max_edges : int
The maximum number of edges in a generated graph
illegal_action_logreward : float
The log reward an agent gets for illegal actions
offline_ratio: float
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
`self.sampling_model`
train_random_action_prob : float
The probability of taking a random action during training
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Whether to sample conditioning information during validation (if False, expects a validation set of cond_info)
sampling_tau : float
The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta)
"""

method: str = "TB"
global_batch_size: int = 64
max_len: int = 128
max_nodes: int = 128
max_edges: int = 128
illegal_action_logreward: float = -100
offline_ratio: float = 0.5
train_random_action_prob: float = 0.0
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
tb: TBConfig = TBConfig()
moql: MOQLConfig = MOQLConfig()
a2c: A2CConfig = A2CConfig()
fm: FMConfig = FMConfig()
sql: SQLConfig = SQLConfig()
38 changes: 17 additions & 21 deletions src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Dict

import numpy as np
import torch
import torch.nn as nn
Expand All @@ -8,13 +6,15 @@
from torch import Tensor
from torch_scatter import scatter

from gflownet.config import Config
from gflownet.envs.graph_building_env import (
GraphActionCategorical,
GraphBuildingEnv,
GraphBuildingEnvContext,
generate_forward_trajectory,
)
from gflownet.models.graph_transformer import GraphTransformer, mlp
from gflownet.train import GFNTask

from .graph_sampling import GraphSampler

Expand Down Expand Up @@ -165,10 +165,9 @@ def __init__(
self,
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
task: GFNTask,
rng: np.random.RandomState,
hps: Dict[str, Any],
max_len=None,
max_nodes=None,
cfg: Config,
):
"""Envelope Q-Learning implementation, see
A Generalized Algorithm for Multi-Objective Reinforcement Learning and Policy Adaptation,
Expand All @@ -187,31 +186,28 @@ def __init__(
A context.
rng: np.random.RandomState
rng used to take random actions
hps: Dict[str, Any]
Hyperparameter dictionary, see above for used keys.
max_len: int
If not None, ends trajectories of more than max_len steps.
max_nodes: int
If not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
cfg: Config
The experiment configuration
"""
self.ctx = ctx
self.env = env
self.task = task
self.rng = rng
self.max_len = max_len
self.max_nodes = max_nodes
self.illegal_action_logreward = hps["illegal_action_logreward"]
self.gamma = hps.get("moql_gamma", 1)
self.num_objectives = len(hps["objectives"])
self.num_omega_samples = hps.get("moql_num_omega_samples", 32)
self.Lambda_decay = hps.get("moql_lambda_decay", 10_000)
self.invalid_penalty = hps.get("moql_penalty", -10)
self.max_len = cfg.algo.max_len
self.max_nodes = cfg.algo.max_nodes
self.illegal_action_logreward = cfg.algo.illegal_action_logreward
self.gamma = cfg.algo.moql.gamma
self.num_objectives = cfg.algo.moql.num_objectives
self.num_omega_samples = cfg.algo.moql.num_omega_samples
self.lambda_decay = cfg.algo.moql.lambda_decay
self.invalid_penalty = cfg.algo.moql.penalty
self._num_updates = 0
assert self.gamma == 1
self.bootstrap_own_reward = False
# Experimental flags
self.sample_temp = 1
self.do_q_prime_correction = False
self.graph_sampler = GraphSampler(ctx, env, max_len, max_nodes, rng, self.sample_temp)
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp)

def create_training_data_from_own_samples(
self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float
Expand Down Expand Up @@ -396,7 +392,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
# and L_B
loss_B = abs((w * y).sum(1) - (w * Q_saw).sum(1))

Lambda = 1 - self.Lambda_decay / (self.Lambda_decay + self._num_updates)
Lambda = 1 - self.lambda_decay / (self.lambda_decay + self._num_updates)
losses = (1 - Lambda) * loss_A + Lambda * loss_B
self._num_updates += 1

Expand Down
17 changes: 7 additions & 10 deletions src/gflownet/algo/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Dict

import networkx as nx
import numpy as np
import torch
Expand All @@ -8,6 +6,7 @@
from torch_scatter import scatter

from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.config import Config
from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
Expand Down Expand Up @@ -41,17 +40,15 @@ def __init__(
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
rng: np.random.RandomState,
hps: Dict[str, Any],
max_len=None,
max_nodes=None,
cfg: Config,
):
super().__init__(env, ctx, rng, hps, max_len=max_len, max_nodes=max_nodes)
self.fm_epsilon = torch.as_tensor(hps.get("fm_epsilon", 1e-38)).log()
super().__init__(env, ctx, rng, cfg)
self.fm_epsilon = torch.as_tensor(cfg.algo.fm.epsilon).log()
# We include the "balanced loss" as a possibility to reproduce results from the FM paper, but
# in a number of settings the regular loss is more stable.
self.fm_balanced_loss = hps.get("fm_balanced_loss", False)
self.fm_leaf_coef = hps.get("fm_leaf_coef", 10)
self.correct_idempotent = self.correct_idempotent or hps.get("fm_correct_idempotent", False)
self.fm_balanced_loss = cfg.algo.fm.balanced_loss
self.fm_leaf_coef = cfg.algo.fm.leaf_coef
self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent

def construct_batch(self, trajs, cond_info, log_rewards):
"""Construct a batch from a list of trajectories and their information
Expand Down
9 changes: 3 additions & 6 deletions src/gflownet/algo/multiobjective_reinforce.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict

import numpy as np
import torch
import torch_geometric.data as gd
from torch_scatter import scatter

from gflownet.algo.trajectory_balance import TrajectoryBalance, TrajectoryBalanceModel
from gflownet.config import Config
from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext


Expand All @@ -19,11 +18,9 @@ def __init__(
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
rng: np.random.RandomState,
hps: Dict[str, Any],
max_len=None,
max_nodes=None,
cfg: Config,
):
super().__init__(env, ctx, rng, hps, max_len, max_nodes)
super().__init__(env, ctx, rng, cfg)

def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0):
"""Compute multi objective REINFORCE loss over trajectories contained in the batch"""
Expand Down
Loading

0 comments on commit ffabcfd

Please sign in to comment.