Skip to content

Commit

Permalink
fix broken merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 16, 2023
1 parent 2040356 commit b37b01e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/gflownet/envs/basic_graph_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals
"v": [0, 1], # Imagine this is as colors
}
self._num_rw_feat = 8
self.not_a_molecule_env = True

self.num_new_node_values = len(self.node_attr_values["v"])
self.num_node_attr_logits = None
Expand Down Expand Up @@ -160,7 +161,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int

def graph_to_Data(self, g: Graph) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
if self.graph_data is not None:
if self.graph_data is not None and False:
# This caching achieves two things, first we'll speed things up
gidx = self.get_graph_idx(g)
if gidx in self._cache:
Expand Down
23 changes: 21 additions & 2 deletions src/gflownet/tasks/basic_graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,29 @@ def main():
"algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4},
"task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, #
}

hps = {
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 0,
"log_dir": "./logs/basic_graphs/run_6n_pb2",
"model": {"num_layers": 2, "num_emb": 256},
# WARNING: SubTB is detached targets! -- not
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0},
"algo": {
"global_batch_size": 512,
"tb": {"do_subtb": True, "do_parameterize_p_b": False},
"max_nodes": 6,
"offline_ratio": 0 / 4,
},
"task": {"basic_graph": {"do_supervised": False, "do_tabular_model": False, "train_ratio": 1}}, #
}
if hps["task"]["basic_graph"]["do_supervised"]:
trial = BGSupervisedTrainer(hps, torch.device("cuda"))
trial = BGSupervisedTrainer(hps)
else:
trial = BasicGraphTaskTrainer(hps, torch.device("cuda"))
trial = BasicGraphTaskTrainer(hps)
torch.set_num_threads(1)
trial.verbose = True
trial.print_every = 1
Expand Down

0 comments on commit b37b01e

Please sign in to comment.