Skip to content

Commit

Permalink
Use faster kernel 5 for layernorm forward
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jun 16, 2024
1 parent 0e69e3a commit b966d9d
Showing 1 changed file with 49 additions and 33 deletions.
82 changes: 49 additions & 33 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,65 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa
// ----------------------------------------------------------------------------
// CUDA kernels

__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd,
__global__ void layernorm_forward_kernel5(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd,
const floatX* __restrict__ inp, const floatX* __restrict__ weight,
const floatX* __restrict__ bias, int N, int C) {
int lane_id = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
__shared__ float shared_sum[WARP_SIZE]; // block_size max is 1024 = 32 * 32 warps
__shared__ float shared_sum2[WARP_SIZE]; // warps will be writing into shared memory after warp-reduce

int num_warps = blockDim.x / WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;

int idx = blockIdx.x; // simply one block per row
const floatX* x = inp + idx * C; // the row of input that this group of threads is responsible for

int idx = blockIdx.x * num_warps + warp_id;
if(idx >= N) { return; } // guard
// thread coarsening through the row, reduce the sum in series
float thread_sum = 0.0; // stores sum(x)
float thread_sum2 = 0.0; // stores sum(x**2)
for (int i = threadIdx.x; i < C; i += blockDim.x) {
float xi = x[i];
thread_sum += xi;
thread_sum2 += xi * xi;
}

// the row of input that this group of threads is responsible for
const floatX* x = inp + idx * C;
// 2 warp level reductions instead of block level in order to prevent 2 syncthreads

// mean
float sum = 0.0f;
for (int i = lane_id; i < C; i += WARP_SIZE) {
sum += (float)x[i];
// warp-level reduction
float warp_sum = warpReduceSum(thread_sum); // sum(x)
float warp_sum2 = warpReduceSum(thread_sum2); // sum(x**2)
// store the warp-level reduction in shared memory
if (lane_id == 0) {
shared_sum[warp_id] = warp_sum;
shared_sum2[warp_id] = warp_sum2;
}
sum = warpReduceSum(sum);
float m = sum / C;
if(lane_id == 0 && mean != nullptr) {
__syncthreads();
// load results from shared memory to threads, pad with zeros for threads that are out of bounds
warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f;
warp_sum2 = (lane_id < num_warps) ? shared_sum2[lane_id] : 0.0f;
// now reduce the warp-level reductions
float block_sum = warpReduceSum(warp_sum); // sum(x)
float block_sum2 = warpReduceSum(warp_sum2); // sum(x**2)

// mean, var, rstd
block_sum /= C; // mean(x)
block_sum2 /= C; // mean(x**2)
float m = block_sum;
float var = block_sum2 - m * m;
float s = rsqrtf(var + 1e-5f);

if (threadIdx.x == 0 && mean != nullptr) {
__stcs(mean + idx, (floatX)m);
}

// rstd
sum = 0.0f;
for (int i = lane_id; i < C; i += WARP_SIZE) {
float diff = (float)x[i] - m;
sum += diff * diff;
}
sum = warpReduceSum(sum);
float s = rsqrtf(sum / C + 1e-5f);
if(lane_id == 0 && rstd != nullptr) {
if (threadIdx.x == 0 && rstd != nullptr) {
__stcs(rstd + idx, (floatX)s);
}

// final normalization and scaling by weight/bias
floatX* o = out + idx * C;
for (int c = lane_id; c < C; c += WARP_SIZE) {
// load and store using the .cs "streaming" hint to the compiler,
// indicating that this data will not be reused soon, and can be streamed through the caches
// this allows the threads to get more cache-hits for the (shared) weight and bias parameters
float n = s * ((float)__ldcs(x+c) - m);
__stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c]));
for (int i = threadIdx.x; i < C; i += blockDim.x) {
float n = s * ((float)__ldcs(x+i) - m);
__stcs(o+i, (floatX)(n * (float)weight[i] + (float)bias[i]));
}
}

Expand Down Expand Up @@ -358,10 +373,11 @@ void layernorm_forward(floatX* out, floatX* mean, floatX* rstd,
floatX* inp, const floatX* weight, const floatX* bias,
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 512;
const int block_size = 256;
assert(block_size % WARP_SIZE == 0 && block_size <= 1024);
const int N = B * T;
const int grid_size = CEIL_DIV(N * WARP_SIZE, block_size);
layernorm_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, mean, rstd, inp, weight, bias, N, C);
const int grid_size = N;
layernorm_forward_kernel5<<<grid_size, block_size, 0, stream>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit b966d9d

Please sign in to comment.