Skip to content

Commit

Permalink
Speedup bincount (#10308)
Browse files Browse the repository at this point in the history
before:
```
  OP         Args                Lib   KT(GPU)    BW(GPU)                                          KT(1 CPU)   ET(1 CPU)   KT(32 CPU)   ET(32 CPU)
 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  bincount   ones(4096).int()    OF    4814.5     cast: , reduce_min: , reduce_max: , bincount:    685.9       771.7       692.0        772.1
  bincount   ones(65536).int()   OF    545205.9   cast: , reduce_min: , reduce_max: , bincount:    1458.3      1547.9      1508.3       1596.0
  bincount   ones(4096).int()    PT    183.2      -                                                18.3        21.0        21.8         24.9
  bincount   ones(65536).int()   PT    310.7      -                                                158.6       161.2       190.0        192.8
```

after:
```
  OP         Args                Lib   KT(GPU)   BW(GPU)                                          KT(1 CPU)   ET(1 CPU)   KT(32 CPU)   ET(32 CPU)
 ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  bincount   ones(4096).int()    OF    360.2     cast: , reduce_min: , reduce_max: , bincount:    617.1       688.0       624.6        694.3
  bincount   ones(65536).int()   OF    623.3     cast: , reduce_min: , reduce_max: , bincount:    1279.0      1345.1      1554.4       1633.6
  bincount   ones(4096).int()    PT    188.3     -                                                16.8        19.4        17.6         20.2
  bincount   ones(65536).int()   PT    317.5     -                                                158.9       161.4       188.1        191.3
```
  • Loading branch information
marigoold authored Aug 2, 2023
1 parent 27ddd62 commit b664acc
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 27 deletions.
91 changes: 68 additions & 23 deletions oneflow/user/kernels/bincount_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,67 @@ namespace oneflow {
namespace user_op {
namespace {

// clang-format off
template<typename IDX, typename T>
__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t size) {
CUDA_1D_KERNEL_LOOP(i, size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, weight[i]);
template<typename IDX, typename T, bool UseGlobalMem>
__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr,
int64_t in_size, int64_t out_size) {
if constexpr (UseGlobalMem) {
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, weight[i]);
}
} else {
__shared__ T shm[kCudaThreadsNumPerBlock];
T zero = GetZeroVal<T>();
shm[threadIdx.x] = zero;
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(shm + idx, weight[i]);
}
__syncthreads();
if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }
}
};
// clang-format on

template<typename IDX, typename T>
__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t size) {
template<typename IDX, typename T, bool UseGlobalMem>
__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t in_size,
int64_t out_size) {
T one = GetOneVal<T>();
CUDA_1D_KERNEL_LOOP(i, size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, one);
if constexpr (UseGlobalMem) {
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, one);
}
} else {
__shared__ T shm[kCudaThreadsNumPerBlock];
T zero = GetZeroVal<T>();
shm[threadIdx.x] = zero;
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(shm + idx, one);
}
__syncthreads();
if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }
}
};

template<typename IDX, typename T, bool UseGlobalMem>
static void BinCountDispatch(user_op::KernelComputeContext* ctx, const IDX* in_ptr,
const T* weight_ptr, T* out_ptr, int64_t in_size, int64_t out_size) {
if (weight_ptr) {
BinCountCompute<IDX, T, UseGlobalMem>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, weight_ptr, out_ptr,
in_size, out_size);
} else {
BinCountCompute<IDX, T, UseGlobalMem>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, out_ptr, in_size,
out_size);
}
}
template<typename IDX, typename T>
class CUDABinCountKernel final : public user_op::OpKernel {
public:
Expand All @@ -52,31 +94,34 @@ class CUDABinCountKernel final : public user_op::OpKernel {
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
size_t out_size = ctx->Attr<int64_t>("size") * sizeof(T);
size_t out_size = ctx->Attr<int64_t>("size");
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const IDX* in_ptr = in->dptr<IDX>();
T* out_ptr = out->mut_dptr<T>();
std::unique_ptr<ep::primitive::Memset> memset_primitive =
ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());
CHECK(memset_primitive);
memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size);
int64_t in_size = in->shape_view().elem_cnt();
memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size * sizeof(T));
const int64_t in_size = in->shape_view().elem_cnt();
if (in_size == 0) { return; }
const T* weight_ptr = nullptr;
if (ctx->has_input("weight", 0)) {
const T* weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr<T>();
BinCountCompute<IDX, T><<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
in_ptr, weight_ptr, out_ptr, in_size);
weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr<T>();
};
if (out_size > kCudaThreadsNumPerBlock) {
BinCountDispatch<IDX, T, true>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);
} else {
BinCountCompute<IDX, T>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, out_ptr, in_size);
BinCountDispatch<IDX, T, false>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);
}
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
} // namespace oneflow
} // namespace
#define REGISTER_CUDA_BINCOUNT_KERNEL(idx_type, dtype) \
REGISTER_USER_KERNEL("bincount") \
Expand Down
14 changes: 10 additions & 4 deletions python/oneflow/test/modules/test_bincount.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,39 @@ class TestBinCount(flow.unittest.TestCase):
@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
result = torch.bincount(x)
return result

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_weight(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 100).to(device)
return torch.bincount(x, weights=weight)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_minlength(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 100).to(device)
minlength = random(1, 200).to(int)
return torch.bincount(x, weights=weight, minlength=minlength)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_0element(test_case):
device = random_device()
x = random_tensor(1, 0, low=0, dtype=int).to(device)
x = random_tensor(1, 0, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 0).to(device)
minlength = random(1, 200).to(int)
return torch.bincount(x, weights=weight, minlength=minlength)

@profile(torch.bincount)
def profile_bincount(test_case):
torch.bincount(torch.ones(4096).int())
torch.bincount(torch.ones(65536).int())
torch.bincount(torch.arange(4096).int())


if __name__ == "__main__":
unittest.main()

0 comments on commit b664acc

Please sign in to comment.