Skip to content

Commit

Permalink
fixed tabular node ordering problem
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 10, 2023
1 parent 424ad7a commit 2040356
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 9 deletions.
10 changes: 7 additions & 3 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
)
if self.cfg.do_subtb:
self._subtb_max_len = self.global_cfg.algo.max_len + 2
self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info?
self._init_subtb(torch.device(cfg.device)) # TODO: where are we getting device info?

def create_training_data_from_own_samples(
self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float
Expand Down Expand Up @@ -138,7 +138,11 @@ def create_training_data_from_graphs(self, graphs):
trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}]
A list of trajectories.
"""
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
if hasattr(self.ctx, "relabel"):
relabel = self.ctx.relabel
else:
relabel = lambda *x: x
trajs = [{"traj": [relabel(*t) for t in generate_forward_trajectory(i)]} for i in graphs]
for traj in trajs:
n_back = [
self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent)
Expand Down Expand Up @@ -505,6 +509,6 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
P_F_sums = scatter_sum(P_F[idces + offset], dests)
P_B_sums = scatter_sum(P_B[idces + offset], dests)
F_start = F[offset : offset + T].repeat_interleave(T - ar[:T])
F_end = F_and_R[fidces]
F_end = F_and_R[fidces] # .detach()
total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T]
return total_loss
1 change: 1 addition & 0 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def iterator(self):
batch.preferences = cond_info.get("preferences", None)
batch.focus_dir = cond_info.get("focus_dir", None)
batch.extra_info = extra_info
batch.trajs = trajs
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

Expand Down
37 changes: 36 additions & 1 deletion src/gflownet/envs/basic_graph_ctx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Dict, List, Tuple

import networkx as nx
Expand All @@ -9,6 +10,7 @@
Graph,
GraphAction,
GraphActionType,
GraphBuildingEnv,
GraphBuildingEnvContext,
graph_without_edge,
)
Expand Down Expand Up @@ -63,13 +65,33 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals
GraphActionType.RemoveNode,
GraphActionType.RemoveEdge,
]
self._env = GraphBuildingEnv()
self.device = torch.device("cpu")
self.graph_data = graph_data
self.hash_to_graphs: Dict[str, int] = {}
if graph_data is not None:
states_hash = [hashg(i) for i in graph_data]
for i, h, g in zip(range(len(graph_data)), states_hash, graph_data):
self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)]
self._cache = {}

def relabel(self, g, ga):
if ga.action != GraphActionType.Stop:
gp = self._env.step(g, ga)
ig = self.graph_data[self.get_graph_idx(g)]
rmap = nx.vf2pp_isomorphism(g, ig, "v")
ga = copy.copy(ga)
if rmap is None and not len(g):
rmap = {0: 0}
if ga.source is not None:
ga.source = rmap[ga.source]
if ga.target is not None:
ga.target = rmap[ga.target]
if ga.action != GraphActionType.Stop:
gp2 = self._env.step(ig, ga)
if not nx.is_isomorphic(gp2, gp, lambda a, b: a == b):
raise ValueError()
return copy.deepcopy(ig), ga

def get_graph_idx(self, g, default=None):
h = hashg(g)
Expand Down Expand Up @@ -138,6 +160,16 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int

def graph_to_Data(self, g: Graph) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
if self.graph_data is not None:
# This caching achieves two things, first we'll speed things up
gidx = self.get_graph_idx(g)
if gidx in self._cache:
return self._cache[gidx]
# And second we'll always have the same node ordering, which is necessary for the tabular model
# to work. In the non-tabular case, we're hopefully using a model that's invariant to node ordering, so this
# shouldn't cause any problems.
g = self.graph_data[gidx]

x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat))
x[0, -1] = len(g.nodes) == 0
remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
Expand All @@ -160,7 +192,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T
gid = self.get_graph_idx(g) if self.output_gid else 0

return self._preprocess(
data = self._preprocess(
gd.Data(
x,
edge_index,
Expand All @@ -174,6 +206,9 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
gid=gid,
)
)
if self.graph_data is not None:
self._cache[gidx] = data
return data

def _preprocess(self, g: gd.Data) -> gd.Data:
if self._num_rw_feat > 0:
Expand Down
40 changes: 35 additions & 5 deletions src/gflownet/tasks/basic_graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,12 @@ def setup(self):
)
elif mcfg.do_tabular_model:
model = TabularHashingModel(self.exact_prob_cb)
if 1:
if 0:
model.set_values(self.exact_prob_cb)
if 0: # reload_bit
model.load_state_dict(
torch.load(
"/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_9/model_state.pt"
"/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_4/model_state.pt"
)["models_state_dict"][0]
)
else:
Expand All @@ -353,6 +355,8 @@ def setup(self):
self.opt = torch.optim.SGD(
params, self.cfg.opt.learning_rate, self.cfg.opt.momentum, weight_decay=self.cfg.opt.weight_decay
)
elif self.cfg.opt.opt == "RMSProp":
self.opt = torch.optim.RMSprop(params, self.cfg.opt.learning_rate, weight_decay=self.cfg.opt.weight_decay)
self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay))

algo = self.cfg.algo.method
Expand Down Expand Up @@ -481,6 +485,16 @@ def __call__(self, g: gd.Batch, cond_info):
logF_s,
)

def set_values(self, epc):
"""Set the values of the table to the true values of the MDP. This tabular model should have 0 error."""
for i in tqdm(range(len(epc.states))):
for neighbor in list(epc.mdp_graph.neighbors(i)):
for _, edge in epc.mdp_graph.get_edge_data(i, neighbor).items():
a, F = edge["a"], edge["F"]
self.table.data[self.slices[i][a[0]] + a[1] * self.shapes[i][a[0]][1] + a[2]] = F
self.table.data[self.slices[i][3]] = epc.mdp_graph.nodes[i]["F"]
self._logZ.data = torch.tensor(epc.mdp_graph.nodes[0]["F"]).float()

def logZ(self, cond_info: Tensor):
return self._logZ.tile(cond_info.shape[0]).reshape((-1, 1)) # Why is the reshape necessary?

Expand Down Expand Up @@ -876,15 +890,31 @@ def main():
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 16,
"log_dir": "./logs/basic_graphs/run_6n_14",
"log_dir": "./logs/basic_graphs/run_6n_19",
"model": {"num_layers": 2, "num_emb": 256},
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10},
# WARNING: SubTB is detached targets!
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.995, "lr_decay": 1e10},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0},
"algo": {"global_batch_size": 2048, "tb": {"do_subtb": False}, "max_nodes": 6},
"algo": {"global_batch_size": 4096, "tb": {"do_subtb": True}, "max_nodes": 6},
"task": {
"basic_graph": {"do_supervised": False, "do_tabular_model": True}
}, # Change this to launch a supervised job
}

hps = {
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 16,
"log_dir": "./logs/basic_graphs/run_6n_27",
"model": {"num_layers": 2, "num_emb": 256},
# WARNING: SubTB is detached targets! -- not
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0},
"algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4},
"task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, #
}
if hps["task"]["basic_graph"]["do_supervised"]:
trial = BGSupervisedTrainer(hps, torch.device("cuda"))
else:
Expand Down

0 comments on commit 2040356

Please sign in to comment.