Skip to content

Commit

Permalink
[Feature] Iteration timer includes evaluation (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Jul 10, 2024
1 parent 6225cf5 commit e16ffa0
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ def _collection_loop(self):
initial=self.n_iters_performed,
total=self.config.get_max_n_iters(self.on_policy),
)
sampling_start = time.time()

if not self.config.collect_with_grad:
iterator = iter(self.collector)
Expand All @@ -568,6 +567,7 @@ def _collection_loop(self):
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
iteration_start = time.time()
if not self.config.collect_with_grad:
batch = next(iterator)
else:
Expand All @@ -585,7 +585,7 @@ def _collection_loop(self):
reset_batch = step_mdp(batch[..., -1])

# Logging collection
collection_time = time.time() - sampling_start
collection_time = time.time() - iteration_start
current_frames = batch.numel()
self.total_frames += current_frames
self.mean_return = self.logger.log_collection(
Expand Down Expand Up @@ -637,22 +637,8 @@ def _collection_loop(self):
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()

# Timers
# Training timer
training_time = time.time() - training_start
iteration_time = collection_time + training_time
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)

# Evaluation
if (
Expand All @@ -666,6 +652,20 @@ def _collection_loop(self):
self._evaluation_loop()

# End of step
iteration_time = time.time() - iteration_start
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)
self.n_iters_performed += 1
self.logger.commit()
if (
Expand All @@ -674,7 +674,6 @@ def _collection_loop(self):
):
self._save_experiment()
pbar.update()
sampling_start = time.time()

if self.config.checkpoint_at_end:
self._save_experiment()
Expand Down

0 comments on commit e16ffa0

Please sign in to comment.