Skip to content

Commit

Permalink
Add global norm prod kernels to dev
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jun 15, 2024
1 parent 2bc0b47 commit 0f6853d
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 4 deletions.
27 changes: 27 additions & 0 deletions dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cublasLt.h>
#include <float.h>

#define WARP_SIZE 32U

template<class T>
__host__ __device__ T ceil_div(T dividend, T divisor) {
Expand All @@ -18,6 +19,32 @@ __device__ float warpReduceSum(float val) {
return val;
}

// requires all 32 threads in the warp to be active, but should work for any block size
// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes
// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end
// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1
using reduction_func_t = float (*) (float);
template<reduction_func_t warp_reduction>
__device__ inline float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) {
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
__shared__ float shared_val[WARP_SIZE];
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;

float warp_val = warp_reduction(val);
if (lane_id == 0) { shared_val[warp_id] = warp_val; }
__syncthreads();
warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds;
float block_val = warp_reduction(warp_val);

if (final_sync) {
__syncthreads(); // only needed in loops when effectively reusing shared memory etc.
}
return block_val;
}

// ----------------------------------------------------------------------------
// checking utils

Expand Down
94 changes: 90 additions & 4 deletions dev/cuda/global_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,54 @@ __global__ void norm_kernel2(float* out, const T* data, size_t count) {
}
}

template<class T>
__global__ void norm_kernel3(float* out, const T* data, size_t count) {
size_t index = blockIdx.x * blockDim.x + threadIdx.x;
size_t grid_width = blockDim.x * gridDim.x;
float accumulator = 0.f;
for(size_t i = index; i < count; i += grid_width) {
accumulator += (float)data[i] * (float)data[i];
}
// block-level reduce
float block_sum = blockReduce<warpReduceSum>(accumulator);
if(threadIdx.x == 0) {
atomicAdd(out, block_sum);
}
}

// Same as kernel3 but without atomic adds -> this allows us to have determinism due to the
// non associativity of floating point operations. Roughly same performance as kernel3.
template<class T>
__global__ void norm_kernel4(float* out, const T* data, size_t count) {
size_t index = blockIdx.x * blockDim.x + threadIdx.x;
size_t grid_width = blockDim.x * gridDim.x;
float accumulator = 0.f;
for(size_t i = index; i < count; i += grid_width) {
accumulator += (float)data[i] * (float)data[i];
}
// block-level reduce
float block_sum = blockReduce<warpReduceSum>(accumulator);
// each block accumulates its partial sum to out[blockIdx.x]
// we want to avoid using atomic add here so we combine this kernel with the aggregate kernel call
// that sums up the partial block sums
if(threadIdx.x == 0) {
out[blockIdx.x] = block_sum;
}
}

__global__ void global_norm_aggregate_kernel(float* out, size_t count) {
size_t index = threadIdx.x;
// grab block sums from the previous kernel, use 0. as the neutral sum element
float block_sum = (index < count) ? out[index] : 0.f;
float sum = blockReduce<warpReduceSum>(block_sum);
if(threadIdx.x == 0) {
out[0] = sum; // out[0] ends up with the final norm squared
}
}

// ----------------------------------------------------------------------------
// kernel launchers

template<typename T>
void global_norm1(float* out, const T* values, size_t count, int block_size) {
// launch just enough blocks to fill the grid. deliberately no DIV_CEIL.
Expand All @@ -111,17 +159,55 @@ void global_norm2(float* out, const T* values, size_t count, int block_size) {
cudaCheck(cudaGetLastError());
}

void global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size) {
template<typename T>
void global_norm3(float* out, const T* values, size_t count, int block_size, cudaDeviceProp deviceProp) {
// launch just enough blocks to fill the grid. deliberately no DIV_CEIL.
// having one block less than possible is a tiny performance hit, having
// one block too many is catastrophic, since it only can start once all the other
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
// on all gpus, so the division really is going to be exact.
const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
norm_kernel3<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaGetLastError());
}

template<typename T>
void global_norm4(float* out, const T* values, size_t count, int block_size, cudaDeviceProp deviceProp) {
if (block_size <= 64) {
block_size = 128; // to avoid triggering the assert below
}
// launch just enough blocks to fill the grid. deliberately no DIV_CEIL.
// having one block less than possible is a tiny performance hit, having
// one block too many is catastrophic, since it only can start once all the other
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
// on all gpus, so the division really is going to be exact.
const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
assert(grid_size < 1024); // we want to later accumulate the block sums in a single block
norm_kernel4<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaGetLastError());
global_norm_aggregate_kernel<<<1, 1024>>>(out, grid_size);
cudaCheck(cudaGetLastError());
}

void global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size, cudaDeviceProp deviceProp) {
switch (kernel_num) {
case 1:
return global_norm1(out, values, count, block_size);
case 2:
return global_norm2(out, values, count, block_size);
case 3:
return global_norm3(out, values, count, block_size, deviceProp);
case 4:
return global_norm4(out, values, count, block_size, deviceProp);
}
}

int main(int argc, const char **argv) {
setup_main();
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);

int C = 768;
int L = 12;
Expand All @@ -148,7 +234,7 @@ int main(int argc, const char **argv) {
// move to GPU
float* d_out;
floatX* d_inp;
cudaCheck(cudaMalloc(&d_out, sizeof(float)));
cudaCheck(cudaMalloc(&d_out, 1024 * sizeof(float))); // 1024 needed for kernel 4
cudaCheck(cudaMalloc(&d_inp, num_params * sizeof(floatX)));
cudaCheck(memcpy_convert(d_inp, inp, num_params));

Expand All @@ -157,7 +243,7 @@ int main(int argc, const char **argv) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
cudaCheck(cudaMemset(d_out, 0, sizeof(float)));
global_norm(kernel_num, d_out, d_inp, num_params, block_size);
global_norm(kernel_num, d_out, d_inp, num_params, block_size, deviceProp);
validate_result(d_out, &out, "out", 1, 1e-2f);
}

Expand All @@ -170,7 +256,7 @@ int main(int argc, const char **argv) {

float elapsed_time = benchmark_kernel(repeat_times, global_norm,
kernel_num, d_out, d_inp,
num_params, block_size);
num_params, block_size, deviceProp);
size_t memory_ops = num_params * sizeof(floatX);
float memory_bandwidth = memory_ops / elapsed_time / 1e6;

Expand Down

0 comments on commit 0f6853d

Please sign in to comment.