Skip to content

Commit

Permalink
fix: added detach and cpu() at the begining of create_batch()
Browse files Browse the repository at this point in the history
  • Loading branch information
julienroyd committed Apr 1, 2024
1 parent 5443f6f commit f92690a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from gflownet import GFNAlgorithm, GFNTask
from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu
from gflownet.envs.graph_building_env import GraphBuildingEnvContext
from gflownet.utils.misc import get_worker_rng

Expand Down Expand Up @@ -214,6 +214,7 @@ def call_sampling_hooks(self, trajs):
return batch_info

def create_batch(self, trajs, batch_info):
trajs = detach_and_cpu(trajs)
ci = torch.stack([t["cond_info"]["encoding"] for t in trajs])
log_rewards = torch.stack([t["log_reward"] for t in trajs])
batch = self.algo.construct_batch(trajs, ci, log_rewards)
Expand Down

0 comments on commit f92690a

Please sign in to comment.