Skip to content

Commit

Permalink
Merge branch 'julien-fix-gpu-mem-bust' into julien-harmonize-use-of-m…
Browse files Browse the repository at this point in the history
…asks
  • Loading branch information
julienroyd committed Apr 4, 2024
2 parents fb6dac1 + 5ab80df commit 5f953c6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info):
batch.num_online = sum(t.get("is_online", 0) for t in trajs)
batch.num_offline = len(trajs) - batch.num_online
batch.extra_info = batch_info
if "preferences" in trajs[0]:
batch.preferences = torch.stack([t["preferences"] for t in trajs])
if "focus_dir" in trajs[0]:
batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs])
if "preferences" in trajs[0]["cond_info"].keys():
batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs])
if "focus_dir" in trajs[0]["cond_info"].keys():
batch.focus_dir = torch.stack([t["cond_info"]["focus_dir"] for t in trajs])

if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n:
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
Expand Down

0 comments on commit 5f953c6

Please sign in to comment.