From ad0db6ca0fd8c80ca4005bdca92020eef439a700 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 07:36:11 -0600 Subject: [PATCH 1/2] fix: made focus_dir and preferences accessible at the batch level --- src/gflownet/data/data_source.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 90a2848f..cc868758 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - if "preferences" in trajs[0]: - batch.preferences = torch.stack([t["preferences"] for t in trajs]) - if "focus_dir" in trajs[0]: - batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]['cond_info'].keys(): + batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]['cond_info'].keys(): + batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs]) if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] From 5ab80df44d18c05c3747123467cd0e364e9a4e52 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 07:41:29 -0600 Subject: [PATCH 2/2] tox --- src/gflownet/data/data_source.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index cc868758..782f0750 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - if "preferences" in trajs[0]['cond_info'].keys(): - batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs]) - if "focus_dir" in trajs[0]['cond_info'].keys(): - batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]["cond_info"].keys(): + batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]["cond_info"].keys(): + batch.focus_dir = torch.stack([t["cond_info"]["focus_dir"] for t in trajs]) if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]