diff --git a/tests/make_it_batch/experiment.py b/tests/make_it_batch/experiment.py index 3bc90150..0f972b13 100644 --- a/tests/make_it_batch/experiment.py +++ b/tests/make_it_batch/experiment.py @@ -274,7 +274,7 @@ def tag(self): def training_pipeline(self, **kwargs): ppo_steps = int(1e6) lr = 2.5e-4 - num_mini_batch = 2 if not torch.cuda.is_available() else 6 + num_mini_batch = self.task_batch_size if not torch.cuda.is_available() else 6 update_repeats = 4 num_steps = self.MAX_STEPS metric_accumulate_interval = self.MAX_STEPS * 1 # Log every 10 max length tasks