diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index cc868758..782f0750 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - if "preferences" in trajs[0]['cond_info'].keys(): - batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs]) - if "focus_dir" in trajs[0]['cond_info'].keys(): - batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]["cond_info"].keys(): + batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]["cond_info"].keys(): + batch.focus_dir = torch.stack([t["cond_info"]["focus_dir"] for t in trajs]) if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]