diff --git a/oneflow/user/kernels/bincount_kernel.cu b/oneflow/user/kernels/bincount_kernel.cu index 5e6c3910237..ce80849e68c 100644 --- a/oneflow/user/kernels/bincount_kernel.cu +++ b/oneflow/user/kernels/bincount_kernel.cu @@ -23,25 +23,67 @@ namespace oneflow { namespace user_op { namespace { -// clang-format off -template -__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t size) { - CUDA_1D_KERNEL_LOOP(i, size) { - IDX idx = *(in_ptr + i); - cuda::atomic::Add(out_ptr + idx, weight[i]); +template +__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr, + int64_t in_size, int64_t out_size) { + if constexpr (UseGlobalMem) { + CUDA_1D_KERNEL_LOOP(i, in_size) { + IDX idx = *(in_ptr + i); + cuda::atomic::Add(out_ptr + idx, weight[i]); + } + } else { + __shared__ T shm[kCudaThreadsNumPerBlock]; + T zero = GetZeroVal(); + shm[threadIdx.x] = zero; + __syncthreads(); + CUDA_1D_KERNEL_LOOP(i, in_size) { + IDX idx = *(in_ptr + i); + cuda::atomic::Add(shm + idx, weight[i]); + } + __syncthreads(); + if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); } } }; -// clang-format on -template -__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t size) { +template +__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t in_size, + int64_t out_size) { T one = GetOneVal(); - CUDA_1D_KERNEL_LOOP(i, size) { - IDX idx = *(in_ptr + i); - cuda::atomic::Add(out_ptr + idx, one); + if constexpr (UseGlobalMem) { + CUDA_1D_KERNEL_LOOP(i, in_size) { + IDX idx = *(in_ptr + i); + cuda::atomic::Add(out_ptr + idx, one); + } + } else { + __shared__ T shm[kCudaThreadsNumPerBlock]; + T zero = GetZeroVal(); + shm[threadIdx.x] = zero; + __syncthreads(); + CUDA_1D_KERNEL_LOOP(i, in_size) { + IDX idx = *(in_ptr + i); + cuda::atomic::Add(shm + idx, one); + } + __syncthreads(); + if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); } } }; +template +static void BinCountDispatch(user_op::KernelComputeContext* ctx, const IDX* in_ptr, + const T* weight_ptr, T* out_ptr, int64_t in_size, int64_t out_size) { + if (weight_ptr) { + BinCountCompute + <<stream()->As()->cuda_stream()>>>(in_ptr, weight_ptr, out_ptr, + in_size, out_size); + } else { + BinCountCompute + <<stream()->As()->cuda_stream()>>>(in_ptr, out_ptr, in_size, + out_size); + } +} + template class CUDABinCountKernel final : public user_op::OpKernel { public: @@ -52,31 +94,34 @@ class CUDABinCountKernel final : public user_op::OpKernel { using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); - size_t out_size = ctx->Attr("size") * sizeof(T); + size_t out_size = ctx->Attr("size"); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const IDX* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); + std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); - memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size); - int64_t in_size = in->shape_view().elem_cnt(); + memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size * sizeof(T)); + + const int64_t in_size = in->shape_view().elem_cnt(); if (in_size == 0) { return; } + + const T* weight_ptr = nullptr; if (ctx->has_input("weight", 0)) { - const T* weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); - BinCountCompute<<stream()->As()->cuda_stream()>>>( - in_ptr, weight_ptr, out_ptr, in_size); + weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); + }; + + if (out_size > kCudaThreadsNumPerBlock) { + BinCountDispatch(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size); } else { - BinCountCompute - <<stream()->As()->cuda_stream()>>>(in_ptr, out_ptr, in_size); + BinCountDispatch(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -} // namespace oneflow +} // namespace #define REGISTER_CUDA_BINCOUNT_KERNEL(idx_type, dtype) \ REGISTER_USER_KERNEL("bincount") \ diff --git a/python/oneflow/test/modules/test_bincount.py b/python/oneflow/test/modules/test_bincount.py index 1066ba94b7f..dc9134280f2 100644 --- a/python/oneflow/test/modules/test_bincount.py +++ b/python/oneflow/test/modules/test_bincount.py @@ -26,21 +26,21 @@ class TestBinCount(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount(test_case): device = random_device() - x = random_tensor(1, 100, low=0, dtype=int).to(device) + x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) result = torch.bincount(x) return result @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_weight(test_case): device = random_device() - x = random_tensor(1, 100, low=0, dtype=int).to(device) + x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 100).to(device) return torch.bincount(x, weights=weight) @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_minlength(test_case): device = random_device() - x = random_tensor(1, 100, low=0, dtype=int).to(device) + x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 100).to(device) minlength = random(1, 200).to(int) return torch.bincount(x, weights=weight, minlength=minlength) @@ -48,11 +48,17 @@ def test_bincount_minlength(test_case): @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_0element(test_case): device = random_device() - x = random_tensor(1, 0, low=0, dtype=int).to(device) + x = random_tensor(1, 0, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 0).to(device) minlength = random(1, 200).to(int) return torch.bincount(x, weights=weight, minlength=minlength) + @profile(torch.bincount) + def profile_bincount(test_case): + torch.bincount(torch.ones(4096).int()) + torch.bincount(torch.ones(65536).int()) + torch.bincount(torch.arange(4096).int()) + if __name__ == "__main__": unittest.main()