Skip to content

Commit

Permalink
Log the saturation percentage of the output
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Dec 14, 2023
1 parent 396c11b commit 54482e4
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def training_step(self, train_batch: NTURGBDSample):
self.model.train()
out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
outputs = out['out']
# Calculate the saturation of the tanh output
saturated = (outputs.abs() > 0.95)
saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
Expand All @@ -205,6 +209,9 @@ def training_step(self, train_batch: NTURGBDSample):
for name, ratio in self.model.gradient_update_ratios.items():
self.logger.add_scalar(tag=f'gradient/{name}_grad_update_norm_ratio', scalar_value=ratio, global_step=len(self.training_loss_per_step))

if self.logger is not None:
self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step))

self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
Expand Down

0 comments on commit 54482e4

Please sign in to comment.