diff --git a/spconv/pytorch/pool.py b/spconv/pytorch/pool.py index 2384507..ce3f64d 100644 --- a/spconv/pytorch/pool.py +++ b/spconv/pytorch/pool.py @@ -270,7 +270,7 @@ def forward(self, input: spconv.SparseConvTensor): real_inds = out_indices[i, :counts_cpu_np[i]] real_features = input.features[real_inds] if self.is_mean: - real_features_reduced = torch.mean(real_features, dim=0)[0] + real_features_reduced = torch.mean(real_features, dim=0) else: real_features_reduced = torch.max(real_features, dim=0)[0] res_features_list.append(real_features_reduced)