Skip to content

Commit

Permalink
Compute the dead neurons
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Dec 21, 2023
1 parent 9408041 commit 63f939f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def training_step(self, train_batch: NTURGBDSample):
# Calculate the saturation of the tanh output
saturated = (outputs.abs() > 0.95)
saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
# Calculate the dead neurons
dead_neurons = (outputs.abs() < 0.05)
dead_neurons_percentage = dead_neurons.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
Expand All @@ -211,6 +214,7 @@ def training_step(self, train_batch: NTURGBDSample):

if self.logger is not None:
self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(tag='train/dead_neurons', scalar_value=dead_neurons_percentage.mean().item(), global_step=len(self.training_loss_per_step))

self.optimizer.step()
# Print the loss
Expand Down

0 comments on commit 63f939f

Please sign in to comment.