diff --git a/statistics_util/statistic_plots.py b/statistics_util/statistic_plots.py index f02132f9da..87ad379ea1 100644 --- a/statistics_util/statistic_plots.py +++ b/statistics_util/statistic_plots.py @@ -190,11 +190,11 @@ def create_statistics(self, graph_y_labels): for i, i_head in enumerate(i_first_batch): ## Flatten across heads, height, and width - flattened = i_head.view(-1) + i_head = i_head[torch.tril(torch.ones_like(i_head)) == 1] ## Calculate statistics - i_means.append(torch.nanmean(flattened).item()) - i_medians.append(torch.nanmedian(flattened).item()) + i_means.append(torch.nanmean(i_head).item()) + i_medians.append(torch.nanmedian(i_head).item()) # Standard deviation, ignoring NaNs mask = ~torch.isnan(i_head) @@ -213,8 +213,8 @@ def create_statistics(self, graph_y_labels): denominator.append(sum.item()) ## Append statistic to the input list of each head in each layer - self.stats['mean'][layer][i].append(torch.nanmean(flattened).item()) - self.stats['median'][layer][i].append(torch.nanmedian(flattened).item()) + self.stats['mean'][layer][i].append(torch.nanmean(i_head).item()) + self.stats['median'][layer][i].append(torch.nanmedian(i_head).item()) self.stats['stdev'][layer][i].append(torch.std(i_head[mask]).item()) self.stats['max'][layer][i].append(torch.max(torch.where(torch.isnan(i_head), torch.tensor(float('-inf')), i_head)).item()) self.stats['min'][layer][i].append(torch.min(torch.where(torch.isnan(i_head), torch.tensor(float('inf')), i_head)).item()) @@ -228,12 +228,12 @@ def create_statistics(self, graph_y_labels): for i, o_head in enumerate(o_first_batch): # Step 3: Flatten across heads, height, and width - flattened = o_head.view(-1) + o_head = o_head[torch.tril(torch.ones_like(o_head)) == 1] # Step 4: Calculate statistics ## Calculate statistics - o_means.append(torch.nanmean(flattened).item()) - o_medians.append(torch.nanmedian(flattened).item()) + o_means.append(torch.nanmean(o_head).item()) + o_medians.append(torch.nanmedian(o_head).item()) # Standard deviation, ignoring NaNs mask = ~torch.isnan(o_head) o_stdevs.append(torch.std(o_head[mask]).item()) @@ -247,8 +247,8 @@ def create_statistics(self, graph_y_labels): o_min_values.append(torch.min(torch.where(torch.isnan(o_head), torch.tensor(float('inf')), o_head)).item()) # Append statistic to the output list of each head in each layer - self.stats['o_mean'][layer][i].append(torch.nanmean(flattened).item()) - self.stats['o_median'][layer][i].append(torch.nanmedian(flattened).item()) + self.stats['o_mean'][layer][i].append(torch.nanmean(o_head).item()) + self.stats['o_median'][layer][i].append(torch.nanmedian(o_head).item()) self.stats['o_stdev'][layer][i].append(torch.std(o_head[mask]).item()) self.stats['o_max'][layer][i].append(torch.max(torch.where(torch.isnan(o_head), torch.tensor(float('-inf')), o_head)).item()) self.stats['o_min'][layer][i].append(torch.min(torch.where(torch.isnan(o_head), torch.tensor(float('inf')), o_head)).item())