Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Sep 6, 2023
1 parent 5cd4724 commit ba77d4a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
P_F_sums = scatter_sum(P_F[idces + offset], dests)
P_B_sums = scatter_sum(P_B[idces + offset], dests)
F_start = F[offset : offset + T].repeat_interleave(T - ar[:T])
F_end = F_and_R[fidces] # .detach()
F_end = F_and_R[fidces]
total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T]
return total_loss

Expand Down
7 changes: 4 additions & 3 deletions src/gflownet/envs/basic_graph_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BasicGraphContext(GraphBuildingEnvContext):
def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False):
self.max_nodes = max_nodes
self.output_gid = output_gid
self.use_graph_cache = False

self.node_attr_values = {
"v": [0, 1], # Imagine this is as colors
Expand Down Expand Up @@ -159,9 +160,9 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
type_idx = self.bck_action_type_order.index(action.action)
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
if self.graph_data is not None and False:
if self.graph_data is not None and self.use_graph_cache:
# This caching achieves two things, first we'll speed things up
gidx = self.get_graph_idx(g)
if gidx in self._cache:
Expand Down Expand Up @@ -207,7 +208,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
gid=gid,
)
)
if self.graph_data is not None:
if self.graph_data is not None and self.use_graph_cache:
self._cache[gidx] = data
return data

Expand Down
63 changes: 8 additions & 55 deletions src/gflownet/tasks/basic_graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch_scatter import scatter_logsumexp
from tqdm import tqdm

from gflownet.algo.config import TBVariant
from gflownet.algo.flow_matching import FlowMatching
from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.config import Config
Expand Down Expand Up @@ -272,7 +273,7 @@ def set_default_hps(self, cfg: Config):
cfg.model.num_layers = 8
cfg.algo.valid_offline_ratio = 0
cfg.algo.tb.do_correct_idempotent = True # Important to converge to the true p(x)
cfg.algo.tb.do_subtb = True
cfg.algo.tb.variant = TBVariant.SubTB1
cfg.algo.tb.do_parameterize_p_b = False
cfg.algo.illegal_action_logreward = -30 # Although, all states are legal here, this shouldn't matter
cfg.num_workers = 8
Expand All @@ -293,6 +294,7 @@ def setup(self):
self.env = GraphBuildingEnv()
self._data = load_two_col_data(self.cfg.task.basic_graph.data_root, max_nodes=max_nodes)
self.ctx = BasicGraphContext(max_nodes, num_cond_dim=1, graph_data=self._data, output_gid=True)
self.ctx.use_graph_cache = mcfg.do_tabular_model
self._do_supervised = self.cfg.task.basic_graph.do_supervised

self.training_data = TwoColorGraphDataset(
Expand Down Expand Up @@ -326,12 +328,6 @@ def setup(self):
model = TabularHashingModel(self.exact_prob_cb)
if 0:
model.set_values(self.exact_prob_cb)
if 0: # reload_bit
model.load_state_dict(
torch.load(
"/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_4/model_state.pt"
)["models_state_dict"][0]
)
else:
model = GraphTransformerGFN(
self.ctx,
Expand Down Expand Up @@ -362,7 +358,6 @@ def setup(self):
algo = self.cfg.algo.method
if algo == "TB" or algo == "subTB":
self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg)
self.algo.graph_sampler.sample_temp = 100
elif algo == "FM":
self.algo = FlowMatching(self.env, self.ctx, self.rng, self.cfg)
self.task = BasicGraphTask(
Expand Down Expand Up @@ -611,18 +606,8 @@ def compute_cache(self, tqdm_disable=None):
bs = states[bi : bi + mbs]
bD = states_Data[bi : bi + mbs]
indices = list(range(bi, bi + len(bs)))
# TODO: if the environment's masks are well designed, this non_terminal business shouldn't be necessary
# non_terminals = [(i, j, k) for i, j, k in zip(bs, bD, indices) if not self.is_terminal(i)]
# if not len(non_terminals):
# self.precomputed_batches.append(None)
# self.precomputed_indices.append(None)
# continue
# bs, bD, indices = zip(*non_terminals)
batch = self.ctx.collate(bD).to(dev)
self.precomputed_batches.append(batch)

# with torch.no_grad():
# cat, *_, mo = self.trial.model(batch, ones[:len(bs)])
actions = [[] for i in range(len(bs))]
offset = 0
for u, i in enumerate(ctx.action_type_order):
Expand Down Expand Up @@ -752,11 +737,11 @@ def get_bck_trajectory_test_split(self, r, seed=142857):
while len(test_set) < n:
i0 = np.random.randint(len(self.states))
s0 = self.states[i0]
if len(s0.nodes) < 7: # TODO: unhardcode this
if len(s0.nodes) < 7: # TODO: unhardcode this?
continue
s = s0
idx = i0
while len(s.nodes) > 5: # TODO: unhardcode this
while len(s.nodes) > 5: # TODO: unhardcode this?
test_set.add(idx)
actions = [
(u, a.item(), b.item())
Expand Down Expand Up @@ -886,49 +871,17 @@ def build_validation_data_loader(self) -> DataLoader:

def main():
# Launch a test job
hps = {
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 16,
"log_dir": "./logs/basic_graphs/run_6n_19",
"model": {"num_layers": 2, "num_emb": 256},
# WARNING: SubTB is detached targets!
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.995, "lr_decay": 1e10},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0},
"algo": {"global_batch_size": 4096, "tb": {"do_subtb": True}, "max_nodes": 6},
"task": {
"basic_graph": {"do_supervised": False, "do_tabular_model": True}
}, # Change this to launch a supervised job
}

hps = {
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 16,
"log_dir": "./logs/basic_graphs/run_6n_27",
"model": {"num_layers": 2, "num_emb": 256},
# WARNING: SubTB is detached targets! -- not
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0},
"algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4},
"task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, #
}

hps = {
"num_training_steps": 20000,
"validate_every": 100,
"num_workers": 0,
"log_dir": "./logs/basic_graphs/run_6n_pb2",
"model": {"num_layers": 2, "num_emb": 256},
# WARNING: SubTB is detached targets! -- not
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-2},
# "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"},
# "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0},
"opt": {"adam_eps": 1e-8, "learning_rate": 3e-4},
"algo": {
"global_batch_size": 512,
"tb": {"do_subtb": True, "do_parameterize_p_b": False},
"global_batch_size": 64,
"tb": {"variant": "SubTB1", "do_parameterize_p_b": False},
"max_nodes": 6,
"offline_ratio": 0 / 4,
},
Expand Down

0 comments on commit ba77d4a

Please sign in to comment.