diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 2035081a..593c7b3b 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -577,7 +577,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): P_F_sums = scatter_sum(P_F[idces + offset], dests) P_B_sums = scatter_sum(P_B[idces + offset], dests) F_start = F[offset : offset + T].repeat_interleave(T - ar[:T]) - F_end = F_and_R[fidces] # .detach() + F_end = F_and_R[fidces] total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py index 362efb42..1fce3801 100644 --- a/src/gflownet/envs/basic_graph_ctx.py +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -38,6 +38,7 @@ class BasicGraphContext(GraphBuildingEnvContext): def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False): self.max_nodes = max_nodes self.output_gid = output_gid + self.use_graph_cache = False self.node_attr_values = { "v": [0, 1], # Imagine this is as colors @@ -159,9 +160,9 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int type_idx = self.bck_action_type_order.index(action.action) return (type_idx, int(row), int(col)) - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - if self.graph_data is not None and False: + if self.graph_data is not None and self.use_graph_cache: # This caching achieves two things, first we'll speed things up gidx = self.get_graph_idx(g) if gidx in self._cache: @@ -207,7 +208,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: gid=gid, ) ) - if self.graph_data is not None: + if self.graph_data is not None and self.use_graph_cache: self._cache[gidx] = data return data diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index 849162f4..36df986e 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -14,6 +14,7 @@ from torch_scatter import scatter_logsumexp from tqdm import tqdm +from gflownet.algo.config import TBVariant from gflownet.algo.flow_matching import FlowMatching from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config @@ -272,7 +273,7 @@ def set_default_hps(self, cfg: Config): cfg.model.num_layers = 8 cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.do_correct_idempotent = True # Important to converge to the true p(x) - cfg.algo.tb.do_subtb = True + cfg.algo.tb.variant = TBVariant.SubTB1 cfg.algo.tb.do_parameterize_p_b = False cfg.algo.illegal_action_logreward = -30 # Although, all states are legal here, this shouldn't matter cfg.num_workers = 8 @@ -293,6 +294,7 @@ def setup(self): self.env = GraphBuildingEnv() self._data = load_two_col_data(self.cfg.task.basic_graph.data_root, max_nodes=max_nodes) self.ctx = BasicGraphContext(max_nodes, num_cond_dim=1, graph_data=self._data, output_gid=True) + self.ctx.use_graph_cache = mcfg.do_tabular_model self._do_supervised = self.cfg.task.basic_graph.do_supervised self.training_data = TwoColorGraphDataset( @@ -326,12 +328,6 @@ def setup(self): model = TabularHashingModel(self.exact_prob_cb) if 0: model.set_values(self.exact_prob_cb) - if 0: # reload_bit - model.load_state_dict( - torch.load( - "/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_4/model_state.pt" - )["models_state_dict"][0] - ) else: model = GraphTransformerGFN( self.ctx, @@ -362,7 +358,6 @@ def setup(self): algo = self.cfg.algo.method if algo == "TB" or algo == "subTB": self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) - self.algo.graph_sampler.sample_temp = 100 elif algo == "FM": self.algo = FlowMatching(self.env, self.ctx, self.rng, self.cfg) self.task = BasicGraphTask( @@ -611,18 +606,8 @@ def compute_cache(self, tqdm_disable=None): bs = states[bi : bi + mbs] bD = states_Data[bi : bi + mbs] indices = list(range(bi, bi + len(bs))) - # TODO: if the environment's masks are well designed, this non_terminal business shouldn't be necessary - # non_terminals = [(i, j, k) for i, j, k in zip(bs, bD, indices) if not self.is_terminal(i)] - # if not len(non_terminals): - # self.precomputed_batches.append(None) - # self.precomputed_indices.append(None) - # continue - # bs, bD, indices = zip(*non_terminals) batch = self.ctx.collate(bD).to(dev) self.precomputed_batches.append(batch) - - # with torch.no_grad(): - # cat, *_, mo = self.trial.model(batch, ones[:len(bs)]) actions = [[] for i in range(len(bs))] offset = 0 for u, i in enumerate(ctx.action_type_order): @@ -752,11 +737,11 @@ def get_bck_trajectory_test_split(self, r, seed=142857): while len(test_set) < n: i0 = np.random.randint(len(self.states)) s0 = self.states[i0] - if len(s0.nodes) < 7: # TODO: unhardcode this + if len(s0.nodes) < 7: # TODO: unhardcode this? continue s = s0 idx = i0 - while len(s.nodes) > 5: # TODO: unhardcode this + while len(s.nodes) > 5: # TODO: unhardcode this? test_set.add(idx) actions = [ (u, a.item(), b.item()) @@ -886,35 +871,6 @@ def build_validation_data_loader(self) -> DataLoader: def main(): # Launch a test job - hps = { - "num_training_steps": 20000, - "validate_every": 100, - "num_workers": 16, - "log_dir": "./logs/basic_graphs/run_6n_19", - "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.995, "lr_decay": 1e10}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0}, - "algo": {"global_batch_size": 4096, "tb": {"do_subtb": True}, "max_nodes": 6}, - "task": { - "basic_graph": {"do_supervised": False, "do_tabular_model": True} - }, # Change this to launch a supervised job - } - - hps = { - "num_training_steps": 20000, - "validate_every": 100, - "num_workers": 16, - "log_dir": "./logs/basic_graphs/run_6n_27", - "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! -- not - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, - "algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4}, - "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, # - } hps = { "num_training_steps": 20000, @@ -922,13 +878,10 @@ def main(): "num_workers": 0, "log_dir": "./logs/basic_graphs/run_6n_pb2", "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! -- not - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-4}, "algo": { - "global_batch_size": 512, - "tb": {"do_subtb": True, "do_parameterize_p_b": False}, + "global_batch_size": 64, + "tb": {"variant": "SubTB1", "do_parameterize_p_b": False}, "max_nodes": 6, "offline_ratio": 0 / 4, },