From b37b01e2bf76e97ed6ad56db5d04b3c60786a94f Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 16 Aug 2023 11:20:07 -0400 Subject: [PATCH] fix broken merge --- src/gflownet/envs/basic_graph_ctx.py | 3 ++- src/gflownet/tasks/basic_graph_task.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py index c862a394..362efb42 100644 --- a/src/gflownet/envs/basic_graph_ctx.py +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -43,6 +43,7 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals "v": [0, 1], # Imagine this is as colors } self._num_rw_feat = 8 + self.not_a_molecule_env = True self.num_new_node_values = len(self.node_attr_values["v"]) self.num_node_attr_logits = None @@ -160,7 +161,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - if self.graph_data is not None: + if self.graph_data is not None and False: # This caching achieves two things, first we'll speed things up gidx = self.get_graph_idx(g) if gidx in self._cache: diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index 4f961f69..849162f4 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -915,10 +915,29 @@ def main(): "algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4}, "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, # } + + hps = { + "num_training_steps": 20000, + "validate_every": 100, + "num_workers": 0, + "log_dir": "./logs/basic_graphs/run_6n_pb2", + "model": {"num_layers": 2, "num_emb": 256}, + # WARNING: SubTB is detached targets! -- not + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, + # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, + # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, + "algo": { + "global_batch_size": 512, + "tb": {"do_subtb": True, "do_parameterize_p_b": False}, + "max_nodes": 6, + "offline_ratio": 0 / 4, + }, + "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": False, "train_ratio": 1}}, # + } if hps["task"]["basic_graph"]["do_supervised"]: - trial = BGSupervisedTrainer(hps, torch.device("cuda")) + trial = BGSupervisedTrainer(hps) else: - trial = BasicGraphTaskTrainer(hps, torch.device("cuda")) + trial = BasicGraphTaskTrainer(hps) torch.set_num_threads(1) trial.verbose = True trial.print_every = 1