From 9ae307d259a106f6cf23f5c7203986738dd71d76 Mon Sep 17 00:00:00 2001 From: x1y9 Date: Mon, 8 Jul 2024 10:30:55 +0800 Subject: [PATCH] fix SparseGlobalAvgPool --- spconv/pytorch/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spconv/pytorch/pool.py b/spconv/pytorch/pool.py index 2384507..ce3f64d 100644 --- a/spconv/pytorch/pool.py +++ b/spconv/pytorch/pool.py @@ -270,7 +270,7 @@ def forward(self, input: spconv.SparseConvTensor): real_inds = out_indices[i, :counts_cpu_np[i]] real_features = input.features[real_inds] if self.is_mean: - real_features_reduced = torch.mean(real_features, dim=0)[0] + real_features_reduced = torch.mean(real_features, dim=0) else: real_features_reduced = torch.max(real_features, dim=0)[0] res_features_list.append(real_features_reduced)